Skip to main content

virtio_drivers/
queue.rs

1//! Support for virt queues, the main mechanism for data transport on VirtIO devices.
2//!
3//! Types from this module are used to implement VirtIO device drivers. If you just want to use the
4//! drivers provided (rather than implementing drivers for other devices) then you shouldn't need to
5//! use anything from this module.
6
7#[cfg(feature = "alloc")]
8mod owning;
9
10#[cfg(feature = "alloc")]
11pub use self::owning::OwningQueue;
12use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
13use crate::transport::Transport;
14use crate::{Error, PAGE_SIZE, Result, align_up, pages};
15#[cfg(feature = "alloc")]
16use alloc::boxed::Box;
17use bitflags::bitflags;
18#[cfg(test)]
19use core::cmp::min;
20use core::convert::TryInto;
21use core::hint::spin_loop;
22use core::mem::{size_of, take};
23#[cfg(test)]
24use core::ptr;
25use core::ptr::NonNull;
26use core::sync::atomic::{AtomicU16, Ordering, fence};
27use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes, KnownLayout};
28
29/// The mechanism for bulk data transport on virtio devices.
30///
31/// Each device can have zero or more virtqueues.
32///
33/// * `SIZE`: The size of the queue. This is both the number of descriptors, and the number of slots
34///   in the available and used rings. It must be a power of 2 and fit in a [`u16`].
35#[derive(Debug)]
36pub struct VirtQueue<H: Hal, const SIZE: usize> {
37    /// DMA guard
38    layout: VirtQueueLayout<H>,
39    /// Descriptor table
40    ///
41    /// The device may be able to modify this, even though it's not supposed to, so we shouldn't
42    /// trust values read back from it. Use `desc_shadow` instead to keep track of what we wrote to
43    /// it.
44    desc: NonNull<[Descriptor]>,
45    /// Available ring
46    ///
47    /// The device may be able to modify this, even though it's not supposed to, so we shouldn't
48    /// trust values read back from it. The only field we need to read currently is `idx`, so we
49    /// have `avail_idx` below to use instead.
50    avail: NonNull<AvailRing<SIZE>>,
51    /// Used ring
52    used: NonNull<UsedRing<SIZE>>,
53
54    /// The index of queue
55    queue_idx: u16,
56    /// The number of descriptors currently in use.
57    num_used: u16,
58    /// The head desc index of the free list.
59    free_head: u16,
60    /// Our trusted copy of `desc` that the device can't access.
61    desc_shadow: [Descriptor; SIZE],
62    /// Our trusted copy of `avail.idx`.
63    avail_idx: u16,
64    last_used_idx: u16,
65    /// Whether the `VIRTIO_F_EVENT_IDX` feature has been negotiated.
66    event_idx: bool,
67    #[cfg(feature = "alloc")]
68    indirect: bool,
69    #[cfg(feature = "alloc")]
70    indirect_lists: [Option<NonNull<[Descriptor]>>; SIZE],
71}
72
73impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
74    const SIZE_OK: () = assert!(SIZE.is_power_of_two() && SIZE <= u16::MAX as usize);
75
76    /// Creates a new VirtQueue.
77    ///
78    /// * `indirect`: Whether to use indirect descriptors. This should be set if the
79    ///   `VIRTIO_F_INDIRECT_DESC` feature has been negotiated with the device.
80    /// * `event_idx`: Whether to use the `used_event` and `avail_event` fields for notification
81    ///   suppression. This should be set if the `VIRTIO_F_EVENT_IDX` feature has been negotiated
82    ///   with the device.
83    pub fn new<T: Transport>(
84        transport: &mut T,
85        idx: u16,
86        indirect: bool,
87        event_idx: bool,
88    ) -> Result<Self> {
89        #[allow(clippy::let_unit_value)]
90        let _ = Self::SIZE_OK;
91
92        if transport.queue_used(idx) {
93            return Err(Error::AlreadyUsed);
94        }
95        if transport.max_queue_size(idx) < SIZE as u32 {
96            return Err(Error::InvalidParam);
97        }
98        let size = SIZE as u16;
99
100        let layout = if transport.requires_legacy_layout() {
101            VirtQueueLayout::allocate_legacy(size)?
102        } else {
103            VirtQueueLayout::allocate_flexible(size)?
104        };
105
106        transport.queue_set(
107            idx,
108            size.into(),
109            layout.descriptors_paddr(),
110            layout.driver_area_paddr(),
111            layout.device_area_paddr(),
112        );
113
114        let desc =
115            NonNull::slice_from_raw_parts(layout.descriptors_vaddr().cast::<Descriptor>(), SIZE);
116        let avail = layout.avail_vaddr().cast();
117        let used = layout.used_vaddr().cast();
118
119        let mut desc_shadow: [Descriptor; SIZE] = FromZeros::new_zeroed();
120        // Link descriptors together.
121        for i in 0..(size - 1) {
122            desc_shadow[i as usize].next = i + 1;
123            // SAFETY: `desc` is properly aligned, dereferenceable, initialised,
124            // and the device won't access the descriptors for the duration of this unsafe block.
125            unsafe {
126                (*desc.as_ptr())[i as usize].next = i + 1;
127            }
128        }
129
130        #[cfg(feature = "alloc")]
131        const NONE: Option<NonNull<[Descriptor]>> = None;
132        Ok(VirtQueue {
133            layout,
134            desc,
135            avail,
136            used,
137            queue_idx: idx,
138            num_used: 0,
139            free_head: 0,
140            desc_shadow,
141            avail_idx: 0,
142            last_used_idx: 0,
143            event_idx,
144            #[cfg(feature = "alloc")]
145            indirect,
146            #[cfg(feature = "alloc")]
147            indirect_lists: [NONE; SIZE],
148        })
149    }
150
151    /// Add buffers to the virtqueue, return a token.
152    ///
153    /// The buffers must not be empty.
154    ///
155    /// Ref: linux virtio_ring.c virtqueue_add
156    ///
157    /// # Safety
158    ///
159    /// The input and output buffers must remain valid and not be accessed until a call to
160    /// `pop_used` with the returned token succeeds.
161    pub unsafe fn add<'a, 'b>(
162        &mut self,
163        inputs: &'a [&'b [u8]],
164        outputs: &'a mut [&'b mut [u8]],
165    ) -> Result<u16> {
166        if inputs.is_empty() && outputs.is_empty() {
167            return Err(Error::InvalidParam);
168        }
169        let descriptors_needed = inputs.len() + outputs.len();
170        // Only consider indirect descriptors if the alloc feature is enabled, as they require
171        // allocation.
172        #[cfg(feature = "alloc")]
173        if self.num_used as usize + 1 > SIZE
174            || descriptors_needed > SIZE
175            || (!self.indirect && self.num_used as usize + descriptors_needed > SIZE)
176        {
177            return Err(Error::QueueFull);
178        }
179        #[cfg(not(feature = "alloc"))]
180        if self.num_used as usize + descriptors_needed > SIZE {
181            return Err(Error::QueueFull);
182        }
183
184        #[cfg(feature = "alloc")]
185        let head = if self.indirect && descriptors_needed > 1 {
186            self.add_indirect(inputs, outputs)
187        } else {
188            self.add_direct(inputs, outputs)
189        };
190        #[cfg(not(feature = "alloc"))]
191        let head = self.add_direct(inputs, outputs);
192
193        let avail_slot = self.avail_idx & (SIZE as u16 - 1);
194        // SAFETY: `self.avail` is properly aligned, dereferenceable and initialised.
195        unsafe {
196            (*self.avail.as_ptr()).ring[avail_slot as usize] = head;
197        }
198
199        // Write barrier so that device sees changes to descriptor table and available ring before
200        // change to available index.
201        fence(Ordering::SeqCst);
202
203        // increase head of avail ring
204        self.avail_idx = self.avail_idx.wrapping_add(1);
205        // SAFETY: `self.avail` is properly aligned, dereferenceable and initialised.
206        unsafe {
207            (*self.avail.as_ptr())
208                .idx
209                .store(self.avail_idx, Ordering::Release);
210        }
211
212        Ok(head)
213    }
214
215    fn add_direct<'a, 'b>(
216        &mut self,
217        inputs: &'a [&'b [u8]],
218        outputs: &'a mut [&'b mut [u8]],
219    ) -> u16 {
220        // allocate descriptors from free list
221        let head = self.free_head;
222        let mut last = self.free_head;
223
224        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
225            assert_ne!(buffer.len(), 0);
226
227            // Write to desc_shadow then copy.
228            let desc = &mut self.desc_shadow[usize::from(self.free_head)];
229            // SAFETY: Our caller promises that the buffers live at least until `pop_used`
230            // returns them.
231            unsafe {
232                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
233            }
234            last = self.free_head;
235            self.free_head = desc.next;
236
237            self.write_desc(last);
238        }
239
240        // set last_elem.next = NULL
241        self.desc_shadow[usize::from(last)]
242            .flags
243            .remove(DescFlags::NEXT);
244        self.write_desc(last);
245
246        self.num_used += (inputs.len() + outputs.len()) as u16;
247
248        head
249    }
250
251    #[cfg(feature = "alloc")]
252    fn add_indirect<'a, 'b>(
253        &mut self,
254        inputs: &'a [&'b [u8]],
255        outputs: &'a mut [&'b mut [u8]],
256    ) -> u16 {
257        let head = self.free_head;
258
259        // Allocate and fill in indirect descriptor list.
260        let mut indirect_list =
261            <[Descriptor]>::new_box_zeroed_with_elems(inputs.len() + outputs.len()).unwrap();
262        for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
263            let desc = &mut indirect_list[i];
264            // SAFETY: Our caller promises that the buffers live at least until `pop_used`
265            // returns them.
266            unsafe {
267                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
268            }
269            desc.next = (i + 1) as u16;
270        }
271        indirect_list
272            .last_mut()
273            .unwrap()
274            .flags
275            .remove(DescFlags::NEXT);
276
277        // Need to store pointer to indirect_list too, because direct_desc.set_buf will only store
278        // the physical DMA address which might be different.
279        assert!(self.indirect_lists[usize::from(head)].is_none());
280        self.indirect_lists[usize::from(head)] = Some(indirect_list.as_mut().into());
281
282        // Write a descriptor pointing to indirect descriptor list. We use Box::leak to prevent the
283        // indirect list from being freed when this function returns; recycle_descriptors is instead
284        // responsible for freeing the memory after the buffer chain is popped.
285        let direct_desc = &mut self.desc_shadow[usize::from(head)];
286        self.free_head = direct_desc.next;
287
288        // SAFETY: Using `Box::leak` on `indirect_list` guarantees it won't be deallocated
289        // when this function returns. The allocation isn't freed until
290        // `recycle_descriptors` is called, at which point the allocation is no longer being
291        // used.
292        unsafe {
293            direct_desc.set_buf::<H>(
294                Box::leak(indirect_list).as_bytes().into(),
295                BufferDirection::DriverToDevice,
296                DescFlags::INDIRECT,
297            );
298        }
299        self.write_desc(head);
300        self.num_used += 1;
301
302        head
303    }
304
305    /// Add the given buffers to the virtqueue, notifies the device, blocks until the device uses
306    /// them, then pops them.
307    ///
308    /// This assumes that the device isn't processing any other buffers at the same time.
309    ///
310    /// The buffers must not be empty.
311    pub fn add_notify_wait_pop<'a>(
312        &mut self,
313        inputs: &'a [&'a [u8]],
314        outputs: &'a mut [&'a mut [u8]],
315        transport: &mut impl Transport,
316    ) -> Result<u32> {
317        // SAFETY: We don't return until the same token has been popped, so the buffers remain
318        // valid and are not otherwise accessed until then.
319        let token = unsafe { self.add(inputs, outputs) }?;
320
321        // Notify the queue.
322        if self.should_notify() {
323            transport.notify(self.queue_idx);
324        }
325
326        // Wait until there is at least one element in the used ring.
327        while !self.can_pop() {
328            spin_loop();
329        }
330
331        // SAFETY: These are the same buffers as we passed to `add` above and they are still valid.
332        unsafe { self.pop_used(token, inputs, outputs) }
333    }
334
335    /// Advise the device whether used buffer notifications are needed.
336    ///
337    /// See Virtio v1.1 2.6.7 Used Buffer Notification Suppression
338    pub fn set_dev_notify(&mut self, enable: bool) {
339        let avail_ring_flags = if enable { 0x0000 } else { 0x0001 };
340        if !self.event_idx {
341            // SAFETY: `self.avail` points to a valid, aligned, initialised, dereferenceable, readable
342            // instance of `AvailRing`.
343            unsafe {
344                (*self.avail.as_ptr())
345                    .flags
346                    .store(avail_ring_flags, Ordering::Release)
347            }
348        }
349    }
350
351    /// Returns whether the driver should notify the device after adding a new buffer to the
352    /// virtqueue.
353    ///
354    /// This will be false if the device has suppressed notifications.
355    pub fn should_notify(&self) -> bool {
356        if self.event_idx {
357            // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable, readable
358            // instance of `UsedRing`.
359            let avail_event = unsafe { (*self.used.as_ptr()).avail_event.load(Ordering::Acquire) };
360            self.avail_idx >= avail_event.wrapping_add(1)
361        } else {
362            // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable, readable
363            // instance of `UsedRing`.
364            unsafe { (*self.used.as_ptr()).flags.load(Ordering::Acquire) & 0x0001 == 0 }
365        }
366    }
367
368    /// Copies the descriptor at the given index from `desc_shadow` to `desc`, so it can be seen by
369    /// the device.
370    fn write_desc(&mut self, index: u16) {
371        let index = usize::from(index);
372        // SAFETY: `self.desc` is properly aligned, dereferenceable and initialised, and nothing
373        // else reads or writes the descriptor during this block.
374        unsafe {
375            (*self.desc.as_ptr())[index] = self.desc_shadow[index].clone();
376        }
377    }
378
379    /// Returns whether there is a used element that can be popped.
380    pub fn can_pop(&self) -> bool {
381        // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable, readable
382        // instance of `UsedRing`.
383        self.last_used_idx != unsafe { (*self.used.as_ptr()).idx.load(Ordering::Acquire) }
384    }
385
386    /// Returns the descriptor index (a.k.a. token) of the next used element without popping it, or
387    /// `None` if the used ring is empty.
388    pub fn peek_used(&self) -> Option<u16> {
389        if self.can_pop() {
390            let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
391            // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable,
392            // readable instance of `UsedRing`.
393            Some(unsafe { (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16 })
394        } else {
395            None
396        }
397    }
398
399    /// Returns the number of free descriptors.
400    pub fn available_desc(&self) -> usize {
401        #[cfg(feature = "alloc")]
402        if self.indirect {
403            return if usize::from(self.num_used) == SIZE {
404                0
405            } else {
406                SIZE
407            };
408        }
409
410        SIZE - usize::from(self.num_used)
411    }
412
413    /// Unshares buffers in the list starting at descriptor index `head` and adds them to the free
414    /// list. Unsharing may involve copying data back to the original buffers, so they must be
415    /// passed in too.
416    ///
417    /// This will push all linked descriptors at the front of the free list.
418    ///
419    /// # Safety
420    ///
421    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
422    /// queue by `add`.
423    unsafe fn recycle_descriptors<'a>(
424        &mut self,
425        head: u16,
426        inputs: &'a [&'a [u8]],
427        outputs: &'a mut [&'a mut [u8]],
428    ) {
429        let original_free_head = self.free_head;
430        self.free_head = head;
431
432        let head_desc = &mut self.desc_shadow[usize::from(head)];
433        if head_desc.flags.contains(DescFlags::INDIRECT) {
434            #[cfg(feature = "alloc")]
435            {
436                // Find the indirect descriptor list, unshare it and move its descriptor to the free
437                // list.
438                let indirect_list = self.indirect_lists[usize::from(head)].take().unwrap();
439                // SAFETY: We allocated the indirect list in `add_indirect`, and the device has
440                // finished accessing it by this point.
441                let mut indirect_list = unsafe { Box::from_raw(indirect_list.as_ptr()) };
442                let paddr = head_desc.addr;
443                head_desc.unset_buf();
444                self.num_used -= 1;
445                head_desc.next = original_free_head;
446
447                // SAFETY: `paddr` comes from a previous call `H::share` (inside
448                // `Descriptor::set_buf`, which was called from `add_direct` or `add_indirect`).
449                // `indirect_list` is owned by this function and is not accessed from any other threads.
450                unsafe {
451                    H::unshare(
452                        paddr,
453                        indirect_list.as_mut_bytes().into(),
454                        BufferDirection::DriverToDevice,
455                    );
456                }
457
458                // Unshare the buffers in the indirect descriptor list, and free it.
459                assert_eq!(indirect_list.len(), inputs.len() + outputs.len());
460                for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
461                    assert_ne!(buffer.len(), 0);
462
463                    // SAFETY: The caller ensures that the buffer is valid and matches the
464                    // descriptor from which we got `paddr`.
465                    unsafe {
466                        // Unshare the buffer (and perhaps copy its contents back to the original
467                        // buffer).
468                        H::unshare(indirect_list[i].addr, buffer, direction);
469                    }
470                }
471                drop(indirect_list);
472            }
473        } else {
474            let mut next = Some(head);
475
476            for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
477                assert_ne!(buffer.len(), 0);
478
479                let desc_index = next.expect("Descriptor chain was shorter than expected.");
480                let desc = &mut self.desc_shadow[usize::from(desc_index)];
481
482                let paddr = desc.addr;
483                desc.unset_buf();
484                self.num_used -= 1;
485                next = desc.next();
486                if next.is_none() {
487                    desc.next = original_free_head;
488                }
489
490                self.write_desc(desc_index);
491
492                // SAFETY: The caller ensures that the buffer is valid and matches the descriptor
493                // from which we got `paddr`.
494                unsafe {
495                    // Unshare the buffer (and perhaps copy its contents back to the original buffer).
496                    H::unshare(paddr, buffer, direction);
497                }
498            }
499
500            if next.is_some() {
501                panic!("Descriptor chain was longer than expected.");
502            }
503        }
504    }
505
506    /// If the given token is next on the device used queue, pops it and returns the total buffer
507    /// length which was used (written) by the device.
508    ///
509    /// Ref: linux virtio_ring.c virtqueue_get_buf_ctx
510    ///
511    /// # Safety
512    ///
513    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
514    /// queue by `add` when it returned the token being passed in here.
515    pub unsafe fn pop_used<'a>(
516        &mut self,
517        token: u16,
518        inputs: &'a [&'a [u8]],
519        outputs: &'a mut [&'a mut [u8]],
520    ) -> Result<u32> {
521        if !self.can_pop() {
522            return Err(Error::NotReady);
523        }
524
525        // Get the index of the start of the descriptor chain for the next element in the used ring.
526        let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
527        let index;
528        let len;
529        // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable, readable
530        // instance of `UsedRing`.
531        unsafe {
532            index = (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16;
533            len = (*self.used.as_ptr()).ring[last_used_slot as usize].len;
534        }
535
536        if index != token {
537            // The device used a different descriptor chain to the one we were expecting.
538            return Err(Error::WrongToken);
539        }
540
541        // SAFETY: The caller ensures the buffers are valid and match the descriptor.
542        unsafe {
543            self.recycle_descriptors(index, inputs, outputs);
544        }
545        self.last_used_idx = self.last_used_idx.wrapping_add(1);
546
547        if self.event_idx {
548            // SAFETY: `self.avail` points to a valid, aligned, initialised, dereferenceable,
549            // readable instance of `AvailRing`.
550            unsafe {
551                (*self.avail.as_ptr())
552                    .used_event
553                    .store(self.last_used_idx, Ordering::Release);
554            }
555        }
556
557        Ok(len)
558    }
559}
560
561// SAFETY: None of the virt queue resources are tied to a particular thread.
562unsafe impl<H: Hal, const SIZE: usize> Send for VirtQueue<H, SIZE> {}
563
564// SAFETY: A `&VirtQueue` only allows reading from the various pointers it contains, so there is no
565// data race.
566unsafe impl<H: Hal, const SIZE: usize> Sync for VirtQueue<H, SIZE> {}
567
568/// The inner layout of a VirtQueue.
569///
570/// Ref: 2.6 Split Virtqueues
571#[derive(Debug)]
572enum VirtQueueLayout<H: Hal> {
573    Legacy {
574        dma: Dma<H>,
575        avail_offset: usize,
576        used_offset: usize,
577    },
578    Modern {
579        /// The region used for the descriptor area and driver area.
580        driver_to_device_dma: Dma<H>,
581        /// The region used for the device area.
582        device_to_driver_dma: Dma<H>,
583        /// The offset from the start of the `driver_to_device_dma` region to the driver area
584        /// (available ring).
585        avail_offset: usize,
586    },
587}
588
589impl<H: Hal> VirtQueueLayout<H> {
590    /// Allocates a single DMA region containing all parts of the virtqueue, following the layout
591    /// required by legacy interfaces.
592    ///
593    /// Ref: 2.6.2 Legacy Interfaces: A Note on Virtqueue Layout
594    fn allocate_legacy(queue_size: u16) -> Result<Self> {
595        let (desc, avail, used) = queue_part_sizes(queue_size);
596        let size = align_up(desc + avail) + align_up(used);
597        // Allocate contiguous pages.
598        let dma = Dma::new(size / PAGE_SIZE, BufferDirection::Both)?;
599        Ok(Self::Legacy {
600            dma,
601            avail_offset: desc,
602            used_offset: align_up(desc + avail),
603        })
604    }
605
606    /// Allocates separate DMA regions for the the different parts of the virtqueue, as supported by
607    /// non-legacy interfaces.
608    ///
609    /// This is preferred over `allocate_legacy` where possible as it reduces memory fragmentation
610    /// and allows the HAL to know which DMA regions are used in which direction.
611    fn allocate_flexible(queue_size: u16) -> Result<Self> {
612        let (desc, avail, used) = queue_part_sizes(queue_size);
613        let driver_to_device_dma = Dma::new(pages(desc + avail), BufferDirection::DriverToDevice)?;
614        let device_to_driver_dma = Dma::new(pages(used), BufferDirection::DeviceToDriver)?;
615        Ok(Self::Modern {
616            driver_to_device_dma,
617            device_to_driver_dma,
618            avail_offset: desc,
619        })
620    }
621
622    /// Returns the physical address of the descriptor area.
623    fn descriptors_paddr(&self) -> PhysAddr {
624        match self {
625            Self::Legacy { dma, .. } => dma.paddr(),
626            Self::Modern {
627                driver_to_device_dma,
628                ..
629            } => driver_to_device_dma.paddr(),
630        }
631    }
632
633    /// Returns a pointer to the descriptor table (in the descriptor area).
634    fn descriptors_vaddr(&self) -> NonNull<u8> {
635        match self {
636            Self::Legacy { dma, .. } => dma.vaddr(0),
637            Self::Modern {
638                driver_to_device_dma,
639                ..
640            } => driver_to_device_dma.vaddr(0),
641        }
642    }
643
644    /// Returns the physical address of the driver area.
645    fn driver_area_paddr(&self) -> PhysAddr {
646        match self {
647            Self::Legacy {
648                dma, avail_offset, ..
649            } => dma.paddr() + *avail_offset as u64,
650            Self::Modern {
651                driver_to_device_dma,
652                avail_offset,
653                ..
654            } => driver_to_device_dma.paddr() + *avail_offset as u64,
655        }
656    }
657
658    /// Returns a pointer to the available ring (in the driver area).
659    fn avail_vaddr(&self) -> NonNull<u8> {
660        match self {
661            Self::Legacy {
662                dma, avail_offset, ..
663            } => dma.vaddr(*avail_offset),
664            Self::Modern {
665                driver_to_device_dma,
666                avail_offset,
667                ..
668            } => driver_to_device_dma.vaddr(*avail_offset),
669        }
670    }
671
672    /// Returns the physical address of the device area.
673    fn device_area_paddr(&self) -> PhysAddr {
674        match self {
675            Self::Legacy {
676                used_offset, dma, ..
677            } => dma.paddr() + *used_offset as u64,
678            Self::Modern {
679                device_to_driver_dma,
680                ..
681            } => device_to_driver_dma.paddr(),
682        }
683    }
684
685    /// Returns a pointer to the used ring (in the driver area).
686    fn used_vaddr(&self) -> NonNull<u8> {
687        match self {
688            Self::Legacy {
689                dma, used_offset, ..
690            } => dma.vaddr(*used_offset),
691            Self::Modern {
692                device_to_driver_dma,
693                ..
694            } => device_to_driver_dma.vaddr(0),
695        }
696    }
697}
698
699/// Returns the size in bytes of the descriptor table, available ring and used ring for a given
700/// queue size.
701///
702/// Ref: 2.6 Split Virtqueues
703fn queue_part_sizes(queue_size: u16) -> (usize, usize, usize) {
704    assert!(
705        queue_size.is_power_of_two(),
706        "queue size should be a power of 2"
707    );
708    let queue_size = usize::from(queue_size);
709    let desc = size_of::<Descriptor>() * queue_size;
710    let avail = size_of::<u16>() * (3 + queue_size);
711    let used = size_of::<u16>() * 3 + size_of::<UsedElem>() * queue_size;
712    (desc, avail, used)
713}
714
715#[repr(C, align(16))]
716#[derive(Clone, Debug, FromBytes, Immutable, IntoBytes, KnownLayout)]
717pub(crate) struct Descriptor {
718    addr: u64,
719    len: u32,
720    flags: DescFlags,
721    next: u16,
722}
723
724impl Descriptor {
725    /// Sets the buffer address, length and flags, and shares it with the device.
726    ///
727    /// # Safety
728    ///
729    /// The caller must ensure that the buffer lives at least as long as the descriptor is active.
730    unsafe fn set_buf<H: Hal>(
731        &mut self,
732        buf: NonNull<[u8]>,
733        direction: BufferDirection,
734        extra_flags: DescFlags,
735    ) {
736        // SAFETY: Our caller promises that the buffer is valid.
737        unsafe {
738            self.addr = H::share(buf, direction);
739        }
740        self.len = buf.len().try_into().unwrap();
741        self.flags = extra_flags
742            | match direction {
743                BufferDirection::DeviceToDriver => DescFlags::WRITE,
744                BufferDirection::DriverToDevice => DescFlags::empty(),
745                BufferDirection::Both => {
746                    panic!("Buffer passed to device should never use BufferDirection::Both.")
747                }
748            };
749    }
750
751    /// Sets the buffer address and length to 0.
752    ///
753    /// This must only be called once the device has finished using the descriptor.
754    fn unset_buf(&mut self) {
755        self.addr = 0;
756        self.len = 0;
757    }
758
759    /// Returns the index of the next descriptor in the chain if the `NEXT` flag is set, or `None`
760    /// if it is not (and thus this descriptor is the end of the chain).
761    fn next(&self) -> Option<u16> {
762        if self.flags.contains(DescFlags::NEXT) {
763            Some(self.next)
764        } else {
765            None
766        }
767    }
768}
769
770/// Descriptor flags
771#[derive(
772    Copy, Clone, Debug, Default, Eq, FromBytes, Immutable, IntoBytes, KnownLayout, PartialEq,
773)]
774#[repr(transparent)]
775struct DescFlags(u16);
776
777bitflags! {
778    impl DescFlags: u16 {
779        const NEXT = 1;
780        const WRITE = 2;
781        const INDIRECT = 4;
782    }
783}
784
785/// The driver uses the available ring to offer buffers to the device:
786/// each ring entry refers to the head of a descriptor chain.
787/// It is only written by the driver and read by the device.
788#[repr(C)]
789#[derive(Debug)]
790struct AvailRing<const SIZE: usize> {
791    flags: AtomicU16,
792    /// A driver MUST NOT decrement the idx.
793    idx: AtomicU16,
794    ring: [u16; SIZE],
795    /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
796    used_event: AtomicU16,
797}
798
799/// The used ring is where the device returns buffers once it is done with them:
800/// it is only written to by the device, and read by the driver.
801#[repr(C)]
802#[derive(Debug)]
803struct UsedRing<const SIZE: usize> {
804    flags: AtomicU16,
805    idx: AtomicU16,
806    ring: [UsedElem; SIZE],
807    /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
808    avail_event: AtomicU16,
809}
810
811#[repr(C)]
812#[derive(Debug)]
813struct UsedElem {
814    id: u32,
815    len: u32,
816}
817
818struct InputOutputIter<'a, 'b> {
819    inputs: &'a [&'b [u8]],
820    outputs: &'a mut [&'b mut [u8]],
821}
822
823impl<'a, 'b> InputOutputIter<'a, 'b> {
824    fn new(inputs: &'a [&'b [u8]], outputs: &'a mut [&'b mut [u8]]) -> Self {
825        Self { inputs, outputs }
826    }
827}
828
829impl Iterator for InputOutputIter<'_, '_> {
830    type Item = (NonNull<[u8]>, BufferDirection);
831
832    fn next(&mut self) -> Option<Self::Item> {
833        if let Some(input) = take_first(&mut self.inputs) {
834            Some(((*input).into(), BufferDirection::DriverToDevice))
835        } else {
836            let output = take_first_mut(&mut self.outputs)?;
837            Some(((*output).into(), BufferDirection::DeviceToDriver))
838        }
839    }
840}
841
842// TODO: Use `slice::take_first` once it is stable
843// (https://github.com/rust-lang/rust/issues/62280).
844fn take_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> {
845    let (first, rem) = slice.split_first()?;
846    *slice = rem;
847    Some(first)
848}
849
850// TODO: Use `slice::take_first_mut` once it is stable
851// (https://github.com/rust-lang/rust/issues/62280).
852fn take_first_mut<'a, T>(slice: &mut &'a mut [T]) -> Option<&'a mut T> {
853    let (first, rem) = take(slice).split_first_mut()?;
854    *slice = rem;
855    Some(first)
856}
857
858/// Simulates the device reading from a VirtIO queue and writing a response back, for use in tests.
859///
860/// The fake device always uses descriptors in order.
861///
862/// Returns true if a descriptor chain was available and processed, or false if no descriptors were
863/// available.
864#[cfg(test)]
865pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
866    descriptors: *const [Descriptor; QUEUE_SIZE],
867    queue_driver_area: *const u8,
868    queue_device_area: *mut u8,
869    handler: impl FnOnce(Vec<u8>) -> Vec<u8>,
870) -> bool {
871    use core::{ops::Deref, slice};
872
873    let available_ring = queue_driver_area as *const AvailRing<QUEUE_SIZE>;
874    let used_ring = queue_device_area as *mut UsedRing<QUEUE_SIZE>;
875
876    // Safe because the various pointers are properly aligned, dereferenceable, initialised, and
877    // nothing else accesses them during this block.
878    unsafe {
879        // Make sure there is actually at least one descriptor available to read from.
880        if (*available_ring).idx.load(Ordering::Acquire) == (*used_ring).idx.load(Ordering::Acquire)
881        {
882            return false;
883        }
884        // The fake device always uses descriptors in order, like VIRTIO_F_IN_ORDER, so
885        // `used_ring.idx` marks the next descriptor we should take from the available ring.
886        let next_slot = (*used_ring).idx.load(Ordering::Acquire) & (QUEUE_SIZE as u16 - 1);
887        let head_descriptor_index = (*available_ring).ring[next_slot as usize];
888        let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
889
890        let input_length;
891        let output;
892        if descriptor.flags.contains(DescFlags::INDIRECT) {
893            // The descriptor shouldn't have any other flags if it is indirect.
894            assert_eq!(descriptor.flags, DescFlags::INDIRECT);
895
896            // Loop through all input descriptors in the indirect descriptor list, reading data from
897            // them.
898            let indirect_descriptor_list: &[Descriptor] = zerocopy::Ref::into_ref(
899                zerocopy::Ref::<_, [Descriptor]>::from_bytes(slice::from_raw_parts(
900                    descriptor.addr as *const u8,
901                    descriptor.len as usize,
902                ))
903                .unwrap(),
904            );
905            let mut input = Vec::new();
906            let mut indirect_descriptor_index = 0;
907            while indirect_descriptor_index < indirect_descriptor_list.len() {
908                let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index];
909                if indirect_descriptor.flags.contains(DescFlags::WRITE) {
910                    break;
911                }
912
913                input.extend_from_slice(slice::from_raw_parts(
914                    indirect_descriptor.addr as *const u8,
915                    indirect_descriptor.len as usize,
916                ));
917
918                indirect_descriptor_index += 1;
919            }
920            input_length = input.len();
921
922            // Let the test handle the request.
923            output = handler(input);
924
925            // Write the response to the remaining descriptors.
926            let mut remaining_output = output.deref();
927            while indirect_descriptor_index < indirect_descriptor_list.len() {
928                let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index];
929                assert!(indirect_descriptor.flags.contains(DescFlags::WRITE));
930
931                let length_to_write = min(remaining_output.len(), indirect_descriptor.len as usize);
932                ptr::copy(
933                    remaining_output.as_ptr(),
934                    indirect_descriptor.addr as *mut u8,
935                    length_to_write,
936                );
937                remaining_output = &remaining_output[length_to_write..];
938
939                indirect_descriptor_index += 1;
940            }
941            assert_eq!(remaining_output.len(), 0);
942        } else {
943            // Loop through all input descriptors in the chain, reading data from them.
944            let mut input = Vec::new();
945            while !descriptor.flags.contains(DescFlags::WRITE) {
946                input.extend_from_slice(slice::from_raw_parts(
947                    descriptor.addr as *const u8,
948                    descriptor.len as usize,
949                ));
950
951                if let Some(next) = descriptor.next() {
952                    descriptor = &(*descriptors)[next as usize];
953                } else {
954                    break;
955                }
956            }
957            input_length = input.len();
958
959            // Let the test handle the request.
960            output = handler(input);
961
962            // Write the response to the remaining descriptors.
963            let mut remaining_output = output.deref();
964            if descriptor.flags.contains(DescFlags::WRITE) {
965                loop {
966                    assert!(descriptor.flags.contains(DescFlags::WRITE));
967
968                    let length_to_write = min(remaining_output.len(), descriptor.len as usize);
969                    ptr::copy(
970                        remaining_output.as_ptr(),
971                        descriptor.addr as *mut u8,
972                        length_to_write,
973                    );
974                    remaining_output = &remaining_output[length_to_write..];
975
976                    if let Some(next) = descriptor.next() {
977                        descriptor = &(*descriptors)[next as usize];
978                    } else {
979                        break;
980                    }
981                }
982            }
983            assert_eq!(remaining_output.len(), 0);
984        }
985
986        // Mark the buffer as used.
987        (*used_ring).ring[next_slot as usize].id = head_descriptor_index.into();
988        (*used_ring).ring[next_slot as usize].len = (input_length + output.len()) as u32;
989        (*used_ring).idx.fetch_add(1, Ordering::AcqRel);
990
991        true
992    }
993}
994
995#[cfg(test)]
996mod tests {
997    use super::*;
998    use crate::{
999        device::common::Feature,
1000        hal::fake::FakeHal,
1001        transport::{
1002            DeviceType,
1003            fake::{FakeTransport, QueueStatus, State},
1004            mmio::{MODERN_VERSION, MmioTransport, VirtIOHeader},
1005        },
1006    };
1007    use safe_mmio::UniqueMmioPointer;
1008    use std::sync::{Arc, Mutex};
1009
1010    #[test]
1011    fn queue_too_big() {
1012        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1013        let mut transport = MmioTransport::new_from_unique(
1014            UniqueMmioPointer::from(&mut header),
1015            UniqueMmioPointer::from([].as_mut_slice()),
1016        )
1017        .unwrap();
1018        assert_eq!(
1019            VirtQueue::<FakeHal, 8>::new(&mut transport, 0, false, false).unwrap_err(),
1020            Error::InvalidParam
1021        );
1022    }
1023
1024    #[test]
1025    fn queue_already_used() {
1026        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1027        let mut transport = MmioTransport::new_from_unique(
1028            UniqueMmioPointer::from(&mut header),
1029            UniqueMmioPointer::from([].as_mut_slice()),
1030        )
1031        .unwrap();
1032        VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1033        assert_eq!(
1034            VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap_err(),
1035            Error::AlreadyUsed
1036        );
1037    }
1038
1039    #[test]
1040    fn add_empty() {
1041        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1042        let mut transport = MmioTransport::new_from_unique(
1043            UniqueMmioPointer::from(&mut header),
1044            UniqueMmioPointer::from([].as_mut_slice()),
1045        )
1046        .unwrap();
1047        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1048        assert_eq!(
1049            unsafe { queue.add(&[], &mut []) }.unwrap_err(),
1050            Error::InvalidParam
1051        );
1052    }
1053
1054    #[test]
1055    fn add_too_many() {
1056        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1057        let mut transport = MmioTransport::new_from_unique(
1058            UniqueMmioPointer::from(&mut header),
1059            UniqueMmioPointer::from([].as_mut_slice()),
1060        )
1061        .unwrap();
1062        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1063        assert_eq!(queue.available_desc(), 4);
1064        assert_eq!(
1065            unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
1066            Error::QueueFull
1067        );
1068    }
1069
1070    #[test]
1071    fn add_buffers() {
1072        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1073        let mut transport = MmioTransport::new_from_unique(
1074            UniqueMmioPointer::from(&mut header),
1075            UniqueMmioPointer::from([].as_mut_slice()),
1076        )
1077        .unwrap();
1078        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1079        assert_eq!(queue.available_desc(), 4);
1080
1081        // Add a buffer chain consisting of two device-readable parts followed by two
1082        // device-writable parts.
1083        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
1084
1085        assert_eq!(queue.available_desc(), 0);
1086        assert!(!queue.can_pop());
1087
1088        // Safe because the various parts of the queue are properly aligned, dereferenceable and
1089        // initialised, and nothing else is accessing them at the same time.
1090        unsafe {
1091            let first_descriptor_index = (*queue.avail.as_ptr()).ring[0];
1092            assert_eq!(first_descriptor_index, token);
1093            assert_eq!(
1094                (*queue.desc.as_ptr())[first_descriptor_index as usize].len,
1095                2
1096            );
1097            assert_eq!(
1098                (*queue.desc.as_ptr())[first_descriptor_index as usize].flags,
1099                DescFlags::NEXT
1100            );
1101            let second_descriptor_index =
1102                (*queue.desc.as_ptr())[first_descriptor_index as usize].next;
1103            assert_eq!(
1104                (*queue.desc.as_ptr())[second_descriptor_index as usize].len,
1105                1
1106            );
1107            assert_eq!(
1108                (*queue.desc.as_ptr())[second_descriptor_index as usize].flags,
1109                DescFlags::NEXT
1110            );
1111            let third_descriptor_index =
1112                (*queue.desc.as_ptr())[second_descriptor_index as usize].next;
1113            assert_eq!(
1114                (*queue.desc.as_ptr())[third_descriptor_index as usize].len,
1115                2
1116            );
1117            assert_eq!(
1118                (*queue.desc.as_ptr())[third_descriptor_index as usize].flags,
1119                DescFlags::NEXT | DescFlags::WRITE
1120            );
1121            let fourth_descriptor_index =
1122                (*queue.desc.as_ptr())[third_descriptor_index as usize].next;
1123            assert_eq!(
1124                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].len,
1125                1
1126            );
1127            assert_eq!(
1128                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].flags,
1129                DescFlags::WRITE
1130            );
1131        }
1132    }
1133
1134    #[cfg(feature = "alloc")]
1135    #[test]
1136    fn add_buffers_indirect() {
1137        use core::ptr::slice_from_raw_parts;
1138
1139        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1140        let mut transport = MmioTransport::new_from_unique(
1141            UniqueMmioPointer::from(&mut header),
1142            UniqueMmioPointer::from([].as_mut_slice()),
1143        )
1144        .unwrap();
1145        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, true, false).unwrap();
1146        assert_eq!(queue.available_desc(), 4);
1147
1148        // Add a buffer chain consisting of two device-readable parts followed by two
1149        // device-writable parts.
1150        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
1151
1152        assert_eq!(queue.available_desc(), 4);
1153        assert!(!queue.can_pop());
1154
1155        // Safe because the various parts of the queue are properly aligned, dereferenceable and
1156        // initialised, and nothing else is accessing them at the same time.
1157        unsafe {
1158            let indirect_descriptor_index = (*queue.avail.as_ptr()).ring[0];
1159            assert_eq!(indirect_descriptor_index, token);
1160            assert_eq!(
1161                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].len as usize,
1162                4 * size_of::<Descriptor>()
1163            );
1164            assert_eq!(
1165                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].flags,
1166                DescFlags::INDIRECT
1167            );
1168
1169            let indirect_descriptors = slice_from_raw_parts(
1170                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].addr
1171                    as *const Descriptor,
1172                4,
1173            );
1174            assert_eq!((*indirect_descriptors)[0].len, 2);
1175            assert_eq!((*indirect_descriptors)[0].flags, DescFlags::NEXT);
1176            assert_eq!((*indirect_descriptors)[0].next, 1);
1177            assert_eq!((*indirect_descriptors)[1].len, 1);
1178            assert_eq!((*indirect_descriptors)[1].flags, DescFlags::NEXT);
1179            assert_eq!((*indirect_descriptors)[1].next, 2);
1180            assert_eq!((*indirect_descriptors)[2].len, 2);
1181            assert_eq!(
1182                (*indirect_descriptors)[2].flags,
1183                DescFlags::NEXT | DescFlags::WRITE
1184            );
1185            assert_eq!((*indirect_descriptors)[2].next, 3);
1186            assert_eq!((*indirect_descriptors)[3].len, 1);
1187            assert_eq!((*indirect_descriptors)[3].flags, DescFlags::WRITE);
1188        }
1189    }
1190
1191    /// Tests that the queue advises the device that notifications are needed.
1192    #[test]
1193    fn set_dev_notify() {
1194        let state = Arc::new(Mutex::new(State::new(vec![QueueStatus::default()], ())));
1195        let mut transport = FakeTransport {
1196            device_type: DeviceType::Block,
1197            max_queue_size: 4,
1198            device_features: 0,
1199            state: state.clone(),
1200        };
1201        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1202
1203        // Check that the avail ring's flag is zero by default.
1204        assert_eq!(
1205            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1206            0x0
1207        );
1208
1209        queue.set_dev_notify(false);
1210
1211        // Check that the avail ring's flag is 1 after `disable_dev_notify`.
1212        assert_eq!(
1213            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1214            0x1
1215        );
1216
1217        queue.set_dev_notify(true);
1218
1219        // Check that the avail ring's flag is 0 after `enable_dev_notify`.
1220        assert_eq!(
1221            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1222            0x0
1223        );
1224    }
1225
1226    /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
1227    /// notifications.
1228    #[test]
1229    fn add_notify() {
1230        let state = Arc::new(Mutex::new(State::new(vec![QueueStatus::default()], ())));
1231        let mut transport = FakeTransport {
1232            device_type: DeviceType::Block,
1233            max_queue_size: 4,
1234            device_features: 0,
1235            state: state.clone(),
1236        };
1237        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1238
1239        // Add a buffer chain with a single device-readable part.
1240        unsafe { queue.add(&[&[42]], &mut []) }.unwrap();
1241
1242        // Check that the transport would be notified.
1243        assert_eq!(queue.should_notify(), true);
1244
1245        // SAFETY: the various parts of the queue are properly aligned, dereferenceable and
1246        // initialised, and nothing else is accessing them at the same time.
1247        unsafe {
1248            // Suppress notifications.
1249            (*queue.used.as_ptr()).flags.store(0x01, Ordering::Release);
1250        }
1251
1252        // Check that the transport would not be notified.
1253        assert_eq!(queue.should_notify(), false);
1254    }
1255
1256    /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
1257    /// notifications with the `avail_event` index.
1258    #[test]
1259    fn add_notify_event_idx() {
1260        let state = Arc::new(Mutex::new(State::new(vec![QueueStatus::default()], ())));
1261        let mut transport = FakeTransport {
1262            device_type: DeviceType::Block,
1263            max_queue_size: 4,
1264            device_features: Feature::RING_EVENT_IDX.bits(),
1265            state: state.clone(),
1266        };
1267        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, true).unwrap();
1268
1269        // Add a buffer chain with a single device-readable part.
1270        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 0);
1271
1272        // Check that the transport would be notified.
1273        assert_eq!(queue.should_notify(), true);
1274
1275        // SAFETY: the various parts of the queue are properly aligned, dereferenceable and
1276        // initialised, and nothing else is accessing them at the same time.
1277        unsafe {
1278            // Suppress notifications.
1279            (*queue.used.as_ptr())
1280                .avail_event
1281                .store(1, Ordering::Release);
1282        }
1283
1284        // Check that the transport would not be notified.
1285        assert_eq!(queue.should_notify(), false);
1286
1287        // Add another buffer chain.
1288        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 1);
1289
1290        // Check that the transport should be notified again now.
1291        assert_eq!(queue.should_notify(), true);
1292    }
1293}