1#![deny(unsafe_op_in_unsafe_fn)]
2
3#[cfg(feature = "alloc")]
4pub mod owning;
5
6use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
7use crate::transport::Transport;
8use crate::{align_up, pages, Error, Result, PAGE_SIZE};
9#[cfg(feature = "alloc")]
10use alloc::boxed::Box;
11use bitflags::bitflags;
12#[cfg(test)]
13use core::cmp::min;
14use core::convert::TryInto;
15use core::hint::spin_loop;
16use core::mem::{size_of, take};
17#[cfg(test)]
18use core::ptr;
19use core::ptr::NonNull;
20use core::sync::atomic::{fence, AtomicU16, Ordering};
21use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes, KnownLayout};
22
23#[derive(Debug)]
30pub struct VirtQueue<H: Hal, const SIZE: usize> {
31    layout: VirtQueueLayout<H>,
33    desc: NonNull<[Descriptor]>,
39    avail: NonNull<AvailRing<SIZE>>,
45    used: NonNull<UsedRing<SIZE>>,
47
48    queue_idx: u16,
50    num_used: u16,
52    free_head: u16,
54    desc_shadow: [Descriptor; SIZE],
56    avail_idx: u16,
58    last_used_idx: u16,
59    event_idx: bool,
61    #[cfg(feature = "alloc")]
62    indirect: bool,
63    #[cfg(feature = "alloc")]
64    indirect_lists: [Option<NonNull<[Descriptor]>>; SIZE],
65}
66
67impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
68    const SIZE_OK: () = assert!(SIZE.is_power_of_two() && SIZE <= u16::MAX as usize);
69
70    pub fn new<T: Transport>(
78        transport: &mut T,
79        idx: u16,
80        indirect: bool,
81        event_idx: bool,
82    ) -> Result<Self> {
83        #[allow(clippy::let_unit_value)]
84        let _ = Self::SIZE_OK;
85
86        if transport.queue_used(idx) {
87            return Err(Error::AlreadyUsed);
88        }
89        if transport.max_queue_size(idx) < SIZE as u32 {
90            return Err(Error::InvalidParam);
91        }
92        let size = SIZE as u16;
93
94        let layout = if transport.requires_legacy_layout() {
95            VirtQueueLayout::allocate_legacy(size)?
96        } else {
97            VirtQueueLayout::allocate_flexible(size)?
98        };
99
100        transport.queue_set(
101            idx,
102            size.into(),
103            layout.descriptors_paddr(),
104            layout.driver_area_paddr(),
105            layout.device_area_paddr(),
106        );
107
108        let desc =
109            NonNull::slice_from_raw_parts(layout.descriptors_vaddr().cast::<Descriptor>(), SIZE);
110        let avail = layout.avail_vaddr().cast();
111        let used = layout.used_vaddr().cast();
112
113        let mut desc_shadow: [Descriptor; SIZE] = FromZeros::new_zeroed();
114        for i in 0..(size - 1) {
116            desc_shadow[i as usize].next = i + 1;
117            unsafe {
120                (*desc.as_ptr())[i as usize].next = i + 1;
121            }
122        }
123
124        #[cfg(feature = "alloc")]
125        const NONE: Option<NonNull<[Descriptor]>> = None;
126        Ok(VirtQueue {
127            layout,
128            desc,
129            avail,
130            used,
131            queue_idx: idx,
132            num_used: 0,
133            free_head: 0,
134            desc_shadow,
135            avail_idx: 0,
136            last_used_idx: 0,
137            event_idx,
138            #[cfg(feature = "alloc")]
139            indirect,
140            #[cfg(feature = "alloc")]
141            indirect_lists: [NONE; SIZE],
142        })
143    }
144
145    pub unsafe fn add<'a, 'b>(
156        &mut self,
157        inputs: &'a [&'b [u8]],
158        outputs: &'a mut [&'b mut [u8]],
159    ) -> Result<u16> {
160        if inputs.is_empty() && outputs.is_empty() {
161            return Err(Error::InvalidParam);
162        }
163        let descriptors_needed = inputs.len() + outputs.len();
164        #[cfg(feature = "alloc")]
167        if self.num_used as usize + 1 > SIZE
168            || descriptors_needed > SIZE
169            || (!self.indirect && self.num_used as usize + descriptors_needed > SIZE)
170        {
171            return Err(Error::QueueFull);
172        }
173        #[cfg(not(feature = "alloc"))]
174        if self.num_used as usize + descriptors_needed > SIZE {
175            return Err(Error::QueueFull);
176        }
177
178        #[cfg(feature = "alloc")]
179        let head = if self.indirect && descriptors_needed > 1 {
180            self.add_indirect(inputs, outputs)
181        } else {
182            self.add_direct(inputs, outputs)
183        };
184        #[cfg(not(feature = "alloc"))]
185        let head = self.add_direct(inputs, outputs);
186
187        let avail_slot = self.avail_idx & (SIZE as u16 - 1);
188        unsafe {
190            (*self.avail.as_ptr()).ring[avail_slot as usize] = head;
191        }
192
193        fence(Ordering::SeqCst);
196
197        self.avail_idx = self.avail_idx.wrapping_add(1);
199        unsafe {
201            (*self.avail.as_ptr())
202                .idx
203                .store(self.avail_idx, Ordering::Release);
204        }
205
206        Ok(head)
207    }
208
209    fn add_direct<'a, 'b>(
210        &mut self,
211        inputs: &'a [&'b [u8]],
212        outputs: &'a mut [&'b mut [u8]],
213    ) -> u16 {
214        let head = self.free_head;
216        let mut last = self.free_head;
217
218        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
219            assert_ne!(buffer.len(), 0);
220
221            let desc = &mut self.desc_shadow[usize::from(self.free_head)];
223            unsafe {
226                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
227            }
228            last = self.free_head;
229            self.free_head = desc.next;
230
231            self.write_desc(last);
232        }
233
234        self.desc_shadow[usize::from(last)]
236            .flags
237            .remove(DescFlags::NEXT);
238        self.write_desc(last);
239
240        self.num_used += (inputs.len() + outputs.len()) as u16;
241
242        head
243    }
244
245    #[cfg(feature = "alloc")]
246    fn add_indirect<'a, 'b>(
247        &mut self,
248        inputs: &'a [&'b [u8]],
249        outputs: &'a mut [&'b mut [u8]],
250    ) -> u16 {
251        let head = self.free_head;
252
253        let mut indirect_list =
255            <[Descriptor]>::new_box_zeroed_with_elems(inputs.len() + outputs.len()).unwrap();
256        for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
257            let desc = &mut indirect_list[i];
258            unsafe {
261                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
262            }
263            desc.next = (i + 1) as u16;
264        }
265        indirect_list
266            .last_mut()
267            .unwrap()
268            .flags
269            .remove(DescFlags::NEXT);
270
271        assert!(self.indirect_lists[usize::from(head)].is_none());
274        self.indirect_lists[usize::from(head)] = Some(indirect_list.as_mut().into());
275
276        let direct_desc = &mut self.desc_shadow[usize::from(head)];
280        self.free_head = direct_desc.next;
281
282        unsafe {
287            direct_desc.set_buf::<H>(
288                Box::leak(indirect_list).as_bytes().into(),
289                BufferDirection::DriverToDevice,
290                DescFlags::INDIRECT,
291            );
292        }
293        self.write_desc(head);
294        self.num_used += 1;
295
296        head
297    }
298
299    pub fn add_notify_wait_pop<'a>(
306        &mut self,
307        inputs: &'a [&'a [u8]],
308        outputs: &'a mut [&'a mut [u8]],
309        transport: &mut impl Transport,
310    ) -> Result<u32> {
311        let token = unsafe { self.add(inputs, outputs) }?;
314
315        if self.should_notify() {
317            transport.notify(self.queue_idx);
318        }
319
320        while !self.can_pop() {
322            spin_loop();
323        }
324
325        unsafe { self.pop_used(token, inputs, outputs) }
327    }
328
329    pub fn set_dev_notify(&mut self, enable: bool) {
333        let avail_ring_flags = if enable { 0x0000 } else { 0x0001 };
334        if !self.event_idx {
335            unsafe {
338                (*self.avail.as_ptr())
339                    .flags
340                    .store(avail_ring_flags, Ordering::Release)
341            }
342        }
343    }
344
345    pub fn should_notify(&self) -> bool {
350        if self.event_idx {
351            let avail_event = unsafe { (*self.used.as_ptr()).avail_event.load(Ordering::Acquire) };
354            self.avail_idx >= avail_event.wrapping_add(1)
355        } else {
356            unsafe { (*self.used.as_ptr()).flags.load(Ordering::Acquire) & 0x0001 == 0 }
359        }
360    }
361
362    fn write_desc(&mut self, index: u16) {
365        let index = usize::from(index);
366        unsafe {
369            (*self.desc.as_ptr())[index] = self.desc_shadow[index].clone();
370        }
371    }
372
373    pub fn can_pop(&self) -> bool {
375        self.last_used_idx != unsafe { (*self.used.as_ptr()).idx.load(Ordering::Acquire) }
378    }
379
380    pub fn peek_used(&self) -> Option<u16> {
383        if self.can_pop() {
384            let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
385            Some(unsafe { (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16 })
388        } else {
389            None
390        }
391    }
392
393    pub fn available_desc(&self) -> usize {
395        #[cfg(feature = "alloc")]
396        if self.indirect {
397            return if usize::from(self.num_used) == SIZE {
398                0
399            } else {
400                SIZE
401            };
402        }
403
404        SIZE - usize::from(self.num_used)
405    }
406
407    unsafe fn recycle_descriptors<'a>(
418        &mut self,
419        head: u16,
420        inputs: &'a [&'a [u8]],
421        outputs: &'a mut [&'a mut [u8]],
422    ) {
423        let original_free_head = self.free_head;
424        self.free_head = head;
425
426        let head_desc = &mut self.desc_shadow[usize::from(head)];
427        if head_desc.flags.contains(DescFlags::INDIRECT) {
428            #[cfg(feature = "alloc")]
429            {
430                let indirect_list = self.indirect_lists[usize::from(head)].take().unwrap();
433                let mut indirect_list = unsafe { Box::from_raw(indirect_list.as_ptr()) };
436                let paddr = head_desc.addr;
437                head_desc.unset_buf();
438                self.num_used -= 1;
439                head_desc.next = original_free_head;
440
441                unsafe {
445                    H::unshare(
446                        paddr as usize,
447                        indirect_list.as_mut_bytes().into(),
448                        BufferDirection::DriverToDevice,
449                    );
450                }
451
452                assert_eq!(indirect_list.len(), inputs.len() + outputs.len());
454                for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
455                    assert_ne!(buffer.len(), 0);
456
457                    unsafe {
460                        H::unshare(indirect_list[i].addr as usize, buffer, direction);
463                    }
464                }
465                drop(indirect_list);
466            }
467        } else {
468            let mut next = Some(head);
469
470            for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
471                assert_ne!(buffer.len(), 0);
472
473                let desc_index = next.expect("Descriptor chain was shorter than expected.");
474                let desc = &mut self.desc_shadow[usize::from(desc_index)];
475
476                let paddr = desc.addr;
477                desc.unset_buf();
478                self.num_used -= 1;
479                next = desc.next();
480                if next.is_none() {
481                    desc.next = original_free_head;
482                }
483
484                self.write_desc(desc_index);
485
486                unsafe {
489                    H::unshare(paddr as usize, buffer, direction);
491                }
492            }
493
494            if next.is_some() {
495                panic!("Descriptor chain was longer than expected.");
496            }
497        }
498    }
499
500    pub unsafe fn pop_used<'a>(
510        &mut self,
511        token: u16,
512        inputs: &'a [&'a [u8]],
513        outputs: &'a mut [&'a mut [u8]],
514    ) -> Result<u32> {
515        if !self.can_pop() {
516            return Err(Error::NotReady);
517        }
518
519        let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
521        let index;
522        let len;
523        unsafe {
526            index = (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16;
527            len = (*self.used.as_ptr()).ring[last_used_slot as usize].len;
528        }
529
530        if index != token {
531            return Err(Error::WrongToken);
533        }
534
535        unsafe {
537            self.recycle_descriptors(index, inputs, outputs);
538        }
539        self.last_used_idx = self.last_used_idx.wrapping_add(1);
540
541        if self.event_idx {
542            unsafe {
545                (*self.avail.as_ptr())
546                    .used_event
547                    .store(self.last_used_idx, Ordering::Release);
548            }
549        }
550
551        Ok(len)
552    }
553}
554
555unsafe impl<H: Hal, const SIZE: usize> Send for VirtQueue<H, SIZE> {}
557
558unsafe impl<H: Hal, const SIZE: usize> Sync for VirtQueue<H, SIZE> {}
561
562#[derive(Debug)]
566enum VirtQueueLayout<H: Hal> {
567    Legacy {
568        dma: Dma<H>,
569        avail_offset: usize,
570        used_offset: usize,
571    },
572    Modern {
573        driver_to_device_dma: Dma<H>,
575        device_to_driver_dma: Dma<H>,
577        avail_offset: usize,
580    },
581}
582
583impl<H: Hal> VirtQueueLayout<H> {
584    fn allocate_legacy(queue_size: u16) -> Result<Self> {
589        let (desc, avail, used) = queue_part_sizes(queue_size);
590        let size = align_up(desc + avail) + align_up(used);
591        let dma = Dma::new(size / PAGE_SIZE, BufferDirection::Both)?;
593        Ok(Self::Legacy {
594            dma,
595            avail_offset: desc,
596            used_offset: align_up(desc + avail),
597        })
598    }
599
600    fn allocate_flexible(queue_size: u16) -> Result<Self> {
606        let (desc, avail, used) = queue_part_sizes(queue_size);
607        let driver_to_device_dma = Dma::new(pages(desc + avail), BufferDirection::DriverToDevice)?;
608        let device_to_driver_dma = Dma::new(pages(used), BufferDirection::DeviceToDriver)?;
609        Ok(Self::Modern {
610            driver_to_device_dma,
611            device_to_driver_dma,
612            avail_offset: desc,
613        })
614    }
615
616    fn descriptors_paddr(&self) -> PhysAddr {
618        match self {
619            Self::Legacy { dma, .. } => dma.paddr(),
620            Self::Modern {
621                driver_to_device_dma,
622                ..
623            } => driver_to_device_dma.paddr(),
624        }
625    }
626
627    fn descriptors_vaddr(&self) -> NonNull<u8> {
629        match self {
630            Self::Legacy { dma, .. } => dma.vaddr(0),
631            Self::Modern {
632                driver_to_device_dma,
633                ..
634            } => driver_to_device_dma.vaddr(0),
635        }
636    }
637
638    fn driver_area_paddr(&self) -> PhysAddr {
640        match self {
641            Self::Legacy {
642                dma, avail_offset, ..
643            } => dma.paddr() + avail_offset,
644            Self::Modern {
645                driver_to_device_dma,
646                avail_offset,
647                ..
648            } => driver_to_device_dma.paddr() + avail_offset,
649        }
650    }
651
652    fn avail_vaddr(&self) -> NonNull<u8> {
654        match self {
655            Self::Legacy {
656                dma, avail_offset, ..
657            } => dma.vaddr(*avail_offset),
658            Self::Modern {
659                driver_to_device_dma,
660                avail_offset,
661                ..
662            } => driver_to_device_dma.vaddr(*avail_offset),
663        }
664    }
665
666    fn device_area_paddr(&self) -> PhysAddr {
668        match self {
669            Self::Legacy {
670                used_offset, dma, ..
671            } => dma.paddr() + used_offset,
672            Self::Modern {
673                device_to_driver_dma,
674                ..
675            } => device_to_driver_dma.paddr(),
676        }
677    }
678
679    fn used_vaddr(&self) -> NonNull<u8> {
681        match self {
682            Self::Legacy {
683                dma, used_offset, ..
684            } => dma.vaddr(*used_offset),
685            Self::Modern {
686                device_to_driver_dma,
687                ..
688            } => device_to_driver_dma.vaddr(0),
689        }
690    }
691}
692
693fn queue_part_sizes(queue_size: u16) -> (usize, usize, usize) {
698    assert!(
699        queue_size.is_power_of_two(),
700        "queue size should be a power of 2"
701    );
702    let queue_size = queue_size as usize;
703    let desc = size_of::<Descriptor>() * queue_size;
704    let avail = size_of::<u16>() * (3 + queue_size);
705    let used = size_of::<u16>() * 3 + size_of::<UsedElem>() * queue_size;
706    (desc, avail, used)
707}
708
709#[repr(C, align(16))]
710#[derive(Clone, Debug, FromBytes, Immutable, IntoBytes, KnownLayout)]
711pub(crate) struct Descriptor {
712    addr: u64,
713    len: u32,
714    flags: DescFlags,
715    next: u16,
716}
717
718impl Descriptor {
719    unsafe fn set_buf<H: Hal>(
725        &mut self,
726        buf: NonNull<[u8]>,
727        direction: BufferDirection,
728        extra_flags: DescFlags,
729    ) {
730        unsafe {
732            self.addr = H::share(buf, direction) as u64;
733        }
734        self.len = buf.len().try_into().unwrap();
735        self.flags = extra_flags
736            | match direction {
737                BufferDirection::DeviceToDriver => DescFlags::WRITE,
738                BufferDirection::DriverToDevice => DescFlags::empty(),
739                BufferDirection::Both => {
740                    panic!("Buffer passed to device should never use BufferDirection::Both.")
741                }
742            };
743    }
744
745    fn unset_buf(&mut self) {
749        self.addr = 0;
750        self.len = 0;
751    }
752
753    fn next(&self) -> Option<u16> {
756        if self.flags.contains(DescFlags::NEXT) {
757            Some(self.next)
758        } else {
759            None
760        }
761    }
762}
763
764#[derive(
766    Copy, Clone, Debug, Default, Eq, FromBytes, Immutable, IntoBytes, KnownLayout, PartialEq,
767)]
768#[repr(transparent)]
769struct DescFlags(u16);
770
771bitflags! {
772    impl DescFlags: u16 {
773        const NEXT = 1;
774        const WRITE = 2;
775        const INDIRECT = 4;
776    }
777}
778
779#[repr(C)]
783#[derive(Debug)]
784struct AvailRing<const SIZE: usize> {
785    flags: AtomicU16,
786    idx: AtomicU16,
788    ring: [u16; SIZE],
789    used_event: AtomicU16,
791}
792
793#[repr(C)]
796#[derive(Debug)]
797struct UsedRing<const SIZE: usize> {
798    flags: AtomicU16,
799    idx: AtomicU16,
800    ring: [UsedElem; SIZE],
801    avail_event: AtomicU16,
803}
804
805#[repr(C)]
806#[derive(Debug)]
807struct UsedElem {
808    id: u32,
809    len: u32,
810}
811
812struct InputOutputIter<'a, 'b> {
813    inputs: &'a [&'b [u8]],
814    outputs: &'a mut [&'b mut [u8]],
815}
816
817impl<'a, 'b> InputOutputIter<'a, 'b> {
818    fn new(inputs: &'a [&'b [u8]], outputs: &'a mut [&'b mut [u8]]) -> Self {
819        Self { inputs, outputs }
820    }
821}
822
823impl Iterator for InputOutputIter<'_, '_> {
824    type Item = (NonNull<[u8]>, BufferDirection);
825
826    fn next(&mut self) -> Option<Self::Item> {
827        if let Some(input) = take_first(&mut self.inputs) {
828            Some(((*input).into(), BufferDirection::DriverToDevice))
829        } else {
830            let output = take_first_mut(&mut self.outputs)?;
831            Some(((*output).into(), BufferDirection::DeviceToDriver))
832        }
833    }
834}
835
836fn take_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> {
839    let (first, rem) = slice.split_first()?;
840    *slice = rem;
841    Some(first)
842}
843
844fn take_first_mut<'a, T>(slice: &mut &'a mut [T]) -> Option<&'a mut T> {
847    let (first, rem) = take(slice).split_first_mut()?;
848    *slice = rem;
849    Some(first)
850}
851
852#[cfg(test)]
859pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
860    descriptors: *const [Descriptor; QUEUE_SIZE],
861    queue_driver_area: *const u8,
862    queue_device_area: *mut u8,
863    handler: impl FnOnce(Vec<u8>) -> Vec<u8>,
864) -> bool {
865    use core::{ops::Deref, slice};
866
867    let available_ring = queue_driver_area as *const AvailRing<QUEUE_SIZE>;
868    let used_ring = queue_device_area as *mut UsedRing<QUEUE_SIZE>;
869
870    unsafe {
873        if (*available_ring).idx.load(Ordering::Acquire) == (*used_ring).idx.load(Ordering::Acquire)
875        {
876            return false;
877        }
878        let next_slot = (*used_ring).idx.load(Ordering::Acquire) & (QUEUE_SIZE as u16 - 1);
881        let head_descriptor_index = (*available_ring).ring[next_slot as usize];
882        let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
883
884        let input_length;
885        let output;
886        if descriptor.flags.contains(DescFlags::INDIRECT) {
887            assert_eq!(descriptor.flags, DescFlags::INDIRECT);
889
890            let indirect_descriptor_list: &[Descriptor] = zerocopy::Ref::into_ref(
893                zerocopy::Ref::<_, [Descriptor]>::from_bytes(slice::from_raw_parts(
894                    descriptor.addr as *const u8,
895                    descriptor.len as usize,
896                ))
897                .unwrap(),
898            );
899            let mut input = Vec::new();
900            let mut indirect_descriptor_index = 0;
901            while indirect_descriptor_index < indirect_descriptor_list.len() {
902                let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index];
903                if indirect_descriptor.flags.contains(DescFlags::WRITE) {
904                    break;
905                }
906
907                input.extend_from_slice(slice::from_raw_parts(
908                    indirect_descriptor.addr as *const u8,
909                    indirect_descriptor.len as usize,
910                ));
911
912                indirect_descriptor_index += 1;
913            }
914            input_length = input.len();
915
916            output = handler(input);
918
919            let mut remaining_output = output.deref();
921            while indirect_descriptor_index < indirect_descriptor_list.len() {
922                let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index];
923                assert!(indirect_descriptor.flags.contains(DescFlags::WRITE));
924
925                let length_to_write = min(remaining_output.len(), indirect_descriptor.len as usize);
926                ptr::copy(
927                    remaining_output.as_ptr(),
928                    indirect_descriptor.addr as *mut u8,
929                    length_to_write,
930                );
931                remaining_output = &remaining_output[length_to_write..];
932
933                indirect_descriptor_index += 1;
934            }
935            assert_eq!(remaining_output.len(), 0);
936        } else {
937            let mut input = Vec::new();
939            while !descriptor.flags.contains(DescFlags::WRITE) {
940                input.extend_from_slice(slice::from_raw_parts(
941                    descriptor.addr as *const u8,
942                    descriptor.len as usize,
943                ));
944
945                if let Some(next) = descriptor.next() {
946                    descriptor = &(*descriptors)[next as usize];
947                } else {
948                    break;
949                }
950            }
951            input_length = input.len();
952
953            output = handler(input);
955
956            let mut remaining_output = output.deref();
958            if descriptor.flags.contains(DescFlags::WRITE) {
959                loop {
960                    assert!(descriptor.flags.contains(DescFlags::WRITE));
961
962                    let length_to_write = min(remaining_output.len(), descriptor.len as usize);
963                    ptr::copy(
964                        remaining_output.as_ptr(),
965                        descriptor.addr as *mut u8,
966                        length_to_write,
967                    );
968                    remaining_output = &remaining_output[length_to_write..];
969
970                    if let Some(next) = descriptor.next() {
971                        descriptor = &(*descriptors)[next as usize];
972                    } else {
973                        break;
974                    }
975                }
976            }
977            assert_eq!(remaining_output.len(), 0);
978        }
979
980        (*used_ring).ring[next_slot as usize].id = head_descriptor_index.into();
982        (*used_ring).ring[next_slot as usize].len = (input_length + output.len()) as u32;
983        (*used_ring).idx.fetch_add(1, Ordering::AcqRel);
984
985        true
986    }
987}
988
989#[cfg(test)]
990mod tests {
991    use super::*;
992    use crate::{
993        device::common::Feature,
994        hal::fake::FakeHal,
995        transport::{
996            fake::{FakeTransport, QueueStatus, State},
997            mmio::{MmioTransport, VirtIOHeader, MODERN_VERSION},
998            DeviceType,
999        },
1000    };
1001    use safe_mmio::UniqueMmioPointer;
1002    use std::sync::{Arc, Mutex};
1003
1004    #[test]
1005    fn queue_too_big() {
1006        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1007        let mut transport = MmioTransport::new_from_unique(
1008            UniqueMmioPointer::from(&mut header),
1009            UniqueMmioPointer::from([].as_mut_slice()),
1010        )
1011        .unwrap();
1012        assert_eq!(
1013            VirtQueue::<FakeHal, 8>::new(&mut transport, 0, false, false).unwrap_err(),
1014            Error::InvalidParam
1015        );
1016    }
1017
1018    #[test]
1019    fn queue_already_used() {
1020        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1021        let mut transport = MmioTransport::new_from_unique(
1022            UniqueMmioPointer::from(&mut header),
1023            UniqueMmioPointer::from([].as_mut_slice()),
1024        )
1025        .unwrap();
1026        VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1027        assert_eq!(
1028            VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap_err(),
1029            Error::AlreadyUsed
1030        );
1031    }
1032
1033    #[test]
1034    fn add_empty() {
1035        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1036        let mut transport = MmioTransport::new_from_unique(
1037            UniqueMmioPointer::from(&mut header),
1038            UniqueMmioPointer::from([].as_mut_slice()),
1039        )
1040        .unwrap();
1041        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1042        assert_eq!(
1043            unsafe { queue.add(&[], &mut []) }.unwrap_err(),
1044            Error::InvalidParam
1045        );
1046    }
1047
1048    #[test]
1049    fn add_too_many() {
1050        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1051        let mut transport = MmioTransport::new_from_unique(
1052            UniqueMmioPointer::from(&mut header),
1053            UniqueMmioPointer::from([].as_mut_slice()),
1054        )
1055        .unwrap();
1056        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1057        assert_eq!(queue.available_desc(), 4);
1058        assert_eq!(
1059            unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
1060            Error::QueueFull
1061        );
1062    }
1063
1064    #[test]
1065    fn add_buffers() {
1066        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1067        let mut transport = MmioTransport::new_from_unique(
1068            UniqueMmioPointer::from(&mut header),
1069            UniqueMmioPointer::from([].as_mut_slice()),
1070        )
1071        .unwrap();
1072        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1073        assert_eq!(queue.available_desc(), 4);
1074
1075        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
1078
1079        assert_eq!(queue.available_desc(), 0);
1080        assert!(!queue.can_pop());
1081
1082        unsafe {
1085            let first_descriptor_index = (*queue.avail.as_ptr()).ring[0];
1086            assert_eq!(first_descriptor_index, token);
1087            assert_eq!(
1088                (*queue.desc.as_ptr())[first_descriptor_index as usize].len,
1089                2
1090            );
1091            assert_eq!(
1092                (*queue.desc.as_ptr())[first_descriptor_index as usize].flags,
1093                DescFlags::NEXT
1094            );
1095            let second_descriptor_index =
1096                (*queue.desc.as_ptr())[first_descriptor_index as usize].next;
1097            assert_eq!(
1098                (*queue.desc.as_ptr())[second_descriptor_index as usize].len,
1099                1
1100            );
1101            assert_eq!(
1102                (*queue.desc.as_ptr())[second_descriptor_index as usize].flags,
1103                DescFlags::NEXT
1104            );
1105            let third_descriptor_index =
1106                (*queue.desc.as_ptr())[second_descriptor_index as usize].next;
1107            assert_eq!(
1108                (*queue.desc.as_ptr())[third_descriptor_index as usize].len,
1109                2
1110            );
1111            assert_eq!(
1112                (*queue.desc.as_ptr())[third_descriptor_index as usize].flags,
1113                DescFlags::NEXT | DescFlags::WRITE
1114            );
1115            let fourth_descriptor_index =
1116                (*queue.desc.as_ptr())[third_descriptor_index as usize].next;
1117            assert_eq!(
1118                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].len,
1119                1
1120            );
1121            assert_eq!(
1122                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].flags,
1123                DescFlags::WRITE
1124            );
1125        }
1126    }
1127
1128    #[cfg(feature = "alloc")]
1129    #[test]
1130    fn add_buffers_indirect() {
1131        use core::ptr::slice_from_raw_parts;
1132
1133        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1134        let mut transport = MmioTransport::new_from_unique(
1135            UniqueMmioPointer::from(&mut header),
1136            UniqueMmioPointer::from([].as_mut_slice()),
1137        )
1138        .unwrap();
1139        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, true, false).unwrap();
1140        assert_eq!(queue.available_desc(), 4);
1141
1142        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
1145
1146        assert_eq!(queue.available_desc(), 4);
1147        assert!(!queue.can_pop());
1148
1149        unsafe {
1152            let indirect_descriptor_index = (*queue.avail.as_ptr()).ring[0];
1153            assert_eq!(indirect_descriptor_index, token);
1154            assert_eq!(
1155                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].len as usize,
1156                4 * size_of::<Descriptor>()
1157            );
1158            assert_eq!(
1159                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].flags,
1160                DescFlags::INDIRECT
1161            );
1162
1163            let indirect_descriptors = slice_from_raw_parts(
1164                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].addr
1165                    as *const Descriptor,
1166                4,
1167            );
1168            assert_eq!((*indirect_descriptors)[0].len, 2);
1169            assert_eq!((*indirect_descriptors)[0].flags, DescFlags::NEXT);
1170            assert_eq!((*indirect_descriptors)[0].next, 1);
1171            assert_eq!((*indirect_descriptors)[1].len, 1);
1172            assert_eq!((*indirect_descriptors)[1].flags, DescFlags::NEXT);
1173            assert_eq!((*indirect_descriptors)[1].next, 2);
1174            assert_eq!((*indirect_descriptors)[2].len, 2);
1175            assert_eq!(
1176                (*indirect_descriptors)[2].flags,
1177                DescFlags::NEXT | DescFlags::WRITE
1178            );
1179            assert_eq!((*indirect_descriptors)[2].next, 3);
1180            assert_eq!((*indirect_descriptors)[3].len, 1);
1181            assert_eq!((*indirect_descriptors)[3].flags, DescFlags::WRITE);
1182        }
1183    }
1184
1185    #[test]
1187    fn set_dev_notify() {
1188        let state = Arc::new(Mutex::new(State::new(vec![QueueStatus::default()], ())));
1189        let mut transport = FakeTransport {
1190            device_type: DeviceType::Block,
1191            max_queue_size: 4,
1192            device_features: 0,
1193            state: state.clone(),
1194        };
1195        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1196
1197        assert_eq!(
1199            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1200            0x0
1201        );
1202
1203        queue.set_dev_notify(false);
1204
1205        assert_eq!(
1207            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1208            0x1
1209        );
1210
1211        queue.set_dev_notify(true);
1212
1213        assert_eq!(
1215            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1216            0x0
1217        );
1218    }
1219
1220    #[test]
1223    fn add_notify() {
1224        let state = Arc::new(Mutex::new(State::new(vec![QueueStatus::default()], ())));
1225        let mut transport = FakeTransport {
1226            device_type: DeviceType::Block,
1227            max_queue_size: 4,
1228            device_features: 0,
1229            state: state.clone(),
1230        };
1231        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1232
1233        unsafe { queue.add(&[&[42]], &mut []) }.unwrap();
1235
1236        assert_eq!(queue.should_notify(), true);
1238
1239        unsafe {
1242            (*queue.used.as_ptr()).flags.store(0x01, Ordering::Release);
1244        }
1245
1246        assert_eq!(queue.should_notify(), false);
1248    }
1249
1250    #[test]
1253    fn add_notify_event_idx() {
1254        let state = Arc::new(Mutex::new(State::new(vec![QueueStatus::default()], ())));
1255        let mut transport = FakeTransport {
1256            device_type: DeviceType::Block,
1257            max_queue_size: 4,
1258            device_features: Feature::RING_EVENT_IDX.bits(),
1259            state: state.clone(),
1260        };
1261        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, true).unwrap();
1262
1263        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 0);
1265
1266        assert_eq!(queue.should_notify(), true);
1268
1269        unsafe {
1272            (*queue.used.as_ptr())
1274                .avail_event
1275                .store(1, Ordering::Release);
1276        }
1277
1278        assert_eq!(queue.should_notify(), false);
1280
1281        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 1);
1283
1284        assert_eq!(queue.should_notify(), true);
1286    }
1287}