From 2bc9863e66f3ca11e73610d592b9a0416030e428 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 4 Nov 2025 15:04:24 -0500 Subject: [PATCH 01/13] only wg tun, no batching --- overlay/tun_linux.go | 85 +++++++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 44d87465..366a559c 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -9,7 +9,6 @@ import ( "net" "net/netip" "os" - "strings" "sync/atomic" "time" "unsafe" @@ -21,10 +20,12 @@ import ( "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" + wgtun "golang.zx2c4.com/wireguard/tun" ) type tun struct { io.ReadWriteCloser + wgDevice wgtun.Device fd int Device string vpnNetworks []netip.Prefix @@ -66,58 +67,58 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") + wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd) + if err != nil { + return nil, fmt.Errorf("failed to create TUN from FD: %w", err) + } + file := wgDev.File() t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { + _ = wgDev.Close() return nil, err } - t.Device = "tun0" + t.wgDevice = wgDev + t.Device = name return t, nil } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { - fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - // If /dev/net/tun doesn't exist, try to create it (will happen in docker) - if os.IsNotExist(err) { - err = os.MkdirAll("/dev/net", 0755) - if err != nil { - return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) - } - err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) - if err != nil { - return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) - } - - fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) - } - } else { - return nil, err + // Check if /dev/net/tun exists, create if needed (for docker containers) + if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) { + if err := os.MkdirAll("/dev/net", 0755); err != nil { + return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) + } + if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil { + return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) } } - var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) - if multiqueue { - req.Flags |= unix.IFF_MULTI_QUEUE - } - copy(req.Name[:], c.GetString("tun.dev", "")) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { - return nil, err - } - name := strings.Trim(string(req.Name[:]), "\x00") + devName := c.GetString("tun.dev", "") + mtu := c.GetInt("tun.mtu", DefaultMTU) - file := os.NewFile(uintptr(fd), "/dev/net/tun") + // Create TUN device using wireguard library + wgDev, err := wgtun.CreateTUN(devName, mtu) + if err != nil { + return nil, fmt.Errorf("failed to create TUN device: %w", err) + } + + name, err := wgDev.Name() + if err != nil { + _ = wgDev.Close() + return nil, fmt.Errorf("failed to get TUN device name: %w", err) + } + + file := wgDev.File() t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { + _ = wgDev.Close() return nil, err } + t.wgDevice = wgDev t.Device = name return t, nil @@ -240,6 +241,20 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { } func (t *tun) Write(b []byte) (int, error) { + if t.wgDevice != nil { + // Use wireguard device for writing + bufs := [][]byte{b} + n, err := t.wgDevice.Write(bufs, 0) + if err != nil { + return 0, err + } + if n != 1 { + return 0, fmt.Errorf("expected to write 1 packet, wrote %d", n) + } + return len(b), nil + } + + // Fallback to direct fd write if no wireguard device var nn int maximum := len(b) @@ -674,6 +689,10 @@ func (t *tun) Close() error { close(t.routeChan) } + if t.wgDevice != nil { + _ = t.wgDevice.Close() + } + if t.ReadWriteCloser != nil { _ = t.ReadWriteCloser.Close() } From 3344a840d1c1b7587651b438ea90b90d7ed1a592 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 11 Nov 2025 10:55:39 -0500 Subject: [PATCH 02/13] just using the wg library works --- overlay/tun_linux.go | 50 ++++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 366a559c..9df45812 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -99,10 +99,36 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu devName := c.GetString("tun.dev", "") mtu := c.GetInt("tun.mtu", DefaultMTU) - // Create TUN device using wireguard library - wgDev, err := wgtun.CreateTUN(devName, mtu) + // Create TUN device manually to support multiqueue + fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { - return nil, fmt.Errorf("failed to create TUN device: %w", err) + return nil, err + } + + var req ifReq + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) + if multiqueue { + req.Flags |= unix.IFF_MULTI_QUEUE + } + copy(req.Name[:], devName) + if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + unix.Close(fd) + return nil, err + } + + // Set nonblocking + if err = unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } + + file := os.NewFile(uintptr(fd), "/dev/net/tun") + + // Create wireguard device from file descriptor + wgDev, err := wgtun.CreateTUNFromFile(file, mtu) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create TUN from file: %w", err) } name, err := wgDev.Name() @@ -111,7 +137,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return nil, fmt.Errorf("failed to get TUN device name: %w", err) } - file := wgDev.File() + // file is now owned by wgDev, get a new reference + file = wgDev.File() t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { _ = wgDev.Close() @@ -224,6 +251,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } var req ifReq + // MUST match the flags used in newTun req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { @@ -241,20 +269,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { } func (t *tun) Write(b []byte) (int, error) { - if t.wgDevice != nil { - // Use wireguard device for writing - bufs := [][]byte{b} - n, err := t.wgDevice.Write(bufs, 0) - if err != nil { - return 0, err - } - if n != 1 { - return 0, fmt.Errorf("expected to write 1 packet, wrote %d", n) - } - return len(b), nil - } - - // Fallback to direct fd write if no wireguard device var nn int maximum := len(b) From b68e50486580bbae00ca8f9230e423b608197972 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 11 Nov 2025 13:15:30 -0500 Subject: [PATCH 03/13] hrm --- interface.go | 53 +++++++++++++++ overlay/tun_linux.go | 157 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 206 insertions(+), 4 deletions(-) diff --git a/interface.go b/interface.go index 082906d9..7e8dd4c0 100644 --- a/interface.go +++ b/interface.go @@ -276,9 +276,26 @@ func (f *Interface) listenOut(i int) { }) } +// BatchReader is an interface for devices that support reading multiple packets at once +type BatchReader interface { + BatchRead(bufs [][]byte, sizes []int) (int, error) + BatchSize() int +} + func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() + // Check if reader supports batching + batchReader, supportsBatching := reader.(BatchReader) + + if supportsBatching { + f.listenInBatch(reader, batchReader, i) + } else { + f.listenInSingle(reader, i) + } +} + +func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) { packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} @@ -302,6 +319,42 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } } +func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchReader, i int) { + batchSize := batchReader.BatchSize() + + // Allocate buffers for batch reading + bufs := make([][]byte, batchSize) + for idx := range bufs { + bufs[idx] = make([]byte, mtu) + } + sizes := make([]int, batchSize) + + // Per-packet state (reused across batches) + out := make([]byte, mtu) + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + + for { + n, err := batchReader.BatchRead(bufs, sizes) + if err != nil { + if errors.Is(err, os.ErrClosed) && f.closed.Load() { + return + } + + f.l.WithError(err).Error("Error while batch reading outbound packets") + // This only seems to happen when something fatal happens to the fd, so exit. + os.Exit(2) + } + + // Process each packet in the batch + for j := 0; j < n; j++ { + f.consumeInsidePacket(bufs[j][:sizes[j]], fwPacket, nb, out, i, conntrackCache.Get(f.l)) + } + } +} + func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 9df45812..882a7e7d 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -66,6 +66,78 @@ type ifreqQLEN struct { pad [8]byte } +const ( + virtioNetHdrLen = 10 // Size of virtio_net_hdr structure +) + +// tunVirtioReader wraps a file descriptor that has IFF_VNET_HDR enabled +// and strips the virtio header on reads, adds it on writes +type tunVirtioReader struct { + f *os.File + buf [virtioNetHdrLen + 65535]byte // Space for header + max packet +} + +func (r *tunVirtioReader) Read(b []byte) (int, error) { + // Read into our buffer which has space for the virtio header + n, err := r.f.Read(r.buf[:]) + if err != nil { + return 0, err + } + + // Strip the virtio header (first 10 bytes) + if n < virtioNetHdrLen { + return 0, fmt.Errorf("packet too short: %d bytes", n) + } + + // Copy payload (after header) to destination + copy(b, r.buf[virtioNetHdrLen:n]) + return n - virtioNetHdrLen, nil +} + +func (r *tunVirtioReader) Write(b []byte) (int, error) { + // Zero out the virtio header (no offload from userspace write) + for i := 0; i < virtioNetHdrLen; i++ { + r.buf[i] = 0 + } + + // Copy packet data after header + copy(r.buf[virtioNetHdrLen:], b) + + // Write with header prepended + n, err := r.f.Write(r.buf[:virtioNetHdrLen+len(b)]) + if err != nil { + return 0, err + } + + // Return payload size (excluding header) + return n - virtioNetHdrLen, nil +} + +func (r *tunVirtioReader) Close() error { + return r.f.Close() +} + +// BatchRead reads multiple packets at once for performance +// This is not used for multiqueue readers as they use direct file I/O +// Returns number of packets read +func (r *tunVirtioReader) BatchRead(bufs [][]byte, sizes []int) (int, error) { + // For multiqueue file descriptors, we don't have the wireguard Device interface + // Fall back to single packet reads + // TODO: Could implement proper batching with unix.Recvmmsg + n, err := r.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// BatchSize returns the batch size for multiqueue readers +func (r *tunVirtioReader) BatchSize() int { + // Multiqueue readers use single packet mode for now + return 1 +} + func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd) if err != nil { @@ -106,7 +178,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } @@ -122,6 +194,18 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return nil, err } + // Enable TCP and UDP offload (TSO/GRO) for performance + // This allows the kernel to handle segmentation/coalescing + const ( + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 + ) + offloads := tunTCPOffloads | tunUDPOffloads + if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil { + // Log warning but don't fail - offload is optional + l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced") + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") // Create wireguard device from file descriptor @@ -251,16 +335,18 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } var req ifReq - // MUST match the flags used in newTun - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + // MUST match the flags used in newTun - includes IFF_VNET_HDR + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + unix.Close(fd) return nil, err } file := os.NewFile(uintptr(fd), "/dev/net/tun") - return file, nil + // Wrap in virtio header handler + return &tunVirtioReader{f: file}, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -268,7 +354,70 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } +func (t *tun) Read(b []byte) (int, error) { + if t.wgDevice != nil { + // Use wireguard device which handles virtio headers internally + bufs := [][]byte{b} + sizes := make([]int, 1) + n, err := t.wgDevice.Read(bufs, sizes, 0) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.EOF + } + return sizes[0], nil + } + + // Fallback: direct read from file (shouldn't happen in normal operation) + return t.ReadWriteCloser.Read(b) +} + +// BatchRead reads multiple packets at once for improved performance +// bufs: slice of buffers to read into +// sizes: slice that will be filled with packet sizes +// Returns number of packets read +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + if t.wgDevice != nil { + return t.wgDevice.Read(bufs, sizes, 0) + } + + // Fallback: single packet read + n, err := t.ReadWriteCloser.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// BatchSize returns the optimal number of packets to read/write in a batch +func (t *tun) BatchSize() int { + if t.wgDevice != nil { + return t.wgDevice.BatchSize() + } + return 1 +} + func (t *tun) Write(b []byte) (int, error) { + if t.wgDevice != nil { + // Use wireguard device which handles virtio headers internally + // Allocate buffer with space for virtio header + buf := make([]byte, virtioNetHdrLen+len(b)) + copy(buf[virtioNetHdrLen:], b) + + bufs := [][]byte{buf} + n, err := t.wgDevice.Write(bufs, virtioNetHdrLen) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.ErrShortWrite + } + return len(b), nil + } + + // Fallback: direct write (shouldn't happen in normal operation) var nn int maximum := len(b) From ef0a022375d4754bec26dc8c81b3f1f2099057ae Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 11 Nov 2025 14:22:40 -0500 Subject: [PATCH 04/13] more nonblocking --- overlay/tun_linux.go | 108 ++++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 53 deletions(-) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 882a7e7d..0d7ad2a8 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -70,72 +70,55 @@ const ( virtioNetHdrLen = 10 // Size of virtio_net_hdr structure ) -// tunVirtioReader wraps a file descriptor that has IFF_VNET_HDR enabled -// and strips the virtio header on reads, adds it on writes -type tunVirtioReader struct { - f *os.File - buf [virtioNetHdrLen + 65535]byte // Space for header + max packet +// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser +// This allows multiqueue readers to use the same wireguard Device batching as the main device +type wgDeviceWrapper struct { + dev wgtun.Device + buf []byte // Reusable buffer for single packet reads } -func (r *tunVirtioReader) Read(b []byte) (int, error) { - // Read into our buffer which has space for the virtio header - n, err := r.f.Read(r.buf[:]) +func (w *wgDeviceWrapper) Read(b []byte) (int, error) { + // Use wireguard Device's batch API for single packet + bufs := [][]byte{b} + sizes := make([]int, 1) + n, err := w.dev.Read(bufs, sizes, 0) if err != nil { return 0, err } - - // Strip the virtio header (first 10 bytes) - if n < virtioNetHdrLen { - return 0, fmt.Errorf("packet too short: %d bytes", n) + if n == 0 { + return 0, io.EOF } - - // Copy payload (after header) to destination - copy(b, r.buf[virtioNetHdrLen:n]) - return n - virtioNetHdrLen, nil + return sizes[0], nil } -func (r *tunVirtioReader) Write(b []byte) (int, error) { - // Zero out the virtio header (no offload from userspace write) - for i := 0; i < virtioNetHdrLen; i++ { - r.buf[i] = 0 - } +func (w *wgDeviceWrapper) Write(b []byte) (int, error) { + // Allocate buffer with space for virtio header + buf := make([]byte, virtioNetHdrLen+len(b)) + copy(buf[virtioNetHdrLen:], b) - // Copy packet data after header - copy(r.buf[virtioNetHdrLen:], b) - - // Write with header prepended - n, err := r.f.Write(r.buf[:virtioNetHdrLen+len(b)]) + bufs := [][]byte{buf} + n, err := w.dev.Write(bufs, virtioNetHdrLen) if err != nil { return 0, err } - - // Return payload size (excluding header) - return n - virtioNetHdrLen, nil -} - -func (r *tunVirtioReader) Close() error { - return r.f.Close() -} - -// BatchRead reads multiple packets at once for performance -// This is not used for multiqueue readers as they use direct file I/O -// Returns number of packets read -func (r *tunVirtioReader) BatchRead(bufs [][]byte, sizes []int) (int, error) { - // For multiqueue file descriptors, we don't have the wireguard Device interface - // Fall back to single packet reads - // TODO: Could implement proper batching with unix.Recvmmsg - n, err := r.Read(bufs[0]) - if err != nil { - return 0, err + if n == 0 { + return 0, io.ErrShortWrite } - sizes[0] = n - return 1, nil + return len(b), nil } -// BatchSize returns the batch size for multiqueue readers -func (r *tunVirtioReader) BatchSize() int { - // Multiqueue readers use single packet mode for now - return 1 +func (w *wgDeviceWrapper) Close() error { + return w.dev.Close() +} + +// BatchRead implements batching for multiqueue readers +func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) { + return w.dev.Read(bufs, sizes, 0) +} + +// BatchSize returns the optimal batch size +func (w *wgDeviceWrapper) BatchSize() int { + return w.dev.BatchSize() } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { @@ -343,10 +326,29 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, err } + // Set nonblocking mode - CRITICAL for proper netpoller integration + if err = unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } + + // Get MTU from main device + mtu := t.MaxMTU + if mtu == 0 { + mtu = DefaultMTU + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") - // Wrap in virtio header handler - return &tunVirtioReader{f: file}, nil + // Create wireguard Device from the file descriptor (just like the main device) + wgDev, err := wgtun.CreateTUNFromFile(file, mtu) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err) + } + + // Return a wrapper that uses the wireguard Device for all I/O + return &wgDeviceWrapper{dev: wgDev}, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { From 0f9b33aa36cf90d555d781765427ee2ebf265ace Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 11 Nov 2025 14:51:53 -0500 Subject: [PATCH 05/13] reduce copying --- interface.go | 10 ++++++++-- overlay/tun_linux.go | 46 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/interface.go b/interface.go index 7e8dd4c0..df4c3d37 100644 --- a/interface.go +++ b/interface.go @@ -297,7 +297,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) { packet := make([]byte, mtu) - out := make([]byte, mtu) + // Allocate out buffer with virtio header headroom (10 bytes) to avoid copies on write + const virtioNetHdrLen = 10 + outBuf := make([]byte, virtioNetHdrLen+mtu) + out := outBuf[virtioNetHdrLen:] // Use slice starting after headroom fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) @@ -321,6 +324,7 @@ func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) { func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchReader, i int) { batchSize := batchReader.BatchSize() + const virtioNetHdrLen = 10 // Allocate buffers for batch reading bufs := make([][]byte, batchSize) @@ -330,7 +334,9 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe sizes := make([]int, batchSize) // Per-packet state (reused across batches) - out := make([]byte, mtu) + // Allocate out buffer with virtio header headroom to avoid copies on write + outBuf := make([]byte, virtioNetHdrLen+mtu) + out := outBuf[virtioNetHdrLen:] fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 0d7ad2a8..e2e926a8 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -92,9 +92,24 @@ func (w *wgDeviceWrapper) Read(b []byte) (int, error) { } func (w *wgDeviceWrapper) Write(b []byte) (int, error) { - // Allocate buffer with space for virtio header - buf := make([]byte, virtioNetHdrLen+len(b)) - copy(buf[virtioNetHdrLen:], b) + // Check if buffer has the expected headroom pattern to avoid copy + var buf []byte + + if cap(b) >= len(b)+virtioNetHdrLen { + buf = b[:cap(b)] + if len(buf) == len(b)+virtioNetHdrLen { + // Perfect! Buffer has headroom, no copy needed + buf = buf[:len(b)+virtioNetHdrLen] + } else { + // Unexpected capacity, safer to copy + buf = make([]byte, virtioNetHdrLen+len(b)) + copy(buf[virtioNetHdrLen:], b) + } + } else { + // No headroom, need to allocate and copy + buf = make([]byte, virtioNetHdrLen+len(b)) + copy(buf[virtioNetHdrLen:], b) + } bufs := [][]byte{buf} n, err := w.dev.Write(bufs, virtioNetHdrLen) @@ -404,9 +419,28 @@ func (t *tun) BatchSize() int { func (t *tun) Write(b []byte) (int, error) { if t.wgDevice != nil { // Use wireguard device which handles virtio headers internally - // Allocate buffer with space for virtio header - buf := make([]byte, virtioNetHdrLen+len(b)) - copy(buf[virtioNetHdrLen:], b) + // Check if buffer has the expected headroom pattern: + // cap(b) should be len(b) + virtioNetHdrLen, indicating pre-allocated headroom + var buf []byte + + if cap(b) >= len(b)+virtioNetHdrLen { + // Buffer likely has headroom - use unsafe to access it + // Create a slice that includes the headroom by re-slicing from capacity + buf = b[:cap(b)] + // Check if we have exactly the right amount of extra capacity + if len(buf) == len(b)+virtioNetHdrLen { + // Perfect! This buffer was allocated with headroom, no copy needed + buf = buf[:len(b)+virtioNetHdrLen] + } else { + // Unexpected capacity, safer to copy + buf = make([]byte, virtioNetHdrLen+len(b)) + copy(buf[virtioNetHdrLen:], b) + } + } else { + // No headroom, need to allocate and copy + buf = make([]byte, virtioNetHdrLen+len(b)) + copy(buf[virtioNetHdrLen:], b) + } bufs := [][]byte{buf} n, err := t.wgDevice.Write(bufs, virtioNetHdrLen) From b2bc6a09ca9cb30f43b0166b08bb6ec4bcb94570 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 11 Nov 2025 15:06:45 -0500 Subject: [PATCH 06/13] write in batches --- inside.go | 139 ++++++++++++++++++++++++++++++++++++++++++++ interface.go | 19 +++--- udp/conn.go | 4 ++ udp/udp_darwin.go | 11 ++++ udp/udp_linux.go | 85 +++++++++++++++++++++++++++ udp/udp_linux_64.go | 38 ++++++++++++ 6 files changed, 286 insertions(+), 10 deletions(-) diff --git a/inside.go b/inside.go index d24ed31b..d675f78a 100644 --- a/inside.go +++ b/inside.go @@ -11,6 +11,145 @@ import ( "github.com/slackhq/nebula/routing" ) +// consumeInsidePackets processes multiple packets in a batch for improved performance +// packets: slice of packet buffers to process +// sizes: slice of packet sizes +// count: number of packets to process +// outs: slice of output buffers (one per packet) with virtio headroom +// q: queue index +// localCache: firewall conntrack cache +func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, q int, localCache firewall.ConntrackCache) { + // Reusable per-packet state + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + // Accumulate encrypted packets for batch sending + batchPackets := make([][]byte, 0, count) + batchAddrs := make([]netip.AddrPort, 0, count) + + // Process each packet in the batch + for i := 0; i < count; i++ { + packet := packets[i][:sizes[i]] + out := outs[i] + + // Inline the consumeInsidePacket logic for better performance + err := newPacket(packet, false, fwPacket) + if err != nil { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + } + continue + } + + // Ignore local broadcast packets + if f.dropLocalBroadcast { + if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { + continue + } + } + + if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { + // Immediately forward packets from self to self. + if immediatelyForwardToSelf { + _, err := f.readers[q].Write(packet) + if err != nil { + f.l.WithError(err).Error("Failed to forward to tun") + } + } + continue + } + + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { + continue + } + + hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + + if hostinfo == nil { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). + WithField("fwPacket", fwPacket). + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") + } + continue + } + + if !ready { + continue + } + + dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + if dropReason != nil { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l). + WithField("fwPacket", fwPacket). + WithField("reason", dropReason). + Debugln("dropping outbound packet") + } + continue + } + + // Encrypt and prepare packet for batch sending + ci := hostinfo.ConnectionState + if ci.eKey == nil { + continue + } + + // Check if this needs relay - if so, send immediately and skip batching + useRelay := !hostinfo.remote.IsValid() + if useRelay { + // Handle relay sends individually (less common path) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q) + continue + } + + // Encrypt the packet for batch sending + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + // Query lighthouse if needed + if hostinfo.lastRebindCount != f.rebindCount { + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("counter", c). + Error("Failed to encrypt outgoing packet") + continue + } + + // Add to batch + batchPackets = append(batchPackets, out) + batchAddrs = append(batchAddrs, hostinfo.remote) + } + + // Send all accumulated packets in one batch + if len(batchPackets) > 0 { + n, err := f.writers[q].WriteMulti(batchPackets, batchAddrs) + if err != nil { + f.l.WithError(err).WithField("sent", n).WithField("total", len(batchPackets)).Error("Failed to send batch") + } + } +} + func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { diff --git a/interface.go b/interface.go index df4c3d37..8180cbf2 100644 --- a/interface.go +++ b/interface.go @@ -333,12 +333,13 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe } sizes := make([]int, batchSize) - // Per-packet state (reused across batches) - // Allocate out buffer with virtio header headroom to avoid copies on write - outBuf := make([]byte, virtioNetHdrLen+mtu) - out := outBuf[virtioNetHdrLen:] - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + // Allocate output buffers for batch processing (one per packet) + // Each has virtio header headroom to avoid copies on write + outs := make([][]byte, batchSize) + for idx := range outs { + outBuf := make([]byte, virtioNetHdrLen+mtu) + outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom + } conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) @@ -354,10 +355,8 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe os.Exit(2) } - // Process each packet in the batch - for j := 0; j < n; j++ { - f.consumeInsidePacket(bufs[j][:sizes[j]], fwPacket, nb, out, i, conntrackCache.Get(f.l)) - } + // Process all packets in the batch at once + f.consumeInsidePackets(bufs, sizes, n, outs, i, conntrackCache.Get(f.l)) } } diff --git a/udp/conn.go b/udp/conn.go index 895b0df3..8c821d33 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -18,6 +18,7 @@ type Conn interface { LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error + WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) ReloadConfig(c *config.C) Close() error } @@ -36,6 +37,9 @@ func (NoopConn) ListenOut(_ EncReader) { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) { + return 0, nil +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index c0c6233c..b409d774 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +// WriteMulti sends multiple packets - fallback implementation without sendmmsg +func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ec0bf64b..aec52154 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -194,6 +194,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return u.writeTo6(b, ip) } +func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + if len(packets) != len(addrs) { + return 0, fmt.Errorf("packets and addrs length mismatch") + } + if len(packets) == 0 { + return 0, nil + } + if u.isV4 { + return u.writeMulti4(packets, addrs) + } + return u.writeMulti6(packets, addrs) +} + func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 @@ -248,6 +261,78 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { } } +func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) { + msgs, iovecs, names := u.PrepareWriteMessages4(len(packets)) + + for i := range packets { + if !addrs[i].Addr().Is4() { + return i, ErrInvalidIPv6RemoteForSocket + } + + // Setup the packet buffer + iovecs[i].Base = &packets[i][0] + iovecs[i].Len = uint64(len(packets[i])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET + rsa.Addr = addrs[i].Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port()) + } + + for { + n, _, err := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + 0, + 0, + 0, + ) + + if err != 0 { + return int(n), &net.OpError{Op: "sendmmsg", Err: err} + } + + return int(n), nil + } +} + +func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) { + msgs, iovecs, names := u.PrepareWriteMessages6(len(packets)) + + for i := range packets { + // Setup the packet buffer + iovecs[i].Base = &packets[i][0] + iovecs[i].Len = uint64(len(packets[i])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET6 + rsa.Addr = addrs[i].Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port()) + } + + for { + n, _, err := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + 0, + 0, + 0, + ) + + if err != 0 { + return int(n), &net.OpError{Op: "sendmmsg", Err: err} + } + + return int(n), nil + } +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..36ce8a4a 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -55,3 +55,41 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +func (u *StdConn) PrepareWriteMessages4(n int) ([]rawMessage, []iovec, [][]byte) { + msgs := make([]rawMessage, n) + iovecs := make([]iovec, n) + names := make([][]byte, n) + + for i := range msgs { + names[i] = make([]byte, unix.SizeofSockaddrInet4) + + // Point to the iovec in the slice + msgs[i].Hdr.Iov = &iovecs[i] + msgs[i].Hdr.Iovlen = 1 + + msgs[i].Hdr.Name = &names[i][0] + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4 + } + + return msgs, iovecs, names +} + +func (u *StdConn) PrepareWriteMessages6(n int) ([]rawMessage, []iovec, [][]byte) { + msgs := make([]rawMessage, n) + iovecs := make([]iovec, n) + names := make([][]byte, n) + + for i := range msgs { + names[i] = make([]byte, unix.SizeofSockaddrInet6) + + // Point to the iovec in the slice + msgs[i].Hdr.Iov = &iovecs[i] + msgs[i].Hdr.Iovlen = 1 + + msgs[i].Hdr.Name = &names[i][0] + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6 + } + + return msgs, iovecs, names +} From 226787ea1f84870b70aef7fb6e15c1a8683f0e83 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 11 Nov 2025 15:20:50 -0500 Subject: [PATCH 07/13] prealloc them buffers --- inside.go | 20 ++++--- interface.go | 6 +- udp/udp_linux.go | 141 +++++++++++++++++++++++++++++++++----------- udp/udp_linux_64.go | 38 ------------ 4 files changed, 121 insertions(+), 84 deletions(-) diff --git a/inside.go b/inside.go index d675f78a..12fad380 100644 --- a/inside.go +++ b/inside.go @@ -18,14 +18,16 @@ import ( // outs: slice of output buffers (one per packet) with virtio headroom // q: queue index // localCache: firewall conntrack cache -func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, q int, localCache firewall.ConntrackCache) { +// batchPackets: pre-allocated slice for accumulating encrypted packets +// batchAddrs: pre-allocated slice for accumulating destination addresses +func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) { // Reusable per-packet state fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - // Accumulate encrypted packets for batch sending - batchPackets := make([][]byte, 0, count) - batchAddrs := make([]netip.AddrPort, 0, count) + // Reset batch accumulation slices (reuse capacity) + *batchPackets = (*batchPackets)[:0] + *batchAddrs = (*batchAddrs)[:0] // Process each packet in the batch for i := 0; i < count; i++ { @@ -137,15 +139,15 @@ func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count in } // Add to batch - batchPackets = append(batchPackets, out) - batchAddrs = append(batchAddrs, hostinfo.remote) + *batchPackets = append(*batchPackets, out) + *batchAddrs = append(*batchAddrs, hostinfo.remote) } // Send all accumulated packets in one batch - if len(batchPackets) > 0 { - n, err := f.writers[q].WriteMulti(batchPackets, batchAddrs) + if len(*batchPackets) > 0 { + n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs) if err != nil { - f.l.WithError(err).WithField("sent", n).WithField("total", len(batchPackets)).Error("Failed to send batch") + f.l.WithError(err).WithField("sent", n).WithField("total", len(*batchPackets)).Error("Failed to send batch") } } } diff --git a/interface.go b/interface.go index 8180cbf2..0486378a 100644 --- a/interface.go +++ b/interface.go @@ -341,6 +341,10 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom } + // Pre-allocate batch accumulation buffers for sending + batchPackets := make([][]byte, 0, batchSize) + batchAddrs := make([]netip.AddrPort, 0, batchSize) + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) for { @@ -356,7 +360,7 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe } // Process all packets in the batch at once - f.consumeInsidePackets(bufs, sizes, n, outs, i, conntrackCache.Get(f.l)) + f.consumeInsidePackets(bufs, sizes, n, outs, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index aec52154..c83f7469 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -22,6 +22,11 @@ type StdConn struct { isV4 bool l *logrus.Logger batch int + + // Pre-allocated buffers for batch writes (sized for IPv6, works for both) + writeMsgs []rawMessage + writeIovecs []iovec + writeNames [][]byte } func maybeIPV4(ip net.IP) (net.IP, bool) { @@ -69,7 +74,26 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch} + + // Pre-allocate write message structures for batching (sized for IPv6, works for both) + c.writeMsgs = make([]rawMessage, batch) + c.writeIovecs = make([]iovec, batch) + c.writeNames = make([][]byte, batch) + + for i := range c.writeMsgs { + // Allocate for IPv6 size (larger than IPv4, works for both) + c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + + // Point to the iovec in the slice + c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i] + c.writeMsgs[i].Hdr.Iovlen = 1 + + c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0] + // Namelen will be set appropriately in writeMulti4/writeMulti6 + } + + return c, err } func (u *StdConn) Rebind() error { @@ -262,75 +286,120 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { } func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) { - msgs, iovecs, names := u.PrepareWriteMessages4(len(packets)) - - for i := range packets { - if !addrs[i].Addr().Is4() { - return i, ErrInvalidIPv6RemoteForSocket + sent := 0 + for sent < len(packets) { + // Determine batch size based on remaining packets and buffer capacity + batchSize := len(packets) - sent + if batchSize > len(u.writeMsgs) { + batchSize = len(u.writeMsgs) } - // Setup the packet buffer - iovecs[i].Base = &packets[i][0] - iovecs[i].Len = uint64(len(packets[i])) + // Use pre-allocated buffers + msgs := u.writeMsgs[:batchSize] + iovecs := u.writeIovecs[:batchSize] + names := u.writeNames[:batchSize] - // Setup the destination address - rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0])) - rsa.Family = unix.AF_INET - rsa.Addr = addrs[i].Addr().As4() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port()) - } + // Setup message structures for this batch + for i := 0; i < batchSize; i++ { + pktIdx := sent + i + if !addrs[pktIdx].Addr().Is4() { + return sent + i, ErrInvalidIPv6RemoteForSocket + } - for { - n, _, err := unix.Syscall6( + // Setup the packet buffer + iovecs[i].Base = &packets[pktIdx][0] + iovecs[i].Len = uint64(len(packets[pktIdx])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET + rsa.Addr = addrs[pktIdx].Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port()) + + // Set the appropriate address length for IPv4 + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4 + } + + // Send this batch + nsent, _, err := unix.Syscall6( unix.SYS_SENDMMSG, uintptr(u.sysFd), uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), + uintptr(batchSize), 0, 0, 0, ) if err != 0 { - return int(n), &net.OpError{Op: "sendmmsg", Err: err} + return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err} } - return int(n), nil + sent += int(nsent) + if int(nsent) < batchSize { + // Couldn't send all packets in batch, return what we sent + return sent, nil + } } + + return sent, nil } func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) { - msgs, iovecs, names := u.PrepareWriteMessages6(len(packets)) + sent := 0 + for sent < len(packets) { + // Determine batch size based on remaining packets and buffer capacity + batchSize := len(packets) - sent + if batchSize > len(u.writeMsgs) { + batchSize = len(u.writeMsgs) + } - for i := range packets { - // Setup the packet buffer - iovecs[i].Base = &packets[i][0] - iovecs[i].Len = uint64(len(packets[i])) + // Use pre-allocated buffers + msgs := u.writeMsgs[:batchSize] + iovecs := u.writeIovecs[:batchSize] + names := u.writeNames[:batchSize] - // Setup the destination address - rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0])) - rsa.Family = unix.AF_INET6 - rsa.Addr = addrs[i].Addr().As16() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port()) - } + // Setup message structures for this batch + for i := 0; i < batchSize; i++ { + pktIdx := sent + i - for { - n, _, err := unix.Syscall6( + // Setup the packet buffer + iovecs[i].Base = &packets[pktIdx][0] + iovecs[i].Len = uint64(len(packets[pktIdx])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET6 + rsa.Addr = addrs[pktIdx].Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port()) + + // Set the appropriate address length for IPv6 + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6 + } + + // Send this batch + nsent, _, err := unix.Syscall6( unix.SYS_SENDMMSG, uintptr(u.sysFd), uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), + uintptr(batchSize), 0, 0, 0, ) if err != 0 { - return int(n), &net.OpError{Op: "sendmmsg", Err: err} + return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err} } - return int(n), nil + sent += int(nsent) + if int(nsent) < batchSize { + // Couldn't send all packets in batch, return what we sent + return sent, nil + } } + + return sent, nil } func (u *StdConn) ReloadConfig(c *config.C) { diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 36ce8a4a..48c5a978 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -55,41 +55,3 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } - -func (u *StdConn) PrepareWriteMessages4(n int) ([]rawMessage, []iovec, [][]byte) { - msgs := make([]rawMessage, n) - iovecs := make([]iovec, n) - names := make([][]byte, n) - - for i := range msgs { - names[i] = make([]byte, unix.SizeofSockaddrInet4) - - // Point to the iovec in the slice - msgs[i].Hdr.Iov = &iovecs[i] - msgs[i].Hdr.Iovlen = 1 - - msgs[i].Hdr.Name = &names[i][0] - msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4 - } - - return msgs, iovecs, names -} - -func (u *StdConn) PrepareWriteMessages6(n int) ([]rawMessage, []iovec, [][]byte) { - msgs := make([]rawMessage, n) - iovecs := make([]iovec, n) - names := make([][]byte, n) - - for i := range msgs { - names[i] = make([]byte, unix.SizeofSockaddrInet6) - - // Point to the iovec in the slice - msgs[i].Hdr.Iov = &iovecs[i] - msgs[i].Hdr.Iovlen = 1 - - msgs[i].Hdr.Name = &names[i][0] - msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6 - } - - return msgs, iovecs, names -} From a62ffca97532507733bdcf6287c0b982a9334782 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Thu, 13 Nov 2025 15:09:39 -0500 Subject: [PATCH 08/13] fix 32bit --- udp/udp_linux.go | 4 ++-- udp/udp_linux_32.go | 4 ++-- udp/udp_linux_64.go | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index c83f7469..a8f300dd 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -308,7 +308,7 @@ func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, er // Setup the packet buffer iovecs[i].Base = &packets[pktIdx][0] - iovecs[i].Len = uint64(len(packets[pktIdx])) + iovecs[i].Len = uint(len(packets[pktIdx])) // Setup the destination address rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0])) @@ -365,7 +365,7 @@ func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, er // Setup the packet buffer iovecs[i].Base = &packets[pktIdx][0] - iovecs[i].Len = uint64(len(packets[pktIdx])) + iovecs[i].Len = uint(len(packets[pktIdx])) // Setup the destination address rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0])) diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cdf..707a2b1f 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -12,7 +12,7 @@ import ( type iovec struct { Base *byte - Len uint32 + Len uint } type msghdr struct { @@ -40,7 +40,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ - {Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, + {Base: &buffers[i][0], Len: uint(len(buffers[i]))}, } msgs[i].Hdr.Iov = &vs[0] diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..89c6695d 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -12,7 +12,7 @@ import ( type iovec struct { Base *byte - Len uint64 + Len uint } type msghdr struct { @@ -43,7 +43,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ - {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, + {Base: &buffers[i][0], Len: uint(len(buffers[i]))}, } msgs[i].Hdr.Iov = &vs[0] From 7c3708561d0af2b022b9c30955f609788b484851 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Fri, 14 Nov 2025 14:43:51 -0500 Subject: [PATCH 09/13] instruments --- inside.go | 6 +++++- interface.go | 4 ++++ overlay/tun_linux.go | 1 + udp/udp_linux.go | 4 ++++ 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/inside.go b/inside.go index 12fad380..07f0418a 100644 --- a/inside.go +++ b/inside.go @@ -3,6 +3,7 @@ package nebula import ( "net/netip" + "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -145,9 +146,12 @@ func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count in // Send all accumulated packets in one batch if len(*batchPackets) > 0 { + batchSize := len(*batchPackets) + metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)).Update(int64(batchSize)) + n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs) if err != nil { - f.l.WithError(err).WithField("sent", n).WithField("total", len(*batchPackets)).Error("Failed to send batch") + f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch") } } } diff --git a/interface.go b/interface.go index 0486378a..d80a7214 100644 --- a/interface.go +++ b/interface.go @@ -347,6 +347,8 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + tunBatchHist := metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)) + for { n, err := batchReader.BatchRead(bufs, sizes) if err != nil { @@ -359,6 +361,8 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe os.Exit(2) } + tunBatchHist.Update(int64(n)) + // Process all packets in the batch at once f.consumeInsidePackets(bufs, sizes, n, outs, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs) } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index e2e926a8..c5d7739e 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -128,6 +128,7 @@ func (w *wgDeviceWrapper) Close() error { // BatchRead implements batching for multiqueue readers func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) { + // The zero here is offset. return w.dev.Read(bufs, sizes, 0) } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index a8f300dd..c591a09c 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -151,6 +151,8 @@ func (u *StdConn) ListenOut(r EncReader) { read = u.ReadSingle } + udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)) + for { n, err := read(msgs) if err != nil { @@ -158,6 +160,8 @@ func (u *StdConn) ListenOut(r EncReader) { return } + udpBatchHist.Update(int64(n)) + for i := 0; i < n; i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { From 518a78c9d231a9873c756b7edf1a06990bd0f6fb Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Tue, 18 Nov 2025 14:19:05 -0500 Subject: [PATCH 10/13] preallocate nonce buffer --- inside.go | 3 +-- interface.go | 5 ++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/inside.go b/inside.go index 07f0418a..54627093 100644 --- a/inside.go +++ b/inside.go @@ -21,10 +21,9 @@ import ( // localCache: firewall conntrack cache // batchPackets: pre-allocated slice for accumulating encrypted packets // batchAddrs: pre-allocated slice for accumulating destination addresses -func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) { +func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) { // Reusable per-packet state fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) // Reset batch accumulation slices (reuse capacity) *batchPackets = (*batchPackets)[:0] diff --git a/interface.go b/interface.go index d80a7214..725a6dd2 100644 --- a/interface.go +++ b/interface.go @@ -345,6 +345,9 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe batchPackets := make([][]byte, 0, batchSize) batchAddrs := make([]netip.AddrPort, 0, batchSize) + // Pre-allocate nonce buffer (reused for all encryptions) + nb := make([]byte, 12, 12) + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) tunBatchHist := metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)) @@ -364,7 +367,7 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe tunBatchHist.Update(int64(n)) // Process all packets in the batch at once - f.consumeInsidePackets(bufs, sizes, n, outs, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs) + f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs) } } From 8b32382cd97919a65162c2f66cc8bcf271147c20 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Wed, 19 Nov 2025 12:03:38 -0500 Subject: [PATCH 11/13] zero copy even with virtioheder --- interface.go | 16 +++++++----- outside.go | 11 ++++---- overlay/tun.go | 1 + overlay/tun_linux.go | 61 ++++++++------------------------------------ stats.go | 1 + 5 files changed, 27 insertions(+), 63 deletions(-) diff --git a/interface.go b/interface.go index 725a6dd2..5b96d1c5 100644 --- a/interface.go +++ b/interface.go @@ -22,6 +22,7 @@ import ( ) const mtu = 9001 +const virtioNetHdrLen = overlay.VirtioNetHdrLen type InterfaceConfig struct { HostMap *HostMap @@ -266,13 +267,16 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) + + // Allocate plaintext buffer with virtio header headroom to avoid copies on TUN write + plaintext := make([]byte, virtioNetHdrLen+udp.MTU) + h := &header.H{} fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + nb := make([]byte, 12) li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + f.readOutsidePackets(fromUdpAddr, nil, plaintext[:virtioNetHdrLen], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) } @@ -298,11 +302,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) { packet := make([]byte, mtu) // Allocate out buffer with virtio header headroom (10 bytes) to avoid copies on write - const virtioNetHdrLen = 10 outBuf := make([]byte, virtioNetHdrLen+mtu) out := outBuf[virtioNetHdrLen:] // Use slice starting after headroom fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + nb := make([]byte, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) @@ -324,7 +327,6 @@ func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) { func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchReader, i int) { batchSize := batchReader.BatchSize() - const virtioNetHdrLen = 10 // Allocate buffers for batch reading bufs := make([][]byte, batchSize) @@ -346,7 +348,7 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe batchAddrs := make([]netip.AddrPort, 0, batchSize) // Pre-allocate nonce buffer (reused for all encryptions) - nb := make([]byte, 12, 12) + nb := make([]byte, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) diff --git a/outside.go b/outside.go index 5ff87bd8..eae15f37 100644 --- a/outside.go +++ b/outside.go @@ -95,8 +95,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] switch relay.Type { case TerminalType: // If I am the target of this relay, process the unwrapped packet - // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -474,9 +473,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - err = newPacket(out, true, fwPacket) + packetData := out[virtioNetHdrLen:] + + err = newPacket(packetData, true, fwPacket) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("packet", out). + hostinfo.logger(f.l).WithError(err).WithField("packet", packetData). Warnf("Error while validating inbound packet") return false } @@ -491,7 +492,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in - f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) + f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason). diff --git a/overlay/tun.go b/overlay/tun.go index 3a61d186..7947e295 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -11,6 +11,7 @@ import ( ) const DefaultMTU = 1300 +const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure // TODO: We may be able to remove routines type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index c5d7739e..f40031c2 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -66,10 +66,6 @@ type ifreqQLEN struct { pad [8]byte } -const ( - virtioNetHdrLen = 10 // Size of virtio_net_hdr structure -) - // wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser // This allows multiqueue readers to use the same wireguard Device batching as the main device type wgDeviceWrapper struct { @@ -92,27 +88,11 @@ func (w *wgDeviceWrapper) Read(b []byte) (int, error) { } func (w *wgDeviceWrapper) Write(b []byte) (int, error) { - // Check if buffer has the expected headroom pattern to avoid copy - var buf []byte - - if cap(b) >= len(b)+virtioNetHdrLen { - buf = b[:cap(b)] - if len(buf) == len(b)+virtioNetHdrLen { - // Perfect! Buffer has headroom, no copy needed - buf = buf[:len(b)+virtioNetHdrLen] - } else { - // Unexpected capacity, safer to copy - buf = make([]byte, virtioNetHdrLen+len(b)) - copy(buf[virtioNetHdrLen:], b) - } - } else { - // No headroom, need to allocate and copy - buf = make([]byte, virtioNetHdrLen+len(b)) - copy(buf[virtioNetHdrLen:], b) - } - - bufs := [][]byte{buf} - n, err := w.dev.Write(bufs, virtioNetHdrLen) + // Buffer b should have virtio header space (10 bytes) at the beginning + // The decrypted packet data starts at offset 10 + // Pass the full buffer to WireGuard with offset=virtioNetHdrLen + bufs := [][]byte{b} + n, err := w.dev.Write(bufs, VirtioNetHdrLen) if err != nil { return 0, err } @@ -419,32 +399,11 @@ func (t *tun) BatchSize() int { func (t *tun) Write(b []byte) (int, error) { if t.wgDevice != nil { - // Use wireguard device which handles virtio headers internally - // Check if buffer has the expected headroom pattern: - // cap(b) should be len(b) + virtioNetHdrLen, indicating pre-allocated headroom - var buf []byte - - if cap(b) >= len(b)+virtioNetHdrLen { - // Buffer likely has headroom - use unsafe to access it - // Create a slice that includes the headroom by re-slicing from capacity - buf = b[:cap(b)] - // Check if we have exactly the right amount of extra capacity - if len(buf) == len(b)+virtioNetHdrLen { - // Perfect! This buffer was allocated with headroom, no copy needed - buf = buf[:len(b)+virtioNetHdrLen] - } else { - // Unexpected capacity, safer to copy - buf = make([]byte, virtioNetHdrLen+len(b)) - copy(buf[virtioNetHdrLen:], b) - } - } else { - // No headroom, need to allocate and copy - buf = make([]byte, virtioNetHdrLen+len(b)) - copy(buf[virtioNetHdrLen:], b) - } - - bufs := [][]byte{buf} - n, err := t.wgDevice.Write(bufs, virtioNetHdrLen) + // Buffer b should have virtio header space (10 bytes) at the beginning + // The decrypted packet data starts at offset 10 + // Pass the full buffer to WireGuard with offset=virtioNetHdrLen + bufs := [][]byte{b} + n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen) if err != nil { return 0, err } diff --git a/stats.go b/stats.go index c88c45cc..b86919cc 100644 --- a/stats.go +++ b/stats.go @@ -6,6 +6,7 @@ import ( "log" "net" "net/http" + _ "net/http/pprof" "runtime" "strconv" "time" From f29e21b411752fb308e669388dbbe3c25389d435 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Wed, 19 Nov 2025 13:25:25 -0500 Subject: [PATCH 12/13] don't register metrics in loops --- inside.go | 3 +-- interface.go | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/inside.go b/inside.go index 54627093..198fc55c 100644 --- a/inside.go +++ b/inside.go @@ -3,7 +3,6 @@ package nebula import ( "net/netip" - "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -146,7 +145,7 @@ func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count in // Send all accumulated packets in one batch if len(*batchPackets) > 0 { batchSize := len(*batchPackets) - metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)).Update(int64(batchSize)) + f.batchMetrics.udpWriteSize.Update(int64(batchSize)) n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs) if err != nil { diff --git a/interface.go b/interface.go index 5b96d1c5..2ef1c31b 100644 --- a/interface.go +++ b/interface.go @@ -51,6 +51,12 @@ type InterfaceConfig struct { l *logrus.Logger } +type batchMetrics struct { + udpReadSize metrics.Histogram + tunReadSize metrics.Histogram + udpWriteSize metrics.Histogram +} + type Interface struct { hostMap *HostMap outside udp.Conn @@ -92,6 +98,7 @@ type Interface struct { metricHandshakes metrics.Histogram messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics + batchMetrics *batchMetrics l *logrus.Logger } @@ -194,6 +201,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), }, + batchMetrics: &batchMetrics{ + udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)), + tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)), + udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)), + }, l: c.l, } @@ -352,8 +364,6 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - tunBatchHist := metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)) - for { n, err := batchReader.BatchRead(bufs, sizes) if err != nil { @@ -366,7 +376,7 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe os.Exit(2) } - tunBatchHist.Update(int64(n)) + f.batchMetrics.tunReadSize.Update(int64(n)) // Process all packets in the batch at once f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs) From 4e333c76baa72c58a478f5969283bd96efcb8bac Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Wed, 19 Nov 2025 14:03:36 -0500 Subject: [PATCH 13/13] write batching --- interface.go | 13 +++++++++---- overlay/tun_linux.go | 21 +++++++++++++++++++++ udp/conn.go | 14 ++++++++++++++ udp/udp_darwin.go | 28 ++++++++++++++++++++++++++++ udp/udp_generic.go | 39 +++++++++++++++++++++++++++++++++++++++ udp/udp_linux.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ udp/udp_tester.go | 29 +++++++++++++++++++++++++++++ 7 files changed, 184 insertions(+), 4 deletions(-) diff --git a/interface.go b/interface.go index 2ef1c31b..74e2c84b 100644 --- a/interface.go +++ b/interface.go @@ -280,15 +280,20 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - // Allocate plaintext buffer with virtio header headroom to avoid copies on TUN write - plaintext := make([]byte, virtioNetHdrLen+udp.MTU) + // Pre-allocate output buffers for batch processing + batchSize := li.BatchSize() + outs := make([][]byte, batchSize) + for idx := range outs { + // Allocate full buffer with virtio header space + outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU) + } h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(fromUdpAddr, nil, plaintext[:virtioNetHdrLen], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) { + f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l)) }) } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index f40031c2..3c98d727 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -102,6 +102,11 @@ func (w *wgDeviceWrapper) Write(b []byte) (int, error) { return len(b), nil } +func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) { + // Pass all buffers to WireGuard's batch write + return w.dev.Write(bufs, offset) +} + func (w *wgDeviceWrapper) Close() error { return w.dev.Close() } @@ -436,6 +441,22 @@ func (t *tun) Write(b []byte) (int, error) { } } +// WriteBatch writes multiple packets to the TUN device in a single syscall +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + if t.wgDevice != nil { + return t.wgDevice.Write(bufs, offset) + } + + // Fallback: write individually (shouldn't happen in normal operation) + for i, buf := range bufs { + _, err := t.Write(buf) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) diff --git a/udp/conn.go b/udp/conn.go index 8c821d33..f3267c9e 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -13,13 +13,21 @@ type EncReader func( payload []byte, ) +type EncBatchReader func( + addrs []netip.AddrPort, + payloads [][]byte, + count int, +) + type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader) + ListenOutBatch(r EncBatchReader) WriteTo(b []byte, addr netip.AddrPort) error WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) ReloadConfig(c *config.C) + BatchSize() int Close() error } @@ -34,6 +42,9 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) ListenOut(_ EncReader) { return } +func (NoopConn) ListenOutBatch(_ EncBatchReader) { + return +} func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } @@ -43,6 +54,9 @@ func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) { func (NoopConn) ReloadConfig(_ *config.C) { return } +func (NoopConn) BatchSize() int { + return 1 +} func (NoopConn) Close() error { return nil } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index b409d774..787f1c4d 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -195,6 +195,34 @@ func (u *StdConn) ListenOut(r EncReader) { } } +// ListenOutBatch - fallback to single-packet reads for Darwin +func (u *StdConn) ListenOutBatch(r EncBatchReader) { + buffer := make([]byte, MTU) + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + // Just read one packet at a time and call batch callback with count=1 + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + if errors.Is(err, net.ErrClosed) { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + u.l.WithError(err).Error("unexpected udp socket receive error") + } + + addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()) + payloads[0] = buffer[:n] + r(addrs, payloads, 1) + } +} + +func (u *StdConn) BatchSize() int { + return 1 +} + func (u *StdConn) Rebind() error { var err error if u.isV4 { diff --git a/udp/udp_generic.go b/udp/udp_generic.go index cb21e574..7c8cdf4b 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -85,3 +85,42 @@ func (u *GenericConn) ListenOut(r EncReader) { r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } + +// ListenOutBatch - fallback to single-packet reads for generic platforms +func (u *GenericConn) ListenOutBatch(r EncBatchReader) { + buffer := make([]byte, MTU) + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + // Just read one packet at a time and call batch callback with count=1 + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()) + payloads[0] = buffer[:n] + r(addrs, payloads, 1) + } +} + +// WriteMulti sends multiple packets - fallback implementation +func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + +func (u *GenericConn) BatchSize() int { + return 1 +} + +func (u *GenericConn) Rebind() error { + return nil +} diff --git a/udp/udp_linux.go b/udp/udp_linux.go index c591a09c..efb71c3f 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -174,6 +174,46 @@ func (u *StdConn) ListenOut(r EncReader) { } } +func (u *StdConn) ListenOutBatch(r EncBatchReader) { + var ip netip.Addr + + msgs, buffers, names := u.PrepareRawMessages(u.batch) + read := u.ReadMulti + if u.batch == 1 { + read = u.ReadSingle + } + + udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)) + + // Pre-allocate slices for batch callback + addrs := make([]netip.AddrPort, u.batch) + payloads := make([][]byte, u.batch) + + for { + n, err := read(msgs) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + udpBatchHist.Update(int64(n)) + + // Prepare batch data + for i := 0; i < n; i++ { + if u.isV4 { + ip, _ = netip.AddrFromSlice(names[i][4:8]) + } else { + ip, _ = netip.AddrFromSlice(names[i][8:24]) + } + addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + payloads[i] = buffers[i][:msgs[i].Len] + } + + // Call batch callback with all packets + r(addrs, payloads, n) + } +} + func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( @@ -463,6 +503,10 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { return nil } +func (u *StdConn) BatchSize() int { + return u.batch +} + func (u *StdConn) Close() error { return syscall.Close(u.sysFd) } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 8d5e6c14..d88d8aa6 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -116,6 +116,31 @@ func (u *TesterConn) ListenOut(r EncReader) { } } +func (u *TesterConn) ListenOutBatch(r EncBatchReader) { + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + p, ok := <-u.RxPackets + if !ok { + return + } + addrs[0] = p.From + payloads[0] = p.Data + r(addrs, payloads, 1) + } +} + +func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + func (u *TesterConn) ReloadConfig(*config.C) {} func NewUDPStatsEmitter(_ []Conn) func() { @@ -127,6 +152,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil } +func (u *TesterConn) BatchSize() int { + return 1 +} + func (u *TesterConn) Rebind() error { return nil }