diff --git a/inside.go b/inside.go index 0d53f952..32c9a99f 100644 --- a/inside.go +++ b/inside.go @@ -11,6 +11,149 @@ import ( "github.com/slackhq/nebula/routing" ) +// consumeInsidePackets processes multiple packets in a batch for improved performance +// packets: slice of packet buffers to process +// sizes: slice of packet sizes +// count: number of packets to process +// outs: slice of output buffers (one per packet) with virtio headroom +// q: queue index +// localCache: firewall conntrack cache +// batchPackets: pre-allocated slice for accumulating encrypted packets +// batchAddrs: pre-allocated slice for accumulating destination addresses +func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) { + // Reusable per-packet state + fwPacket := &firewall.Packet{} + + // Reset batch accumulation slices (reuse capacity) + *batchPackets = (*batchPackets)[:0] + *batchAddrs = (*batchAddrs)[:0] + + // Process each packet in the batch + for i := 0; i < count; i++ { + packet := packets[i][:sizes[i]] + out := outs[i] + + // Inline the consumeInsidePacket logic for better performance + err := newPacket(packet, false, fwPacket) + if err != nil { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + } + continue + } + + // Ignore local broadcast packets + if f.dropLocalBroadcast { + if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { + continue + } + } + + if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { + // Immediately forward packets from self to self. + if immediatelyForwardToSelf { + _, err := f.readers[q].Write(packet) + if err != nil { + f.l.WithError(err).Error("Failed to forward to tun") + } + } + continue + } + + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { + continue + } + + hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + + if hostinfo == nil { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). + WithField("fwPacket", fwPacket). + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") + } + continue + } + + if !ready { + continue + } + + dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + if dropReason != nil { + f.rejectInside(packet, out, q) + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l). + WithField("fwPacket", fwPacket). + WithField("reason", dropReason). + Debugln("dropping outbound packet") + } + continue + } + + // Encrypt and prepare packet for batch sending + ci := hostinfo.ConnectionState + if ci.eKey == nil { + continue + } + + // Check if this needs relay - if so, send immediately and skip batching + useRelay := !hostinfo.remote.IsValid() + if useRelay { + // Handle relay sends individually (less common path) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q) + continue + } + + // Encrypt the packet for batch sending + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + // Query lighthouse if needed + if hostinfo.lastRebindCount != f.rebindCount { + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("counter", c). + Error("Failed to encrypt outgoing packet") + continue + } + + // Add to batch + *batchPackets = append(*batchPackets, out) + *batchAddrs = append(*batchAddrs, hostinfo.remote) + } + + // Send all accumulated packets in one batch + if len(*batchPackets) > 0 { + batchSize := len(*batchPackets) + f.batchMetrics.udpWriteSize.Update(int64(batchSize)) + + n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs) + if err != nil { + f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch") + } + } +} + func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { diff --git a/interface.go b/interface.go index 082906d9..74e2c84b 100644 --- a/interface.go +++ b/interface.go @@ -22,6 +22,7 @@ import ( ) const mtu = 9001 +const virtioNetHdrLen = overlay.VirtioNetHdrLen type InterfaceConfig struct { HostMap *HostMap @@ -50,6 +51,12 @@ type InterfaceConfig struct { l *logrus.Logger } +type batchMetrics struct { + udpReadSize metrics.Histogram + tunReadSize metrics.Histogram + udpWriteSize metrics.Histogram +} + type Interface struct { hostMap *HostMap outside udp.Conn @@ -91,6 +98,7 @@ type Interface struct { metricHandshakes metrics.Histogram messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics + batchMetrics *batchMetrics l *logrus.Logger } @@ -193,6 +201,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), }, + batchMetrics: &batchMetrics{ + udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)), + tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)), + udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)), + }, l: c.l, } @@ -266,23 +279,50 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) + + // Pre-allocate output buffers for batch processing + batchSize := li.BatchSize() + outs := make([][]byte, batchSize) + for idx := range outs { + // Allocate full buffer with virtio header space + outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU) + } + h := &header.H{} fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + nb := make([]byte, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) { + f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l)) }) } +// BatchReader is an interface for devices that support reading multiple packets at once +type BatchReader interface { + BatchRead(bufs [][]byte, sizes []int) (int, error) + BatchSize() int +} + func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() + // Check if reader supports batching + batchReader, supportsBatching := reader.(BatchReader) + + if supportsBatching { + f.listenInBatch(reader, batchReader, i) + } else { + f.listenInSingle(reader, i) + } +} + +func (f *Interface) listenInSingle(reader io.ReadWriteCloser, i int) { packet := make([]byte, mtu) - out := make([]byte, mtu) + // Allocate out buffer with virtio header headroom (10 bytes) to avoid copies on write + outBuf := make([]byte, virtioNetHdrLen+mtu) + out := outBuf[virtioNetHdrLen:] // Use slice starting after headroom fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + nb := make([]byte, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) @@ -302,6 +342,52 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } } +func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchReader, i int) { + batchSize := batchReader.BatchSize() + + // Allocate buffers for batch reading + bufs := make([][]byte, batchSize) + for idx := range bufs { + bufs[idx] = make([]byte, mtu) + } + sizes := make([]int, batchSize) + + // Allocate output buffers for batch processing (one per packet) + // Each has virtio header headroom to avoid copies on write + outs := make([][]byte, batchSize) + for idx := range outs { + outBuf := make([]byte, virtioNetHdrLen+mtu) + outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom + } + + // Pre-allocate batch accumulation buffers for sending + batchPackets := make([][]byte, 0, batchSize) + batchAddrs := make([]netip.AddrPort, 0, batchSize) + + // Pre-allocate nonce buffer (reused for all encryptions) + nb := make([]byte, 12) + + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + + for { + n, err := batchReader.BatchRead(bufs, sizes) + if err != nil { + if errors.Is(err, os.ErrClosed) && f.closed.Load() { + return + } + + f.l.WithError(err).Error("Error while batch reading outbound packets") + // This only seems to happen when something fatal happens to the fd, so exit. + os.Exit(2) + } + + f.batchMetrics.tunReadSize.Update(int64(n)) + + // Process all packets in the batch at once + f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs) + } +} + func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) diff --git a/outside.go b/outside.go index 5ff87bd8..eae15f37 100644 --- a/outside.go +++ b/outside.go @@ -95,8 +95,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] switch relay.Type { case TerminalType: // If I am the target of this relay, process the unwrapped packet - // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -474,9 +473,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - err = newPacket(out, true, fwPacket) + packetData := out[virtioNetHdrLen:] + + err = newPacket(packetData, true, fwPacket) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("packet", out). + hostinfo.logger(f.l).WithError(err).WithField("packet", packetData). Warnf("Error while validating inbound packet") return false } @@ -491,7 +492,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in - f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) + f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason). diff --git a/overlay/tun.go b/overlay/tun.go index 3a61d186..7947e295 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -11,6 +11,7 @@ import ( ) const DefaultMTU = 1300 +const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure // TODO: We may be able to remove routines type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 44d87465..3c98d727 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -9,7 +9,6 @@ import ( "net" "net/netip" "os" - "strings" "sync/atomic" "time" "unsafe" @@ -21,10 +20,12 @@ import ( "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" + wgtun "golang.zx2c4.com/wireguard/tun" ) type tun struct { io.ReadWriteCloser + wgDevice wgtun.Device fd int Device string vpnNetworks []netip.Prefix @@ -65,59 +66,154 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") +// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser +// This allows multiqueue readers to use the same wireguard Device batching as the main device +type wgDeviceWrapper struct { + dev wgtun.Device + buf []byte // Reusable buffer for single packet reads +} +func (w *wgDeviceWrapper) Read(b []byte) (int, error) { + // Use wireguard Device's batch API for single packet + bufs := [][]byte{b} + sizes := make([]int, 1) + n, err := w.dev.Read(bufs, sizes, 0) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.EOF + } + return sizes[0], nil +} + +func (w *wgDeviceWrapper) Write(b []byte) (int, error) { + // Buffer b should have virtio header space (10 bytes) at the beginning + // The decrypted packet data starts at offset 10 + // Pass the full buffer to WireGuard with offset=virtioNetHdrLen + bufs := [][]byte{b} + n, err := w.dev.Write(bufs, VirtioNetHdrLen) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.ErrShortWrite + } + return len(b), nil +} + +func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) { + // Pass all buffers to WireGuard's batch write + return w.dev.Write(bufs, offset) +} + +func (w *wgDeviceWrapper) Close() error { + return w.dev.Close() +} + +// BatchRead implements batching for multiqueue readers +func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) { + // The zero here is offset. + return w.dev.Read(bufs, sizes, 0) +} + +// BatchSize returns the optimal batch size +func (w *wgDeviceWrapper) BatchSize() int { + return w.dev.BatchSize() +} + +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { + wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd) + if err != nil { + return nil, fmt.Errorf("failed to create TUN from FD: %w", err) + } + + file := wgDev.File() t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { + _ = wgDev.Close() return nil, err } - t.Device = "tun0" + t.wgDevice = wgDev + t.Device = name return t, nil } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { - fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - // If /dev/net/tun doesn't exist, try to create it (will happen in docker) - if os.IsNotExist(err) { - err = os.MkdirAll("/dev/net", 0755) - if err != nil { - return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) - } - err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) - if err != nil { - return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) - } - - fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) - if err != nil { - return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) - } - } else { - return nil, err + // Check if /dev/net/tun exists, create if needed (for docker containers) + if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) { + if err := os.MkdirAll("/dev/net", 0755); err != nil { + return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) + } + if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil { + return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) } } - var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) - if multiqueue { - req.Flags |= unix.IFF_MULTI_QUEUE - } - copy(req.Name[:], c.GetString("tun.dev", "")) - if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { - return nil, err - } - name := strings.Trim(string(req.Name[:]), "\x00") + devName := c.GetString("tun.dev", "") + mtu := c.GetInt("tun.mtu", DefaultMTU) - file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + // Create TUN device manually to support multiqueue + fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err } + var req ifReq + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + if multiqueue { + req.Flags |= unix.IFF_MULTI_QUEUE + } + copy(req.Name[:], devName) + if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + unix.Close(fd) + return nil, err + } + + // Set nonblocking + if err = unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } + + // Enable TCP and UDP offload (TSO/GRO) for performance + // This allows the kernel to handle segmentation/coalescing + const ( + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 + ) + offloads := tunTCPOffloads | tunUDPOffloads + if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil { + // Log warning but don't fail - offload is optional + l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced") + } + + file := os.NewFile(uintptr(fd), "/dev/net/tun") + + // Create wireguard device from file descriptor + wgDev, err := wgtun.CreateTUNFromFile(file, mtu) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create TUN from file: %w", err) + } + + name, err := wgDev.Name() + if err != nil { + _ = wgDev.Close() + return nil, fmt.Errorf("failed to get TUN device name: %w", err) + } + + // file is now owned by wgDev, get a new reference + file = wgDev.File() + t, err := newTunGeneric(c, l, file, vpnNetworks) + if err != nil { + _ = wgDev.Close() + return nil, err + } + + t.wgDevice = wgDev t.Device = name return t, nil @@ -223,15 +319,37 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } var req ifReq - req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + // MUST match the flags used in newTun - includes IFF_VNET_HDR + req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + unix.Close(fd) return nil, err } + // Set nonblocking mode - CRITICAL for proper netpoller integration + if err = unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } + + // Get MTU from main device + mtu := t.MaxMTU + if mtu == 0 { + mtu = DefaultMTU + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") - return file, nil + // Create wireguard Device from the file descriptor (just like the main device) + wgDev, err := wgtun.CreateTUNFromFile(file, mtu) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err) + } + + // Return a wrapper that uses the wireguard Device for all I/O + return &wgDeviceWrapper{dev: wgDev}, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { @@ -239,7 +357,68 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } +func (t *tun) Read(b []byte) (int, error) { + if t.wgDevice != nil { + // Use wireguard device which handles virtio headers internally + bufs := [][]byte{b} + sizes := make([]int, 1) + n, err := t.wgDevice.Read(bufs, sizes, 0) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.EOF + } + return sizes[0], nil + } + + // Fallback: direct read from file (shouldn't happen in normal operation) + return t.ReadWriteCloser.Read(b) +} + +// BatchRead reads multiple packets at once for improved performance +// bufs: slice of buffers to read into +// sizes: slice that will be filled with packet sizes +// Returns number of packets read +func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) { + if t.wgDevice != nil { + return t.wgDevice.Read(bufs, sizes, 0) + } + + // Fallback: single packet read + n, err := t.ReadWriteCloser.Read(bufs[0]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// BatchSize returns the optimal number of packets to read/write in a batch +func (t *tun) BatchSize() int { + if t.wgDevice != nil { + return t.wgDevice.BatchSize() + } + return 1 +} + func (t *tun) Write(b []byte) (int, error) { + if t.wgDevice != nil { + // Buffer b should have virtio header space (10 bytes) at the beginning + // The decrypted packet data starts at offset 10 + // Pass the full buffer to WireGuard with offset=virtioNetHdrLen + bufs := [][]byte{b} + n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen) + if err != nil { + return 0, err + } + if n == 0 { + return 0, io.ErrShortWrite + } + return len(b), nil + } + + // Fallback: direct write (shouldn't happen in normal operation) var nn int maximum := len(b) @@ -262,6 +441,22 @@ func (t *tun) Write(b []byte) (int, error) { } } +// WriteBatch writes multiple packets to the TUN device in a single syscall +func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) { + if t.wgDevice != nil { + return t.wgDevice.Write(bufs, offset) + } + + // Fallback: write individually (shouldn't happen in normal operation) + for i, buf := range bufs { + _, err := t.Write(buf) + if err != nil { + return i, err + } + } + return len(bufs), nil +} + func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) @@ -674,6 +869,10 @@ func (t *tun) Close() error { close(t.routeChan) } + if t.wgDevice != nil { + _ = t.wgDevice.Close() + } + if t.ReadWriteCloser != nil { _ = t.ReadWriteCloser.Close() } diff --git a/stats.go b/stats.go index c88c45cc..b86919cc 100644 --- a/stats.go +++ b/stats.go @@ -6,6 +6,7 @@ import ( "log" "net" "net/http" + _ "net/http/pprof" "runtime" "strconv" "time" diff --git a/udp/conn.go b/udp/conn.go index 895b0df3..f3267c9e 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -13,12 +13,21 @@ type EncReader func( payload []byte, ) +type EncBatchReader func( + addrs []netip.AddrPort, + payloads [][]byte, + count int, +) + type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader) + ListenOutBatch(r EncBatchReader) WriteTo(b []byte, addr netip.AddrPort) error + WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) ReloadConfig(c *config.C) + BatchSize() int Close() error } @@ -33,12 +42,21 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) ListenOut(_ EncReader) { return } +func (NoopConn) ListenOutBatch(_ EncBatchReader) { + return +} func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) { + return 0, nil +} func (NoopConn) ReloadConfig(_ *config.C) { return } +func (NoopConn) BatchSize() int { + return 1 +} func (NoopConn) Close() error { return nil } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index f1cc102d..2ac47d1e 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } +// WriteMulti sends multiple packets - fallback implementation without sendmmsg +func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() @@ -184,6 +195,34 @@ func (u *StdConn) ListenOut(r EncReader) { } } +// ListenOutBatch - fallback to single-packet reads for Darwin +func (u *StdConn) ListenOutBatch(r EncBatchReader) { + buffer := make([]byte, MTU) + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + // Just read one packet at a time and call batch callback with count=1 + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + if errors.Is(err, net.ErrClosed) { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + u.l.WithError(err).Error("unexpected udp socket receive error") + } + + addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()) + payloads[0] = buffer[:n] + r(addrs, payloads, 1) + } +} + +func (u *StdConn) BatchSize() int { + return 1 +} + func (u *StdConn) Rebind() error { var err error if u.isV4 { diff --git a/udp/udp_generic.go b/udp/udp_generic.go index cb21e574..7c8cdf4b 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -85,3 +85,42 @@ func (u *GenericConn) ListenOut(r EncReader) { r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } + +// ListenOutBatch - fallback to single-packet reads for generic platforms +func (u *GenericConn) ListenOutBatch(r EncBatchReader) { + buffer := make([]byte, MTU) + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + // Just read one packet at a time and call batch callback with count=1 + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()) + payloads[0] = buffer[:n] + r(addrs, payloads, 1) + } +} + +// WriteMulti sends multiple packets - fallback implementation +func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + +func (u *GenericConn) BatchSize() int { + return 1 +} + +func (u *GenericConn) Rebind() error { + return nil +} diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ec0bf64b..efb71c3f 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -22,6 +22,11 @@ type StdConn struct { isV4 bool l *logrus.Logger batch int + + // Pre-allocated buffers for batch writes (sized for IPv6, works for both) + writeMsgs []rawMessage + writeIovecs []iovec + writeNames [][]byte } func maybeIPV4(ip net.IP) (net.IP, bool) { @@ -69,7 +74,26 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch} + + // Pre-allocate write message structures for batching (sized for IPv6, works for both) + c.writeMsgs = make([]rawMessage, batch) + c.writeIovecs = make([]iovec, batch) + c.writeNames = make([][]byte, batch) + + for i := range c.writeMsgs { + // Allocate for IPv6 size (larger than IPv4, works for both) + c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + + // Point to the iovec in the slice + c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i] + c.writeMsgs[i].Hdr.Iovlen = 1 + + c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0] + // Namelen will be set appropriately in writeMulti4/writeMulti6 + } + + return c, err } func (u *StdConn) Rebind() error { @@ -127,6 +151,8 @@ func (u *StdConn) ListenOut(r EncReader) { read = u.ReadSingle } + udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)) + for { n, err := read(msgs) if err != nil { @@ -134,6 +160,8 @@ func (u *StdConn) ListenOut(r EncReader) { return } + udpBatchHist.Update(int64(n)) + for i := 0; i < n; i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { @@ -146,6 +174,46 @@ func (u *StdConn) ListenOut(r EncReader) { } } +func (u *StdConn) ListenOutBatch(r EncBatchReader) { + var ip netip.Addr + + msgs, buffers, names := u.PrepareRawMessages(u.batch) + read := u.ReadMulti + if u.batch == 1 { + read = u.ReadSingle + } + + udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)) + + // Pre-allocate slices for batch callback + addrs := make([]netip.AddrPort, u.batch) + payloads := make([][]byte, u.batch) + + for { + n, err := read(msgs) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + udpBatchHist.Update(int64(n)) + + // Prepare batch data + for i := 0; i < n; i++ { + if u.isV4 { + ip, _ = netip.AddrFromSlice(names[i][4:8]) + } else { + ip, _ = netip.AddrFromSlice(names[i][8:24]) + } + addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + payloads[i] = buffers[i][:msgs[i].Len] + } + + // Call batch callback with all packets + r(addrs, payloads, n) + } +} + func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( @@ -194,6 +262,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return u.writeTo6(b, ip) } +func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + if len(packets) != len(addrs) { + return 0, fmt.Errorf("packets and addrs length mismatch") + } + if len(packets) == 0 { + return 0, nil + } + if u.isV4 { + return u.writeMulti4(packets, addrs) + } + return u.writeMulti6(packets, addrs) +} + func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 @@ -248,6 +329,123 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { } } +func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) { + sent := 0 + for sent < len(packets) { + // Determine batch size based on remaining packets and buffer capacity + batchSize := len(packets) - sent + if batchSize > len(u.writeMsgs) { + batchSize = len(u.writeMsgs) + } + + // Use pre-allocated buffers + msgs := u.writeMsgs[:batchSize] + iovecs := u.writeIovecs[:batchSize] + names := u.writeNames[:batchSize] + + // Setup message structures for this batch + for i := 0; i < batchSize; i++ { + pktIdx := sent + i + if !addrs[pktIdx].Addr().Is4() { + return sent + i, ErrInvalidIPv6RemoteForSocket + } + + // Setup the packet buffer + iovecs[i].Base = &packets[pktIdx][0] + iovecs[i].Len = uint(len(packets[pktIdx])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET + rsa.Addr = addrs[pktIdx].Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port()) + + // Set the appropriate address length for IPv4 + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4 + } + + // Send this batch + nsent, _, err := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(batchSize), + 0, + 0, + 0, + ) + + if err != 0 { + return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err} + } + + sent += int(nsent) + if int(nsent) < batchSize { + // Couldn't send all packets in batch, return what we sent + return sent, nil + } + } + + return sent, nil +} + +func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) { + sent := 0 + for sent < len(packets) { + // Determine batch size based on remaining packets and buffer capacity + batchSize := len(packets) - sent + if batchSize > len(u.writeMsgs) { + batchSize = len(u.writeMsgs) + } + + // Use pre-allocated buffers + msgs := u.writeMsgs[:batchSize] + iovecs := u.writeIovecs[:batchSize] + names := u.writeNames[:batchSize] + + // Setup message structures for this batch + for i := 0; i < batchSize; i++ { + pktIdx := sent + i + + // Setup the packet buffer + iovecs[i].Base = &packets[pktIdx][0] + iovecs[i].Len = uint(len(packets[pktIdx])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET6 + rsa.Addr = addrs[pktIdx].Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port()) + + // Set the appropriate address length for IPv6 + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6 + } + + // Send this batch + nsent, _, err := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(batchSize), + 0, + 0, + 0, + ) + + if err != 0 { + return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err} + } + + sent += int(nsent) + if int(nsent) < batchSize { + // Couldn't send all packets in batch, return what we sent + return sent, nil + } + } + + return sent, nil +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { @@ -305,6 +503,10 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { return nil } +func (u *StdConn) BatchSize() int { + return u.batch +} + func (u *StdConn) Close() error { return syscall.Close(u.sysFd) } diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cdf..707a2b1f 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -12,7 +12,7 @@ import ( type iovec struct { Base *byte - Len uint32 + Len uint } type msghdr struct { @@ -40,7 +40,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ - {Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, + {Base: &buffers[i][0], Len: uint(len(buffers[i]))}, } msgs[i].Hdr.Iov = &vs[0] diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..89c6695d 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -12,7 +12,7 @@ import ( type iovec struct { Base *byte - Len uint64 + Len uint } type msghdr struct { @@ -43,7 +43,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ - {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, + {Base: &buffers[i][0], Len: uint(len(buffers[i]))}, } msgs[i].Hdr.Iov = &vs[0] diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 8d5e6c14..d88d8aa6 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -116,6 +116,31 @@ func (u *TesterConn) ListenOut(r EncReader) { } } +func (u *TesterConn) ListenOutBatch(r EncBatchReader) { + addrs := make([]netip.AddrPort, 1) + payloads := make([][]byte, 1) + + for { + p, ok := <-u.RxPackets + if !ok { + return + } + addrs[0] = p.From + payloads[0] = p.Data + r(addrs, payloads, 1) + } +} + +func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) { + for i := range packets { + err := u.WriteTo(packets[i], addrs[i]) + if err != nil { + return i, err + } + } + return len(packets), nil +} + func (u *TesterConn) ReloadConfig(*config.C) {} func NewUDPStatsEmitter(_ []Conn) func() { @@ -127,6 +152,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil } +func (u *TesterConn) BatchSize() int { + return 1 +} + func (u *TesterConn) Rebind() error { return nil }