diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index af6d2e85..7c2bcd03 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -70,72 +70,55 @@ const ( virtioNetHdrLen = 10 // Size of virtio_net_hdr structure ) -// tunVirtioReader wraps a file descriptor that has IFF_VNET_HDR enabled -// and strips the virtio header on reads, adds it on writes -type tunVirtioReader struct { - f *os.File - buf [virtioNetHdrLen + 65535]byte // Space for header + max packet +// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser +// This allows multiqueue readers to use the same wireguard Device batching as the main device +type wgDeviceWrapper struct { + dev wgtun.Device + buf []byte // Reusable buffer for single packet reads } -func (r *tunVirtioReader) Read(b []byte) (int, error) { - // Read into our buffer which has space for the virtio header - n, err := r.f.Read(r.buf[:]) +func (w *wgDeviceWrapper) Read(b []byte) (int, error) { + // Use wireguard Device's batch API for single packet + bufs := [][]byte{b} + sizes := make([]int, 1) + n, err := w.dev.Read(bufs, sizes, 0) if err != nil { return 0, err } - - // Strip the virtio header (first 10 bytes) - if n < virtioNetHdrLen { - return 0, fmt.Errorf("packet too short: %d bytes", n) + if n == 0 { + return 0, io.EOF } - - // Copy payload (after header) to destination - copy(b, r.buf[virtioNetHdrLen:n]) - return n - virtioNetHdrLen, nil + return sizes[0], nil } -func (r *tunVirtioReader) Write(b []byte) (int, error) { - // Zero out the virtio header (no offload from userspace write) - for i := 0; i < virtioNetHdrLen; i++ { - r.buf[i] = 0 - } +func (w *wgDeviceWrapper) Write(b []byte) (int, error) { + // Allocate buffer with space for virtio header + buf := make([]byte, virtioNetHdrLen+len(b)) + copy(buf[virtioNetHdrLen:], b) - // Copy packet data after header - copy(r.buf[virtioNetHdrLen:], b) - - // Write with header prepended - n, err := r.f.Write(r.buf[:virtioNetHdrLen+len(b)]) + bufs := [][]byte{buf} + n, err := w.dev.Write(bufs, virtioNetHdrLen) if err != nil { return 0, err } - - // Return payload size (excluding header) - return n - virtioNetHdrLen, nil -} - -func (r *tunVirtioReader) Close() error { - return r.f.Close() -} - -// BatchRead reads multiple packets at once for performance -// This is not used for multiqueue readers as they use direct file I/O -// Returns number of packets read -func (r *tunVirtioReader) BatchRead(bufs [][]byte, sizes []int) (int, error) { - // For multiqueue file descriptors, we don't have the wireguard Device interface - // Fall back to single packet reads - // TODO: Could implement proper batching with unix.Recvmmsg - n, err := r.Read(bufs[0]) - if err != nil { - return 0, err + if n == 0 { + return 0, io.ErrShortWrite } - sizes[0] = n - return 1, nil + return len(b), nil } -// BatchSize returns the batch size for multiqueue readers -func (r *tunVirtioReader) BatchSize() int { - // Multiqueue readers use single packet mode for now - return 1 +func (w *wgDeviceWrapper) Close() error { + return w.dev.Close() +} + +// BatchRead implements batching for multiqueue readers +func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) { + return w.dev.Read(bufs, sizes, 0) +} + +// BatchSize returns the optimal batch size +func (w *wgDeviceWrapper) BatchSize() int { + return w.dev.BatchSize() } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { @@ -347,10 +330,29 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, err } + // Set nonblocking mode - CRITICAL for proper netpoller integration + if err = unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } + + // Get MTU from main device + mtu := t.MaxMTU + if mtu == 0 { + mtu = DefaultMTU + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") - // Wrap in virtio header handler - return &tunVirtioReader{f: file}, nil + // Create wireguard Device from the file descriptor (just like the main device) + wgDev, err := wgtun.CreateTUNFromFile(file, mtu) + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err) + } + + // Return a wrapper that uses the wireguard Device for all I/O + return &wgDeviceWrapper{dev: wgDev}, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {