mirror of
https://github.com/slackhq/nebula.git
synced 2025-12-06 02:30:57 -08:00
remove tons of dead code
This commit is contained in:
parent
625acb7cc0
commit
bdfc2f5809
5 changed files with 1 additions and 853 deletions
|
|
@ -123,9 +123,6 @@ func NewDevice(options ...Option) (*Device, error) {
|
|||
if err = dev.refillReceiveQueue(); err != nil {
|
||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||
}
|
||||
if err = dev.refillTransmitQueue(); err != nil {
|
||||
return nil, fmt.Errorf("refill receive queue: %w", err)
|
||||
}
|
||||
|
||||
dev.initialized = true
|
||||
|
||||
|
|
@ -153,22 +150,6 @@ func (dev *Device) refillReceiveQueue() error {
|
|||
}
|
||||
}
|
||||
|
||||
func (dev *Device) refillTransmitQueue() error {
|
||||
//for {
|
||||
// desc, err := dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
||||
// if err != nil {
|
||||
// if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
|
||||
// // Queue is full, job is done.
|
||||
// return nil
|
||||
// }
|
||||
// return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
// } else {
|
||||
// dev.TransmitQueue.UsedRing().InitOfferSingle(desc, 0)
|
||||
// }
|
||||
//}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close cleans up the vhost networking device within the kernel and releases
|
||||
// all resources used for it.
|
||||
// The implementation will try to release as many resources as possible and
|
||||
|
|
@ -214,14 +195,6 @@ func (dev *Device) Close() error {
|
|||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// ensureInitialized is used as a guard to prevent methods to be called on an
|
||||
// uninitialized instance.
|
||||
func (dev *Device) ensureInitialized() {
|
||||
if !dev.initialized {
|
||||
panic("device is not initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// createQueue creates a new virtqueue and registers it with the vhost device
|
||||
// using the given index.
|
||||
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
|
||||
|
|
@ -238,30 +211,10 @@ func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*v
|
|||
return queue, nil
|
||||
}
|
||||
|
||||
// truncateBuffers returns a new list of buffers whose combined length matches
|
||||
// exactly the specified length. When the specified length exceeds the length of
|
||||
// the buffers, this is an error. When it is smaller, the buffer list will be
|
||||
// truncated accordingly.
|
||||
func truncateBuffers(buffers [][]byte, length int) (out [][]byte) {
|
||||
for _, buffer := range buffers {
|
||||
if length < len(buffer) {
|
||||
out = append(out, buffer[:length])
|
||||
return
|
||||
}
|
||||
out = append(out, buffer)
|
||||
length -= len(buffer)
|
||||
}
|
||||
if length > 0 {
|
||||
panic("length exceeds the combined length of all buffers")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
||||
var err error
|
||||
var idx uint16
|
||||
if !dev.fullTable {
|
||||
|
||||
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
||||
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
||||
dev.fullTable = true
|
||||
|
|
@ -393,7 +346,7 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
|||
//todo optimize?
|
||||
var chains []virtqueue.UsedElement
|
||||
var err error
|
||||
//if len(dev.extraRx) == 0 {
|
||||
|
||||
chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
|
@ -401,9 +354,6 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
|||
if len(chains) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
//} else {
|
||||
// chains = dev.extraRx
|
||||
//}
|
||||
|
||||
numPackets := 0
|
||||
chainsIdx := 0
|
||||
|
|
@ -418,10 +368,5 @@ func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
|||
chainsIdx += numChains
|
||||
}
|
||||
|
||||
// Now that we have copied all buffers, we can recycle the used descriptor chains
|
||||
//if err = dev.ReceiveQueue.OfferDescriptorChains(chains); err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
|
||||
return numPackets, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -172,115 +172,6 @@ func (dt *DescriptorTable) releaseBuffers() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// createDescriptorChain creates a new descriptor chain within the descriptor
|
||||
// table which contains a number of device-readable buffers (out buffers) and
|
||||
// device-writable buffers (in buffers).
|
||||
//
|
||||
// All buffers in the outBuffers slice will be concatenated by chaining
|
||||
// descriptors, one for each buffer in the slice. The size of the single buffers
|
||||
// must not exceed the size of a memory page (see [os.Getpagesize]).
|
||||
// When numInBuffers is greater than zero, the given number of device-writable
|
||||
// descriptors will be appended to the end of the chain, each referencing a
|
||||
// whole memory page.
|
||||
//
|
||||
// The index of the head of the new descriptor chain will be returned. Callers
|
||||
// should make sure to free the descriptor chain using [freeDescriptorChain]
|
||||
// after it was used by the device.
|
||||
//
|
||||
// When there are not enough free descriptors to hold the given number of
|
||||
// buffers, an [ErrNotEnoughFreeDescriptors] will be returned. In this case, the
|
||||
// caller should try again after some descriptor chains were used by the device
|
||||
// and returned back into the free chain.
|
||||
func (dt *DescriptorTable) createDescriptorChain(outBuffers [][]byte, numInBuffers int) (uint16, error) {
|
||||
// Calculate the number of descriptors needed to build the chain.
|
||||
numDesc := uint16(len(outBuffers) + numInBuffers)
|
||||
|
||||
// Descriptor chains must always contain at least one descriptor.
|
||||
if numDesc < 1 {
|
||||
return 0, ErrDescriptorChainEmpty
|
||||
}
|
||||
|
||||
// Do we still have enough free descriptors?
|
||||
if numDesc > dt.freeNum {
|
||||
return 0, ErrNotEnoughFreeDescriptors
|
||||
}
|
||||
|
||||
// Above validation ensured that there is at least one free descriptor, so
|
||||
// the free descriptor chain head should be valid.
|
||||
if dt.freeHeadIndex == noFreeHead {
|
||||
panic("free descriptor chain head is unset but there should be free descriptors")
|
||||
}
|
||||
|
||||
// To avoid having to iterate over the whole table to find the descriptor
|
||||
// pointing to the head just to replace the free head, we instead always
|
||||
// create descriptor chains from the descriptors coming after the head.
|
||||
// This way we only have to touch the head as a last resort, when all other
|
||||
// descriptors are already used.
|
||||
head := dt.descriptors[dt.freeHeadIndex].next
|
||||
next := head
|
||||
tail := head
|
||||
for i, buffer := range outBuffers {
|
||||
desc := &dt.descriptors[next]
|
||||
checkUnusedDescriptorLength(next, desc)
|
||||
|
||||
if len(buffer) > dt.itemSize {
|
||||
// The caller should already prevent that from happening.
|
||||
panic(fmt.Sprintf("out buffer %d has size %d which exceeds desc length %d", i, len(buffer), dt.itemSize))
|
||||
}
|
||||
|
||||
// Copy the buffer to the memory referenced by the descriptor.
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), dt.itemSize), buffer)
|
||||
desc.length = uint32(len(buffer))
|
||||
|
||||
// Clear the flags in case there were any others set.
|
||||
desc.flags = descriptorFlagHasNext
|
||||
|
||||
tail = next
|
||||
next = desc.next
|
||||
}
|
||||
for range numInBuffers {
|
||||
desc := &dt.descriptors[next]
|
||||
checkUnusedDescriptorLength(next, desc)
|
||||
|
||||
// Give the device the maximum available number of bytes to write into.
|
||||
desc.length = uint32(dt.itemSize)
|
||||
|
||||
// Mark the descriptor as device-writable.
|
||||
desc.flags = descriptorFlagHasNext | descriptorFlagWritable
|
||||
|
||||
tail = next
|
||||
next = desc.next
|
||||
}
|
||||
|
||||
// The last descriptor should end the chain.
|
||||
tailDesc := &dt.descriptors[tail]
|
||||
tailDesc.flags &= ^descriptorFlagHasNext
|
||||
tailDesc.next = 0 // Not necessary to clear this, it's just for looks.
|
||||
|
||||
dt.freeNum -= numDesc
|
||||
|
||||
if dt.freeNum == 0 {
|
||||
// The last descriptor in the chain should be the free chain head
|
||||
// itself.
|
||||
if tail != dt.freeHeadIndex {
|
||||
panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
|
||||
}
|
||||
|
||||
// When this new chain takes up all remaining descriptors, we no longer
|
||||
// have a free chain.
|
||||
dt.freeHeadIndex = noFreeHead
|
||||
} else {
|
||||
// We took some descriptors out of the free chain, so make sure to close
|
||||
// the circle again.
|
||||
dt.descriptors[dt.freeHeadIndex].next = next
|
||||
}
|
||||
|
||||
return head, nil
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
|
||||
//todo just fill the damn table
|
||||
// Do we still have enough free descriptors?
|
||||
|
|
@ -490,73 +381,6 @@ func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]by
|
|||
return nil
|
||||
}
|
||||
|
||||
func (dt *DescriptorTable) getDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
|
||||
if int(head) > len(dt.descriptors) {
|
||||
return 0, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
// Iterate over the chain. The iteration is limited to the queue size to
|
||||
// avoid ending up in an endless loop when things go very wrong.
|
||||
|
||||
length := 0
|
||||
//find length
|
||||
next := head
|
||||
for range len(dt.descriptors) {
|
||||
if next == dt.freeHeadIndex {
|
||||
return 0, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
desc := &dt.descriptors[next]
|
||||
|
||||
if desc.flags&descriptorFlagWritable == 0 {
|
||||
return 0, fmt.Errorf("receive queue contains device-readable buffer")
|
||||
}
|
||||
length += int(desc.length)
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Detect loops.
|
||||
if desc.next == head {
|
||||
return 0, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
|
||||
}
|
||||
|
||||
next = desc.next
|
||||
}
|
||||
if maxLen > 0 {
|
||||
//todo length = min(maxLen, length)
|
||||
}
|
||||
//set out to length:
|
||||
out = out[:length]
|
||||
|
||||
//now do the copying
|
||||
copied := 0
|
||||
for range len(dt.descriptors) {
|
||||
desc := &dt.descriptors[next]
|
||||
|
||||
// The descriptor address points to memory not managed by Go, so this
|
||||
// conversion is safe. See https://github.com/golang/go/issues/58625
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), min(uint32(length-copied), desc.length))
|
||||
copied += copy(out[copied:], bs)
|
||||
|
||||
// Is this the tail of the chain?
|
||||
if desc.flags&descriptorFlagHasNext == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// we did this already, no need to detect loops.
|
||||
next = desc.next
|
||||
}
|
||||
if copied != length {
|
||||
panic(fmt.Sprintf("expected to copy %d bytes but only copied %d bytes", length, copied))
|
||||
}
|
||||
|
||||
return length, nil
|
||||
}
|
||||
|
||||
// freeDescriptorChain can be used to free a descriptor chain when it is no
|
||||
// longer in use. The descriptor chain that starts with the given index will be
|
||||
// put back into the free chain, so the descriptors can be used for later calls
|
||||
|
|
|
|||
|
|
@ -1,407 +0,0 @@
|
|||
package virtqueue
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDescriptorTable_InitializeDescriptors(t *testing.T) {
|
||||
const queueSize = 32
|
||||
|
||||
dt := DescriptorTable{
|
||||
descriptors: make([]Descriptor, queueSize),
|
||||
}
|
||||
|
||||
assert.NoError(t, dt.initializeDescriptors())
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, dt.releaseBuffers())
|
||||
})
|
||||
|
||||
for i, descriptor := range dt.descriptors {
|
||||
assert.NotZero(t, descriptor.address)
|
||||
assert.Zero(t, descriptor.length)
|
||||
assert.EqualValues(t, descriptorFlagHasNext, descriptor.flags)
|
||||
assert.EqualValues(t, (i+1)%queueSize, descriptor.next)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDescriptorTable_DescriptorChains(t *testing.T) {
|
||||
// Use a very short queue size to not make this test overly verbose.
|
||||
const queueSize = 8
|
||||
|
||||
pageSize := os.Getpagesize() * 2
|
||||
|
||||
// Initialize descriptor table.
|
||||
dt := DescriptorTable{
|
||||
descriptors: make([]Descriptor, queueSize),
|
||||
}
|
||||
assert.NoError(t, dt.initializeDescriptors())
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, dt.releaseBuffers())
|
||||
})
|
||||
|
||||
// Some utilities for easier checking if the descriptor table looks as
|
||||
// expected.
|
||||
type desc struct {
|
||||
buffer []byte
|
||||
flags descriptorFlag
|
||||
next uint16
|
||||
}
|
||||
assertDescriptorTable := func(expected [queueSize]desc) {
|
||||
for i := 0; i < queueSize; i++ {
|
||||
actualDesc := &dt.descriptors[i]
|
||||
expectedDesc := &expected[i]
|
||||
assert.Equal(t, uint32(len(expectedDesc.buffer)), actualDesc.length)
|
||||
if len(expectedDesc.buffer) > 0 {
|
||||
//goland:noinspection GoVetUnsafePointer
|
||||
assert.EqualValues(t,
|
||||
unsafe.Slice((*byte)(unsafe.Pointer(actualDesc.address)), actualDesc.length),
|
||||
expectedDesc.buffer)
|
||||
}
|
||||
assert.Equal(t, expectedDesc.flags, actualDesc.flags)
|
||||
if expectedDesc.flags&descriptorFlagHasNext != 0 {
|
||||
assert.Equal(t, expectedDesc.next, actualDesc.next)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initial state: All descriptors are in the free chain.
|
||||
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
||||
assert.Equal(t, uint16(8), dt.freeNum)
|
||||
assertDescriptorTable([queueSize]desc{
|
||||
{
|
||||
// Free head.
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 1,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 2,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 3,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 4,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 5,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 6,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 7,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 0,
|
||||
},
|
||||
})
|
||||
|
||||
// Create the first chain.
|
||||
firstChain, err := dt.createDescriptorChain([][]byte{
|
||||
makeTestBuffer(t, 26),
|
||||
makeTestBuffer(t, 256),
|
||||
}, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint16(1), firstChain)
|
||||
|
||||
// Now there should be a new chain next to the free chain.
|
||||
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
||||
assert.Equal(t, uint16(5), dt.freeNum)
|
||||
assertDescriptorTable([queueSize]desc{
|
||||
{
|
||||
// Free head.
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 4,
|
||||
},
|
||||
{
|
||||
// Head of first chain.
|
||||
buffer: makeTestBuffer(t, 26),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 2,
|
||||
},
|
||||
{
|
||||
buffer: makeTestBuffer(t, 256),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 3,
|
||||
},
|
||||
{
|
||||
// Tail of first chain.
|
||||
buffer: make([]byte, pageSize),
|
||||
flags: descriptorFlagWritable,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 5,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 6,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 7,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 0,
|
||||
},
|
||||
})
|
||||
|
||||
// Create a second chain with only a single in buffer.
|
||||
secondChain, err := dt.createDescriptorChain(nil, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint16(4), secondChain)
|
||||
|
||||
// Now there should be two chains next to the free chain.
|
||||
assert.Equal(t, uint16(0), dt.freeHeadIndex)
|
||||
assert.Equal(t, uint16(4), dt.freeNum)
|
||||
assertDescriptorTable([queueSize]desc{
|
||||
{
|
||||
// Free head.
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 5,
|
||||
},
|
||||
{
|
||||
// Head of the first chain.
|
||||
buffer: makeTestBuffer(t, 26),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 2,
|
||||
},
|
||||
{
|
||||
buffer: makeTestBuffer(t, 256),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 3,
|
||||
},
|
||||
{
|
||||
// Tail of the first chain.
|
||||
buffer: make([]byte, pageSize),
|
||||
flags: descriptorFlagWritable,
|
||||
},
|
||||
{
|
||||
// Head and tail of the second chain.
|
||||
buffer: make([]byte, pageSize),
|
||||
flags: descriptorFlagWritable,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 6,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 7,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 0,
|
||||
},
|
||||
})
|
||||
|
||||
// Create a third chain taking up all remaining descriptors.
|
||||
thirdChain, err := dt.createDescriptorChain([][]byte{
|
||||
makeTestBuffer(t, 42),
|
||||
makeTestBuffer(t, 96),
|
||||
makeTestBuffer(t, 33),
|
||||
makeTestBuffer(t, 222),
|
||||
}, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint16(5), thirdChain)
|
||||
|
||||
// Now there should be three chains and no free chain.
|
||||
assert.Equal(t, noFreeHead, dt.freeHeadIndex)
|
||||
assert.Equal(t, uint16(0), dt.freeNum)
|
||||
assertDescriptorTable([queueSize]desc{
|
||||
{
|
||||
// Tail of the third chain.
|
||||
buffer: makeTestBuffer(t, 222),
|
||||
},
|
||||
{
|
||||
// Head of the first chain.
|
||||
buffer: makeTestBuffer(t, 26),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 2,
|
||||
},
|
||||
{
|
||||
buffer: makeTestBuffer(t, 256),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 3,
|
||||
},
|
||||
{
|
||||
// Tail of the first chain.
|
||||
buffer: make([]byte, pageSize),
|
||||
flags: descriptorFlagWritable,
|
||||
},
|
||||
{
|
||||
// Head and tail of the second chain.
|
||||
buffer: make([]byte, pageSize),
|
||||
flags: descriptorFlagWritable,
|
||||
},
|
||||
{
|
||||
// Head of the third chain.
|
||||
buffer: makeTestBuffer(t, 42),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 6,
|
||||
},
|
||||
{
|
||||
buffer: makeTestBuffer(t, 96),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 7,
|
||||
},
|
||||
{
|
||||
buffer: makeTestBuffer(t, 33),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 0,
|
||||
},
|
||||
})
|
||||
|
||||
// Free the third chain.
|
||||
assert.NoError(t, dt.freeDescriptorChain(thirdChain))
|
||||
|
||||
// Now there should be two chains and a free chain again.
|
||||
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
||||
assert.Equal(t, uint16(4), dt.freeNum)
|
||||
assertDescriptorTable([queueSize]desc{
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 5,
|
||||
},
|
||||
{
|
||||
// Head of the first chain.
|
||||
buffer: makeTestBuffer(t, 26),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 2,
|
||||
},
|
||||
{
|
||||
buffer: makeTestBuffer(t, 256),
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 3,
|
||||
},
|
||||
{
|
||||
// Tail of the first chain.
|
||||
buffer: make([]byte, pageSize),
|
||||
flags: descriptorFlagWritable,
|
||||
},
|
||||
{
|
||||
// Head and tail of the second chain.
|
||||
buffer: make([]byte, pageSize),
|
||||
flags: descriptorFlagWritable,
|
||||
},
|
||||
{
|
||||
// Free head.
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 6,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 7,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 0,
|
||||
},
|
||||
})
|
||||
|
||||
// Free the first chain.
|
||||
assert.NoError(t, dt.freeDescriptorChain(firstChain))
|
||||
|
||||
// Now there should be only a single chain next to the free chain.
|
||||
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
||||
assert.Equal(t, uint16(7), dt.freeNum)
|
||||
assertDescriptorTable([queueSize]desc{
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 5,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 2,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 3,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 6,
|
||||
},
|
||||
{
|
||||
// Head and tail of the second chain.
|
||||
buffer: make([]byte, pageSize),
|
||||
flags: descriptorFlagWritable,
|
||||
},
|
||||
{
|
||||
// Free head.
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 1,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 7,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 0,
|
||||
},
|
||||
})
|
||||
|
||||
// Free the second chain.
|
||||
assert.NoError(t, dt.freeDescriptorChain(secondChain))
|
||||
|
||||
// Now all descriptors should be in the free chain again.
|
||||
assert.Equal(t, uint16(5), dt.freeHeadIndex)
|
||||
assert.Equal(t, uint16(8), dt.freeNum)
|
||||
assertDescriptorTable([queueSize]desc{
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 5,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 2,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 3,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 6,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 1,
|
||||
},
|
||||
{
|
||||
// Free head.
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 4,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 7,
|
||||
},
|
||||
{
|
||||
flags: descriptorFlagHasNext,
|
||||
next: 0,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func makeTestBuffer(t *testing.T, length int) []byte {
|
||||
t.Helper()
|
||||
buf := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
buf[i] = byte(length - i)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/eventfd"
|
||||
"golang.org/x/sys/unix"
|
||||
|
|
@ -186,28 +185,6 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error {
|
|||
}
|
||||
}
|
||||
|
||||
// BlockAndGetHeads waits for the device to signal that it has used descriptor chains and returns all [UsedElement]s
|
||||
func (sq *SplitQueue) BlockAndGetHeads(ctx context.Context) ([]UsedElement, error) {
|
||||
var n int
|
||||
var err error
|
||||
for ctx.Err() == nil {
|
||||
|
||||
// Wait for a signal from the device.
|
||||
if n, err = sq.epoll.Block(); err != nil {
|
||||
return nil, fmt.Errorf("wait: %w", err)
|
||||
}
|
||||
if n > 0 {
|
||||
stillNeedToTake, out := sq.usedRing.take(-1)
|
||||
sq.more = stillNeedToTake
|
||||
if stillNeedToTake == 0 {
|
||||
_ = sq.epoll.Clear() //???
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
|
||||
var n int
|
||||
var err error
|
||||
|
|
@ -326,53 +303,6 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
|
|||
return head, nil
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) OfferOutDescriptorChains(prepend []byte, outBuffers [][]byte) ([]uint16, error) {
|
||||
// TODO change this
|
||||
// Each descriptor can only hold a whole memory page, so split large out
|
||||
// buffers into multiple smaller ones.
|
||||
outBuffers = splitBuffers(outBuffers, sq.itemSize)
|
||||
|
||||
chains := make([]uint16, len(outBuffers))
|
||||
|
||||
// Create a descriptor chain for the given buffers.
|
||||
var (
|
||||
head uint16
|
||||
err error
|
||||
)
|
||||
for i := range outBuffers {
|
||||
for {
|
||||
bufs := [][]byte{prepend, outBuffers[i]}
|
||||
head, err = sq.descriptorTable.createDescriptorChain(bufs, 0)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// I don't wanna use errors.Is, it's slow
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if err == ErrNotEnoughFreeDescriptors {
|
||||
// Wait for more free descriptors to be put back into the queue.
|
||||
// If the number of free descriptors is still not sufficient, we'll
|
||||
// land here again.
|
||||
//todo should never happen
|
||||
syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0) // Cheap barrier
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("create descriptor chain: %w", err)
|
||||
}
|
||||
chains[i] = head
|
||||
}
|
||||
|
||||
// Make the descriptor chain available to the device.
|
||||
sq.availableRing.offer(chains)
|
||||
|
||||
// Notify the device to make it process the updated available ring.
|
||||
if err := sq.kickEventFD.Kick(); err != nil {
|
||||
return chains, fmt.Errorf("notify device: %w", err)
|
||||
}
|
||||
|
||||
return chains, nil
|
||||
}
|
||||
|
||||
// GetDescriptorChain returns the device-readable buffers (out buffers) and
|
||||
// device-writable buffers (in buffers) of the descriptor chain with the given
|
||||
// head index.
|
||||
|
|
@ -392,10 +322,6 @@ func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
|
|||
return sq.descriptorTable.getDescriptorItem(head)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) GetDescriptorChainContents(head uint16, out []byte, maxLen int) (int, error) {
|
||||
return sq.descriptorTable.getDescriptorChainContents(head, out, maxLen)
|
||||
}
|
||||
|
||||
func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
|
||||
return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
|
||||
}
|
||||
|
|
@ -486,14 +412,6 @@ func (sq *SplitQueue) Close() error {
|
|||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// ensureInitialized is used as a guard to prevent methods to be called on an
|
||||
// uninitialized instance.
|
||||
func (sq *SplitQueue) ensureInitialized() {
|
||||
if sq.buf == nil {
|
||||
panic("used ring is not initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func align(index, alignment int) int {
|
||||
remainder := index % alignment
|
||||
if remainder == 0 {
|
||||
|
|
@ -501,30 +419,3 @@ func align(index, alignment int) int {
|
|||
}
|
||||
return index + alignment - remainder
|
||||
}
|
||||
|
||||
// splitBuffers processes a list of buffers and splits each buffer that is
|
||||
// larger than the size limit into multiple smaller buffers.
|
||||
// If none of the buffers are too big though, do nothing, to avoid allocation for now
|
||||
func splitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
|
||||
for i := range buffers {
|
||||
if len(buffers[i]) > sizeLimit {
|
||||
return reallySplitBuffers(buffers, sizeLimit)
|
||||
}
|
||||
}
|
||||
return buffers
|
||||
}
|
||||
|
||||
func reallySplitBuffers(buffers [][]byte, sizeLimit int) [][]byte {
|
||||
result := make([][]byte, 0, len(buffers))
|
||||
for _, buffer := range buffers {
|
||||
for added := 0; added < len(buffer); added += sizeLimit {
|
||||
if len(buffer)-added <= sizeLimit {
|
||||
result = append(result, buffer[added:])
|
||||
break
|
||||
}
|
||||
result = append(result, buffer[added:added+sizeLimit])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,105 +0,0 @@
|
|||
package virtqueue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSplitQueue_MemoryAlignment(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queueSize int
|
||||
}{
|
||||
{
|
||||
name: "minimal queue size",
|
||||
queueSize: 1,
|
||||
},
|
||||
{
|
||||
name: "small queue size",
|
||||
queueSize: 8,
|
||||
},
|
||||
{
|
||||
name: "large queue size",
|
||||
queueSize: 256,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sq, err := NewSplitQueue(tt.queueSize)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Zero(t, sq.descriptorTable.Address()%descriptorTableAlignment)
|
||||
assert.Zero(t, sq.availableRing.Address()%availableRingAlignment)
|
||||
assert.Zero(t, sq.usedRing.Address()%usedRingAlignment)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitBuffers(t *testing.T) {
|
||||
const sizeLimit = 16
|
||||
tests := []struct {
|
||||
name string
|
||||
buffers [][]byte
|
||||
expected [][]byte
|
||||
}{
|
||||
{
|
||||
name: "no buffers",
|
||||
buffers: make([][]byte, 0),
|
||||
expected: make([][]byte, 0),
|
||||
},
|
||||
{
|
||||
name: "small",
|
||||
buffers: [][]byte{
|
||||
make([]byte, 11),
|
||||
},
|
||||
expected: [][]byte{
|
||||
make([]byte, 11),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exact size",
|
||||
buffers: [][]byte{
|
||||
make([]byte, sizeLimit),
|
||||
},
|
||||
expected: [][]byte{
|
||||
make([]byte, sizeLimit),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "large",
|
||||
buffers: [][]byte{
|
||||
make([]byte, 42),
|
||||
},
|
||||
expected: [][]byte{
|
||||
make([]byte, 16),
|
||||
make([]byte, 16),
|
||||
make([]byte, 10),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed",
|
||||
buffers: [][]byte{
|
||||
make([]byte, 7),
|
||||
make([]byte, 30),
|
||||
make([]byte, 15),
|
||||
make([]byte, 32),
|
||||
},
|
||||
expected: [][]byte{
|
||||
make([]byte, 7),
|
||||
make([]byte, 16),
|
||||
make([]byte, 14),
|
||||
make([]byte, 15),
|
||||
make([]byte, 16),
|
||||
make([]byte, 16),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := splitBuffers(tt.buffers, sizeLimit)
|
||||
assert.Equal(t, tt.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue