virtio_driver/
virtqueue.rs

1// SPDX-License-Identifier: (MIT OR Apache-2.0)
2
3//! A virtqueue implementation to be used internally by virtio device drivers.
4
5mod packed;
6mod split;
7
8use crate::{Iova, IovaTranslator, Le16, VirtioFeatureFlags};
9use bitflags::bitflags;
10use libc::iovec;
11use packed::VirtqueuePacked;
12use split::VirtqueueSplit;
13use std::io::{Error, ErrorKind};
14use std::mem;
15
16bitflags! {
17    struct VirtqueueDescriptorFlags: u16 {
18        const NEXT = 0x1;
19        const WRITE = 0x2;
20        const INDIRECT = 0x4;
21    }
22}
23
24/// A description how the memory passed for each virtqueue is split into individual regions.
25///
26/// * The Virtqueue Descriptor Area starts from offset 0
27/// * The Virtqueue Driver Area (Available Ring for split virtqueue) starts at `driver_area_offset`
28/// * The Virtqueue Device Area (Used Ring for split virtqueue) starts at `device_area_offset`
29/// * Driver-specific per request data that needs to be shared with the device (e.g. request
30///   headers or status bytes) start at `req_offset`
31pub struct VirtqueueLayout {
32    pub num_queues: usize,
33    pub driver_area_offset: usize,
34    pub device_area_offset: usize,
35    pub req_offset: usize,
36    pub end_offset: usize,
37}
38
39impl VirtqueueLayout {
40    pub fn new<R>(
41        num_queues: usize,
42        queue_size: usize,
43        features: VirtioFeatureFlags,
44    ) -> Result<Self, Error> {
45        if features.contains(VirtioFeatureFlags::RING_PACKED) {
46            let desc_bytes = mem::size_of::<packed::VirtqueueDescriptor>() * queue_size;
47            let event_suppress_bytes = mem::size_of::<packed::VirtqueueEventSuppress>();
48
49            Self::new_layout::<R>(
50                num_queues,
51                queue_size,
52                desc_bytes,
53                event_suppress_bytes,
54                event_suppress_bytes,
55            )
56        } else {
57            let desc_bytes = mem::size_of::<split::VirtqueueDescriptor>() * queue_size;
58            let avail_bytes = 8 + mem::size_of::<Le16>() * queue_size;
59            let used_bytes = 8 + mem::size_of::<split::VirtqueueUsedElem>() * queue_size;
60
61            // Check queue size requirements (see 2.6 in the VIRTIO 1.1 spec)
62            if !queue_size.is_power_of_two() || queue_size > 32768 {
63                return Err(Error::new(ErrorKind::InvalidInput, "Invalid queue size"));
64            }
65
66            // The used ring requires an alignment of 4 (see 2.6 in the VIRTIO 1.1 spec)
67            let avail_bytes = (avail_bytes + 3) & !0x3;
68
69            Self::new_layout::<R>(num_queues, queue_size, desc_bytes, avail_bytes, used_bytes)
70        }
71    }
72
73    fn new_layout<R>(
74        num_queues: usize,
75        queue_size: usize,
76        desc_bytes: usize,
77        driver_area_bytes: usize,
78        device_area_bytes: usize,
79    ) -> Result<Self, Error> {
80        let req_bytes = mem::size_of::<R>() * queue_size;
81
82        // Consider the required alignment of R
83        let req_align = mem::align_of::<R>();
84        let req_offset = desc_bytes + driver_area_bytes + device_area_bytes;
85        let req_offset_aligned = (req_offset + req_align - 1) & !(req_align - 1);
86
87        // Maintain 16-byte descriptor table alignment (see 2.7 in the VIRTIO 1.1 spec) in
88        // contiguous virtqueue arrays (useful for allocating memory for several queues at once)
89        let end_offset = (req_offset_aligned + req_bytes + 15) & !15;
90
91        Ok(VirtqueueLayout {
92            num_queues,
93            driver_area_offset: desc_bytes,
94            device_area_offset: desc_bytes + driver_area_bytes,
95            req_offset: req_offset_aligned,
96            end_offset,
97        })
98    }
99}
100
101/// An interface for the virtqueue formats supported by VIRTIO specification.
102trait VirtqueueFormat {
103    /// Returns the number of entries of the descriptor table.
104    fn queue_size(&self) -> u16;
105
106    /// Returns a raw pointer to the start of the descriptor table.
107    fn desc_table_ptr(&self) -> *const u8;
108
109    /// Returns a raw pointer to the start of the device area.
110    fn driver_area_ptr(&self) -> *const u8;
111
112    /// Returns a raw pointer to the start of the driver area.
113    fn device_area_ptr(&self) -> *const u8;
114
115    /// Returns an identifier for the next chain of the available descriptor.
116    fn avail_start_chain(&mut self) -> Option<u16>;
117
118    /// Rewinds the last chain if there were any errors during building.
119    ///
120    /// `chain_id` is the identifier returned by `avail_start_chain()`.
121    fn avail_rewind_chain(&mut self, chain_id: u16);
122
123    /// Add a descriptor to the current chain and return its index in
124    /// the descriptor table.
125    fn avail_add_desc_chain(
126        &mut self,
127        addr: u64,
128        len: u32,
129        flags: VirtqueueDescriptorFlags,
130    ) -> Result<u16, Error>;
131
132    /// Expose the available descriptor chain to the device.
133    ///
134    /// `chain_id` is the identifier returned by `avail_start_chain()`.
135    /// `last_desc_idx` is the index returned by avail_add_desc_chain() of the
136    /// last descriptor added in the chain.
137    fn avail_publish(&mut self, chain_id: u16, last_desc_idx: u16);
138
139    /// Returns `true` if there are used chains available.
140    fn used_has_next(&self) -> bool;
141
142    /// Returns the identifier of a chain used by the device.
143    fn used_next(&mut self) -> Option<u16>;
144
145    /// Returns lower and upper bound of used chains.
146    fn used_size_hint(&self) -> (usize, Option<usize>);
147
148    /// Returns `true` if the avail notifications are needed.
149    fn avail_notif_needed(&mut self) -> bool;
150
151    /// Enable or disable used notifications.
152    fn set_used_notif_enabled(&mut self, enabled: bool);
153}
154
155/// A virtqueue of a virtio device.
156///
157/// `R` is used to store device-specific per-request data (like the request header or status byte)
158/// in memory shared with the device and is copied on completion. Don't put things there that the
159/// device doesn't have to access, in the interest of both security and performance.
160pub struct Virtqueue<'a, R: Copy> {
161    iova_translator: Box<dyn IovaTranslator>,
162    format: Box<dyn VirtqueueFormat + 'a>,
163    req: *mut R,
164    layout: VirtqueueLayout,
165}
166
167// `Send` and `Sync` are not implemented automatically due to the `avail`, `used`, and `req` fields.
168unsafe impl<R: Copy> Send for Virtqueue<'_, R> {}
169unsafe impl<R: Copy> Sync for Virtqueue<'_, R> {}
170
171/// The result of a completed request
172pub struct VirtqueueCompletion<R> {
173    /// The identifier of the descriptors chain for the request as returned by [`add_request`].
174    ///
175    /// [`add_request`]: Virtqueue::add_request
176    pub id: u16,
177
178    /// Device-specific per-request data like the request header or status byte.
179    pub req: R,
180}
181
182impl<'a, R: Copy> Virtqueue<'a, R> {
183    /// Creates a new virtqueue in the passed memory buffer.
184    ///
185    /// `buf` has to be memory that is visible for the device. It is used to store all descriptors,
186    /// rings and device-specific per-request data for the queue.
187    pub fn new(
188        iova_translator: Box<dyn IovaTranslator>,
189        buf: &'a mut [u8],
190        queue_size: u16,
191        features: VirtioFeatureFlags,
192    ) -> Result<Self, Error> {
193        let layout = VirtqueueLayout::new::<R>(1, queue_size as usize, features)?;
194        let event_idx_enabled = features.contains(VirtioFeatureFlags::RING_EVENT_IDX);
195        let (format, req_mem) = if features.contains(VirtioFeatureFlags::RING_PACKED) {
196            let mem = buf.get_mut(0..layout.end_offset).ok_or_else(|| {
197                Error::new(
198                    ErrorKind::InvalidInput,
199                    "Incorrectly sized queue bu
200fer",
201                )
202            })?;
203
204            let (mem, req_mem) = mem.split_at_mut(layout.req_offset);
205            let (mem, device_es_mem) = mem.split_at_mut(layout.device_area_offset);
206            let (desc_mem, driver_es_mem) = mem.split_at_mut(layout.driver_area_offset);
207
208            let format: Box<dyn VirtqueueFormat + 'a> = Box::new(VirtqueuePacked::new(
209                desc_mem,
210                driver_es_mem,
211                device_es_mem,
212                queue_size,
213                event_idx_enabled,
214            )?);
215
216            (format, req_mem)
217        } else {
218            let mem = buf.get_mut(0..layout.end_offset).ok_or_else(|| {
219                Error::new(ErrorKind::InvalidInput, "Incorrectly sized queue buffer")
220            })?;
221
222            let (mem, req_mem) = mem.split_at_mut(layout.req_offset);
223            let (mem, used_mem) = mem.split_at_mut(layout.device_area_offset);
224            let (desc_mem, avail_mem) = mem.split_at_mut(layout.driver_area_offset);
225
226            let format: Box<dyn VirtqueueFormat + 'a> = Box::new(VirtqueueSplit::new(
227                avail_mem,
228                used_mem,
229                desc_mem,
230                queue_size,
231                event_idx_enabled,
232            )?);
233
234            (format, req_mem)
235        };
236
237        let req = req_mem.as_mut_ptr() as *mut R;
238        if req.align_offset(mem::align_of::<R>()) != 0 {
239            return Err(Error::new(
240                ErrorKind::InvalidInput,
241                "Insufficient memory alignment",
242            ));
243        }
244
245        Ok(Virtqueue {
246            iova_translator,
247            format,
248            req,
249            layout,
250        })
251    }
252
253    /// Returns the number of entries in each of the descriptor table and rings.
254    pub fn queue_size(&self) -> u16 {
255        self.format.queue_size()
256    }
257
258    /// Returns the virtqueue memory layout.
259    pub fn layout(&self) -> &VirtqueueLayout {
260        &self.layout
261    }
262
263    /// Returns a raw pointer to the start of the descriptor table.
264    pub fn desc_table_ptr(&self) -> *const u8 {
265        self.format.desc_table_ptr()
266    }
267
268    /// Returns a raw pointer to the start of the driver area.
269    pub fn driver_area_ptr(&self) -> *const u8 {
270        self.format.driver_area_ptr()
271    }
272
273    /// Returns a raw pointer to the start of the device area.
274    pub fn device_area_ptr(&self) -> *const u8 {
275        self.format.device_area_ptr()
276    }
277
278    /// Enqueues a new request.
279    ///
280    /// `prepare` is a function or closure that gets a reference to the device-specific per-request
281    /// data in its final location in the virtqueue memory and a FnMut to add virtio descriptors to
282    /// the request. It can set up the per-request data as necessary and must add all descriptors
283    /// needed for the request.
284    ///
285    /// The parameters of the FnMut it received are the `iovec` describing the buffer to be added
286    /// and a boolean `from_dev` that is `true` if this buffer is written by the device and `false`
287    /// if it is read by the device.
288    pub fn add_request<F>(&mut self, prepare: F) -> Result<u16, Error>
289    where
290        F: FnOnce(&mut R, &mut dyn FnMut(iovec, bool) -> Result<(), Error>) -> Result<(), Error>,
291    {
292        let chain_id = match self.format.avail_start_chain() {
293            None => {
294                return Err(Error::new(ErrorKind::Other, "Not enough free descriptors"));
295            }
296            Some(idx) => idx,
297        };
298
299        let req_ptr = unsafe { &mut *self.req.offset(chain_id as isize) };
300        let mut last_desc_idx: Option<u16> = None;
301
302        let res = prepare(req_ptr, &mut |iovec: iovec, from_dev: bool| {
303            // Set NEXT for all descriptors, it is unset again below for the last one
304            let mut flags = VirtqueueDescriptorFlags::NEXT;
305            if from_dev {
306                flags.insert(VirtqueueDescriptorFlags::WRITE);
307            }
308            let Iova(iova) = self
309                .iova_translator
310                .translate_addr(iovec.iov_base as usize, iovec.iov_len)?;
311            last_desc_idx = Some(self.format.avail_add_desc_chain(
312                iova,
313                iovec.iov_len as u32,
314                flags,
315            )?);
316            Ok(())
317        });
318
319        if let Err(e) = res {
320            self.format.avail_rewind_chain(chain_id);
321            return Err(e);
322        }
323
324        self.format.avail_publish(chain_id, last_desc_idx.unwrap());
325        Ok(chain_id)
326    }
327
328    /// Returns an iterator that returns all completed requests.
329    pub fn completions(&mut self) -> VirtqueueIter<'_, 'a, R> {
330        VirtqueueIter { virtqueue: self }
331    }
332
333    pub fn avail_notif_needed(&mut self) -> bool {
334        self.format.avail_notif_needed()
335    }
336
337    pub fn set_used_notif_enabled(&mut self, enabled: bool) {
338        self.format.set_used_notif_enabled(enabled)
339    }
340}
341
342/// An iterator that returns all completed requests.
343pub struct VirtqueueIter<'a, 'queue, R: Copy> {
344    virtqueue: &'a mut Virtqueue<'queue, R>,
345}
346
347impl<R: Copy> VirtqueueIter<'_, '_, R> {
348    pub fn has_next(&self) -> bool {
349        self.virtqueue.format.used_has_next()
350    }
351}
352
353impl<'a, 'queue, R: Copy> Iterator for VirtqueueIter<'a, 'queue, R> {
354    type Item = VirtqueueCompletion<R>;
355
356    fn next(&mut self) -> Option<Self::Item> {
357        let id = self.virtqueue.format.used_next()?;
358
359        let req = unsafe { *self.virtqueue.req.offset(id as isize) };
360        Some(VirtqueueCompletion { id, req })
361    }
362
363    fn size_hint(&self) -> (usize, Option<usize>) {
364        self.virtqueue.format.used_size_hint()
365    }
366}