diff --git a/inside.go b/inside.go index 25cb1c1f..d2000129 100644 --- a/inside.go +++ b/inside.go @@ -18,14 +18,16 @@ import ( // outs: slice of output buffers (one per packet) with virtio headroom // q: queue index // localCache: firewall conntrack cache -func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, q int, localCache firewall.ConntrackCache) { +// batchPackets: pre-allocated slice for accumulating encrypted packets +// batchAddrs: pre-allocated slice for accumulating destination addresses +func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) { // Reusable per-packet state fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - // Accumulate encrypted packets for batch sending - batchPackets := make([][]byte, 0, count) - batchAddrs := make([]netip.AddrPort, 0, count) + // Reset batch accumulation slices (reuse capacity) + *batchPackets = (*batchPackets)[:0] + *batchAddrs = (*batchAddrs)[:0] // Process each packet in the batch for i := 0; i < count; i++ { @@ -137,15 +139,15 @@ func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count in } // Add to batch - batchPackets = append(batchPackets, out) - batchAddrs = append(batchAddrs, hostinfo.remote) + *batchPackets = append(*batchPackets, out) + *batchAddrs = append(*batchAddrs, hostinfo.remote) } // Send all accumulated packets in one batch - if len(batchPackets) > 0 { - n, err := f.writers[q].WriteMulti(batchPackets, batchAddrs) + if len(*batchPackets) > 0 { + n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs) if err != nil { - f.l.WithError(err).WithField("sent", n).WithField("total", len(batchPackets)).Error("Failed to send batch") + f.l.WithError(err).WithField("sent", n).WithField("total", len(*batchPackets)).Error("Failed to send batch") } } } diff --git a/interface.go b/interface.go index 01e7dbfd..fe8dcc75 100644 --- a/interface.go +++ b/interface.go @@ -348,6 +348,10 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom } + // Pre-allocate batch accumulation buffers for sending + batchPackets := make([][]byte, 0, batchSize) + batchAddrs := make([]netip.AddrPort, 0, batchSize) + conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) for { @@ -363,7 +367,7 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe } // Process all packets in the batch at once - f.consumeInsidePackets(bufs, sizes, n, outs, i, conntrackCache.Get(f.l)) + f.consumeInsidePackets(bufs, sizes, n, outs, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index a6dfa7ae..2f8b28c0 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -22,6 +22,11 @@ type StdConn struct { isV4 bool l *logrus.Logger batch int + + // Pre-allocated buffers for batch writes (sized for IPv6, works for both) + writeMsgs []rawMessage + writeIovecs []iovec + writeNames [][]byte } func maybeIPV4(ip net.IP) (net.IP, bool) { @@ -69,7 +74,26 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch} + + // Pre-allocate write message structures for batching (sized for IPv6, works for both) + c.writeMsgs = make([]rawMessage, batch) + c.writeIovecs = make([]iovec, batch) + c.writeNames = make([][]byte, batch) + + for i := range c.writeMsgs { + // Allocate for IPv6 size (larger than IPv4, works for both) + c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + + // Point to the iovec in the slice + c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i] + c.writeMsgs[i].Hdr.Iovlen = 1 + + c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0] + // Namelen will be set appropriately in writeMulti4/writeMulti6 + } + + return c, err } func (u *StdConn) SupportsMultipleReaders() bool { @@ -266,75 +290,120 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { } func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) { - msgs, iovecs, names := u.PrepareWriteMessages4(len(packets)) - - for i := range packets { - if !addrs[i].Addr().Is4() { - return i, ErrInvalidIPv6RemoteForSocket + sent := 0 + for sent < len(packets) { + // Determine batch size based on remaining packets and buffer capacity + batchSize := len(packets) - sent + if batchSize > len(u.writeMsgs) { + batchSize = len(u.writeMsgs) } - // Setup the packet buffer - iovecs[i].Base = &packets[i][0] - iovecs[i].Len = uint64(len(packets[i])) + // Use pre-allocated buffers + msgs := u.writeMsgs[:batchSize] + iovecs := u.writeIovecs[:batchSize] + names := u.writeNames[:batchSize] - // Setup the destination address - rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0])) - rsa.Family = unix.AF_INET - rsa.Addr = addrs[i].Addr().As4() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port()) - } + // Setup message structures for this batch + for i := 0; i < batchSize; i++ { + pktIdx := sent + i + if !addrs[pktIdx].Addr().Is4() { + return sent + i, ErrInvalidIPv6RemoteForSocket + } - for { - n, _, err := unix.Syscall6( + // Setup the packet buffer + iovecs[i].Base = &packets[pktIdx][0] + iovecs[i].Len = uint64(len(packets[pktIdx])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET + rsa.Addr = addrs[pktIdx].Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port()) + + // Set the appropriate address length for IPv4 + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4 + } + + // Send this batch + nsent, _, err := unix.Syscall6( unix.SYS_SENDMMSG, uintptr(u.sysFd), uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), + uintptr(batchSize), 0, 0, 0, ) if err != 0 { - return int(n), &net.OpError{Op: "sendmmsg", Err: err} + return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err} } - return int(n), nil + sent += int(nsent) + if int(nsent) < batchSize { + // Couldn't send all packets in batch, return what we sent + return sent, nil + } } + + return sent, nil } func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) { - msgs, iovecs, names := u.PrepareWriteMessages6(len(packets)) + sent := 0 + for sent < len(packets) { + // Determine batch size based on remaining packets and buffer capacity + batchSize := len(packets) - sent + if batchSize > len(u.writeMsgs) { + batchSize = len(u.writeMsgs) + } - for i := range packets { - // Setup the packet buffer - iovecs[i].Base = &packets[i][0] - iovecs[i].Len = uint64(len(packets[i])) + // Use pre-allocated buffers + msgs := u.writeMsgs[:batchSize] + iovecs := u.writeIovecs[:batchSize] + names := u.writeNames[:batchSize] - // Setup the destination address - rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0])) - rsa.Family = unix.AF_INET6 - rsa.Addr = addrs[i].Addr().As16() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port()) - } + // Setup message structures for this batch + for i := 0; i < batchSize; i++ { + pktIdx := sent + i - for { - n, _, err := unix.Syscall6( + // Setup the packet buffer + iovecs[i].Base = &packets[pktIdx][0] + iovecs[i].Len = uint64(len(packets[pktIdx])) + + // Setup the destination address + rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0])) + rsa.Family = unix.AF_INET6 + rsa.Addr = addrs[pktIdx].Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port()) + + // Set the appropriate address length for IPv6 + msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6 + } + + // Send this batch + nsent, _, err := unix.Syscall6( unix.SYS_SENDMMSG, uintptr(u.sysFd), uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), + uintptr(batchSize), 0, 0, 0, ) if err != 0 { - return int(n), &net.OpError{Op: "sendmmsg", Err: err} + return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err} } - return int(n), nil + sent += int(nsent) + if int(nsent) < batchSize { + // Couldn't send all packets in batch, return what we sent + return sent, nil + } } + + return sent, nil } func (u *StdConn) ReloadConfig(c *config.C) { diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 36ce8a4a..48c5a978 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -55,41 +55,3 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } - -func (u *StdConn) PrepareWriteMessages4(n int) ([]rawMessage, []iovec, [][]byte) { - msgs := make([]rawMessage, n) - iovecs := make([]iovec, n) - names := make([][]byte, n) - - for i := range msgs { - names[i] = make([]byte, unix.SizeofSockaddrInet4) - - // Point to the iovec in the slice - msgs[i].Hdr.Iov = &iovecs[i] - msgs[i].Hdr.Iovlen = 1 - - msgs[i].Hdr.Name = &names[i][0] - msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4 - } - - return msgs, iovecs, names -} - -func (u *StdConn) PrepareWriteMessages6(n int) ([]rawMessage, []iovec, [][]byte) { - msgs := make([]rawMessage, n) - iovecs := make([]iovec, n) - names := make([][]byte, n) - - for i := range msgs { - names[i] = make([]byte, unix.SizeofSockaddrInet6) - - // Point to the iovec in the slice - msgs[i].Hdr.Iov = &iovecs[i] - msgs[i].Hdr.Iovlen = 1 - - msgs[i].Hdr.Name = &names[i][0] - msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6 - } - - return msgs, iovecs, names -}