virtio_vsock/
packet.rs

1// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2//
3// SPDX-License-Identifier: Apache-2.0 OR BSD-3-Clause
4
5//! Vsock packet abstraction.
6//!
7//! This module provides the following abstraction for parsing a vsock packet, and working with it:
8//!
9//! - [`VsockPacket`](struct.VsockPacket.html) which handles the parsing of the vsock packet from
10//! either a TX descriptor chain via
11//! [`VsockPacket::from_tx_virtq_chain`](struct.VsockPacket.html#method.from_tx_virtq_chain), or an
12//! RX descriptor chain via
13//! [`VsockPacket::from_rx_virtq_chain`](struct.VsockPacket.html#method.from_rx_virtq_chain).
14//! The virtio vsock packet is defined in the standard as having a header of type `virtio_vsock_hdr`
15//! and an optional `data` array of bytes. The methods mentioned above assume that both packet
16//! elements are on the same descriptor, or each of the packet elements occupies exactly one
17//! descriptor. For the usual drivers, this assumption stands,
18//! but in the future we might make the implementation more generic by removing any constraint
19//! regarding the number of descriptors that correspond to the header/data. The buffers associated
20//! to the TX virtio queue are device-readable, and the ones associated to the RX virtio queue are
21//! device-writable.
22///
23/// The `VsockPacket` abstraction is using vm-memory's `VolatileSlice` for representing the header
24/// and the data. `VolatileSlice` is a safe wrapper over a raw pointer, which also handles the dirty
25/// page tracking behind the scenes. A limitation of the current implementation is that it does not
26/// cover the scenario where the header or data buffer doesn't fit in a single `VolatileSlice`
27/// because the guest memory regions of the buffer are contiguous in the guest physical address
28/// space, but not in the host virtual one as well. If this becomes an use case, we can extend this
29/// solution to use an array of `VolatileSlice`s for the header and data.
30/// The `VsockPacket` abstraction is also storing a `virtio_vsock_hdr` instance (which is defined
31/// here as `PacketHeader`). This is needed so that we always access the same data that was read the
32/// first time from the descriptor chain. We avoid this way potential time-of-check time-of-use
33/// problems that may occur when reading later a header field from the underlying memory itself
34/// (i.e. from the header's `VolatileSlice` object).
35use std::fmt::{self, Display};
36use std::ops::Deref;
37
38use virtio_queue::DescriptorChain;
39use vm_memory::bitmap::{BitmapSlice, WithBitmapSlice};
40use vm_memory::{
41    Address, ByteValued, Bytes, GuestMemory, GuestMemoryError, GuestMemoryRegion, Le16, Le32, Le64,
42    VolatileMemoryError, VolatileSlice,
43};
44
45/// Vsock packet parsing errors.
46#[derive(Debug)]
47pub enum Error {
48    /// Too few descriptors in a descriptor chain.
49    DescriptorChainTooShort,
50    /// Descriptor that was too short to use.
51    DescriptorLengthTooSmall,
52    /// Descriptor that was too long to use.
53    DescriptorLengthTooLong,
54    /// The slice for creating a header has an invalid length.
55    InvalidHeaderInputSize(usize),
56    /// The `len` header field value exceeds the maximum allowed data size.
57    InvalidHeaderLen(u32),
58    /// Invalid guest memory access.
59    InvalidMemoryAccess(GuestMemoryError),
60    /// Invalid volatile memory access.
61    InvalidVolatileAccess(VolatileMemoryError),
62    /// Read only descriptor that protocol says to write to.
63    UnexpectedReadOnlyDescriptor,
64    /// Write only descriptor that protocol says to read from.
65    UnexpectedWriteOnlyDescriptor,
66}
67
68impl std::error::Error for Error {}
69
70impl Display for Error {
71    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
72        match self {
73            Error::DescriptorChainTooShort => {
74                write!(f, "There are not enough descriptors in the chain.")
75            }
76            Error::DescriptorLengthTooSmall => write!(
77                f,
78                "The descriptor is pointing to a buffer that has a smaller length than expected."
79            ),
80            Error::DescriptorLengthTooLong => write!(
81                f,
82                "The descriptor is pointing to a buffer that has a longer length than expected."
83            ),
84            Error::InvalidHeaderInputSize(size) => {
85                write!(f, "Invalid header input size: {}", size)
86            }
87            Error::InvalidHeaderLen(size) => {
88                write!(f, "Invalid header `len` field value: {}", size)
89            }
90            Error::InvalidMemoryAccess(error) => {
91                write!(f, "Invalid guest memory access: {}", error)
92            }
93            Error::InvalidVolatileAccess(error) => {
94                write!(f, "Invalid volatile memory access: {}", error)
95            }
96            Error::UnexpectedReadOnlyDescriptor => {
97                write!(f, "Unexpected read-only descriptor.")
98            }
99            Error::UnexpectedWriteOnlyDescriptor => {
100                write!(f, "Unexpected write-only descriptor.")
101            }
102        }
103    }
104}
105
106#[repr(C, packed)]
107#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
108/// The vsock packet header structure.
109pub struct PacketHeader {
110    src_cid: Le64,
111    dst_cid: Le64,
112    src_port: Le32,
113    dst_port: Le32,
114    len: Le32,
115    type_: Le16,
116    op: Le16,
117    flags: Le32,
118    buf_alloc: Le32,
119    fwd_cnt: Le32,
120}
121
122// SAFETY: This is safe because `PacketHeader` contains only wrappers over POD types
123// and all accesses through safe `vm-memory` API will validate any garbage that could
124// be included in there.
125unsafe impl ByteValued for PacketHeader {}
126//
127// This structure will occupy the buffer pointed to by the head of the descriptor chain. Below are
128// the offsets for each field, as well as the packed structure size.
129// Note that these offsets are only used privately by the `VsockPacket` struct, the public interface
130// consisting of getter and setter methods, for each struct field, that will also handle the correct
131// endianness.
132
133/// The size of the header structure (when packed).
134pub const PKT_HEADER_SIZE: usize = std::mem::size_of::<PacketHeader>();
135
136// Offsets of the header fields.
137const SRC_CID_OFFSET: usize = 0;
138const DST_CID_OFFSET: usize = 8;
139const SRC_PORT_OFFSET: usize = 16;
140const DST_PORT_OFFSET: usize = 20;
141const LEN_OFFSET: usize = 24;
142const TYPE_OFFSET: usize = 28;
143const OP_OFFSET: usize = 30;
144const FLAGS_OFFSET: usize = 32;
145const BUF_ALLOC_OFFSET: usize = 36;
146const FWD_CNT_OFFSET: usize = 40;
147
148/// Dedicated [`Result`](https://doc.rust-lang.org/std/result/) type.
149pub type Result<T> = std::result::Result<T, Error>;
150
151/// The vsock packet, implemented as a wrapper over a virtio descriptor chain:
152/// - the chain head, holding the packet header;
153/// - an optional data/buffer descriptor, only present for data packets (for VSOCK_OP_RW requests).
154#[derive(Debug)]
155pub struct VsockPacket<'a, B: BitmapSlice> {
156    // When writing to the header slice, we are using the `write` method of `VolatileSlice`s Bytes
157    // implementation. Because that can only return an error if we pass an invalid offset, we can
158    // safely use `unwraps` in the setters below. If we switch to a type different than
159    // `VolatileSlice`, this assumption can no longer hold. We also must always make sure the
160    // `VsockPacket` API is creating headers with PKT_HEADER_SIZE size.
161    header_slice: VolatileSlice<'a, B>,
162    header: PacketHeader,
163    data_slice: Option<VolatileSlice<'a, B>>,
164}
165
166// This macro is intended to be used for setting a header field in both the `VolatileSlice` and the
167// `PacketHeader` structure from a packet. `$offset` should be a valid offset in the `header_slice`,
168// otherwise the macro will panic.
169macro_rules! set_header_field {
170    ($packet:ident, $field:ident, $offset:ident, $value:ident) => {
171        $packet.header.$field = $value.into();
172        $packet
173            .header_slice
174            .write(&$value.to_le_bytes(), $offset)
175            // This unwrap is safe only if `$offset` is a valid offset in the `header_slice`.
176            .unwrap();
177    };
178}
179
180impl<'a, B: BitmapSlice> VsockPacket<'a, B> {
181    /// Return a reference to the `header_slice` of the packet.
182    pub fn header_slice(&self) -> &VolatileSlice<'a, B> {
183        &self.header_slice
184    }
185
186    /// Return a reference to the `data_slice` of the packet.
187    pub fn data_slice(&self) -> Option<&VolatileSlice<'a, B>> {
188        self.data_slice.as_ref()
189    }
190
191    /// Write to the packet header from an input of raw bytes.
192    ///
193    /// # Example
194    ///
195    /// ```rust
196    /// # use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE;
197    /// # use virtio_queue::mock::MockSplitQueue;
198    /// # use virtio_queue::{Descriptor, Queue, QueueT};
199    /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
200    /// # use vm_memory::{Bytes, GuestAddress, GuestAddressSpace, GuestMemoryMmap};
201    ///
202    /// const MAX_PKT_BUF_SIZE: u32 = 64 * 1024;
203    ///
204    /// # fn create_queue_with_chain(m: &GuestMemoryMmap) -> Queue {
205    /// #     let vq = MockSplitQueue::new(m, 16);
206    /// #     let mut q = vq.create_queue().unwrap();
207    /// #
208    /// #     let v = vec![
209    /// #         Descriptor::new(0x5_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
210    /// #         Descriptor::new(0x8_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
211    /// #     ];
212    /// #     let mut chain = vq.build_desc_chain(&v);
213    /// #     q
214    /// # }
215    /// let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
216    /// // Create a queue and populate it with a descriptor chain.
217    /// let mut queue = create_queue_with_chain(&mem);
218    ///
219    /// while let Some(mut head) = queue.pop_descriptor_chain(&mem) {
220    ///     let mut pkt = VsockPacket::from_rx_virtq_chain(&mem, &mut head, MAX_PKT_BUF_SIZE).unwrap();
221    ///     pkt.set_header_from_raw(&[0u8; PKT_HEADER_SIZE]).unwrap();
222    /// }
223    /// ```
224    pub fn set_header_from_raw(&mut self, bytes: &[u8]) -> Result<()> {
225        if bytes.len() != PKT_HEADER_SIZE {
226            return Err(Error::InvalidHeaderInputSize(bytes.len()));
227        }
228        self.header_slice
229            .write(bytes, 0)
230            .map_err(Error::InvalidVolatileAccess)?;
231        let header = self
232            .header_slice()
233            .read_obj::<PacketHeader>(0)
234            .map_err(Error::InvalidVolatileAccess)?;
235        self.header = header;
236        Ok(())
237    }
238
239    /// Return the `src_cid` of the header.
240    pub fn src_cid(&self) -> u64 {
241        self.header.src_cid.into()
242    }
243
244    /// Set the `src_cid` of the header.
245    pub fn set_src_cid(&mut self, cid: u64) -> &mut Self {
246        set_header_field!(self, src_cid, SRC_CID_OFFSET, cid);
247        self
248    }
249
250    /// Return the `dst_cid` of the header.
251    pub fn dst_cid(&self) -> u64 {
252        self.header.dst_cid.into()
253    }
254
255    /// Set the `dst_cid` of the header.
256    pub fn set_dst_cid(&mut self, cid: u64) -> &mut Self {
257        set_header_field!(self, dst_cid, DST_CID_OFFSET, cid);
258        self
259    }
260
261    /// Return the `src_port` of the header.
262    pub fn src_port(&self) -> u32 {
263        self.header.src_port.into()
264    }
265
266    /// Set the `src_port` of the header.
267    pub fn set_src_port(&mut self, port: u32) -> &mut Self {
268        set_header_field!(self, src_port, SRC_PORT_OFFSET, port);
269        self
270    }
271
272    /// Return the `dst_port` of the header.
273    pub fn dst_port(&self) -> u32 {
274        self.header.dst_port.into()
275    }
276
277    /// Set the `dst_port` of the header.
278    pub fn set_dst_port(&mut self, port: u32) -> &mut Self {
279        set_header_field!(self, dst_port, DST_PORT_OFFSET, port);
280        self
281    }
282
283    /// Return the `len` of the header.
284    pub fn len(&self) -> u32 {
285        self.header.len.into()
286    }
287
288    /// Returns whether the `len` field of the header is 0 or not.
289    pub fn is_empty(&self) -> bool {
290        self.len() == 0
291    }
292
293    /// Set the `len` of the header.
294    pub fn set_len(&mut self, len: u32) -> &mut Self {
295        set_header_field!(self, len, LEN_OFFSET, len);
296        self
297    }
298
299    /// Return the `type` of the header.
300    pub fn type_(&self) -> u16 {
301        self.header.type_.into()
302    }
303
304    /// Set the `type` of the header.
305    pub fn set_type(&mut self, type_: u16) -> &mut Self {
306        set_header_field!(self, type_, TYPE_OFFSET, type_);
307        self
308    }
309
310    /// Return the `op` of the header.
311    pub fn op(&self) -> u16 {
312        self.header.op.into()
313    }
314
315    /// Set the `op` of the header.
316    pub fn set_op(&mut self, op: u16) -> &mut Self {
317        set_header_field!(self, op, OP_OFFSET, op);
318        self
319    }
320
321    /// Return the `flags` of the header.
322    pub fn flags(&self) -> u32 {
323        self.header.flags.into()
324    }
325
326    /// Set the `flags` of the header.
327    pub fn set_flags(&mut self, flags: u32) -> &mut Self {
328        set_header_field!(self, flags, FLAGS_OFFSET, flags);
329        self
330    }
331
332    /// Set a specific flag of the header.
333    pub fn set_flag(&mut self, flag: u32) -> &mut Self {
334        self.set_flags(self.flags() | flag);
335        self
336    }
337
338    /// Return the `buf_alloc` of the header.
339    pub fn buf_alloc(&self) -> u32 {
340        self.header.buf_alloc.into()
341    }
342
343    /// Set the `buf_alloc` of the header.
344    pub fn set_buf_alloc(&mut self, buf_alloc: u32) -> &mut Self {
345        set_header_field!(self, buf_alloc, BUF_ALLOC_OFFSET, buf_alloc);
346        self
347    }
348
349    /// Return the `fwd_cnt` of the header.
350    pub fn fwd_cnt(&self) -> u32 {
351        self.header.fwd_cnt.into()
352    }
353
354    /// Set the `fwd_cnt` of the header.
355    pub fn set_fwd_cnt(&mut self, fwd_cnt: u32) -> &mut Self {
356        set_header_field!(self, fwd_cnt, FWD_CNT_OFFSET, fwd_cnt);
357        self
358    }
359
360    /// Create the packet wrapper from a TX chain.
361    ///
362    /// The chain head is expected to hold a valid packet header. A following packet data
363    /// descriptor can optionally end the chain.
364    ///
365    /// # Arguments
366    /// * `mem` - the `GuestMemory` object that can be used to access the queue buffers.
367    /// * `desc_chain` - the descriptor chain corresponding to a packet.
368    /// * `max_data_size` - the maximum size allowed for the packet payload, that was negotiated
369    ///                     between the device and the driver. Tracking issue for defining this
370    ///                     feature in virtio-spec
371    ///                     [here](https://github.com/oasis-tcs/virtio-spec/issues/140).
372    ///
373    /// # Example
374    ///
375    /// ```rust
376    /// # use virtio_queue::mock::MockSplitQueue;
377    /// # use virtio_queue::{Descriptor, Queue, QueueT};
378    /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
379    /// # use vm_memory::{Bytes, GuestAddress, GuestAddressSpace, GuestMemoryMmap};
380    ///
381    /// const MAX_PKT_BUF_SIZE: u32 = 64 * 1024;
382    /// const OP_RW: u16 = 5;
383    ///
384    /// # fn create_queue_with_chain(m: &GuestMemoryMmap) -> Queue {
385    /// #     let vq = MockSplitQueue::new(m, 16);
386    /// #     let mut q = vq.create_queue().unwrap();
387    /// #
388    /// #     let v = vec![
389    /// #         Descriptor::new(0x5_0000, 0x100, 0, 0),
390    /// #         Descriptor::new(0x8_0000, 0x100, 0, 0),
391    /// #     ];
392    /// #     let mut chain = vq.build_desc_chain(&v);
393    /// #     q
394    /// # }
395    /// let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap();
396    /// // Create a queue and populate it with a descriptor chain.
397    /// let mut queue = create_queue_with_chain(&mem);
398    ///
399    /// while let Some(mut head) = queue.pop_descriptor_chain(&mem) {
400    ///     let pkt = match VsockPacket::from_tx_virtq_chain(&mem, &mut head, MAX_PKT_BUF_SIZE) {
401    ///         Ok(pkt) => pkt,
402    ///         Err(_e) => {
403    ///             // Do some error handling.
404    ///             queue.add_used(&mem, head.head_index(), 0);
405    ///             continue;
406    ///         }
407    ///     };
408    ///     // Here we would send the packet to the backend. Depending on the operation type, a
409    ///     // different type of action will be done.
410    ///
411    ///     // For example, if it's a RW packet, we will forward the packet payload to the backend.
412    ///     if pkt.op() == OP_RW {
413    ///         // Send the packet payload to the backend.
414    ///     }
415    ///     queue.add_used(&mem, head.head_index(), 0);
416    /// }
417    /// ```
418    pub fn from_tx_virtq_chain<M, T>(
419        mem: &'a M,
420        desc_chain: &mut DescriptorChain<T>,
421        max_data_size: u32,
422    ) -> Result<Self>
423    where
424        M: GuestMemory,
425        <<M as GuestMemory>::R as GuestMemoryRegion>::B: WithBitmapSlice<'a, S = B>,
426        T: Deref,
427        T::Target: GuestMemory,
428    {
429        let chain_head = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
430        // All TX buffers must be device-readable.
431        if chain_head.is_write_only() {
432            return Err(Error::UnexpectedWriteOnlyDescriptor);
433        }
434
435        // The packet header should fit inside the buffer corresponding to the head descriptor.
436        if (chain_head.len() as usize) < PKT_HEADER_SIZE {
437            return Err(Error::DescriptorLengthTooSmall);
438        }
439
440        let header_slice = mem
441            .get_slice(chain_head.addr(), PKT_HEADER_SIZE)
442            .map_err(Error::InvalidMemoryAccess)?;
443
444        let header = mem
445            .read_obj(chain_head.addr())
446            .map_err(Error::InvalidMemoryAccess)?;
447
448        let mut pkt = Self {
449            header_slice,
450            header,
451            data_slice: None,
452        };
453
454        // If the `len` field of the header is zero, then the packet doesn't have a `data` element.
455        if pkt.is_empty() {
456            return Ok(pkt);
457        }
458
459        // Reject packets that exceed the maximum allowed value for payload.
460        if pkt.len() > max_data_size {
461            return Err(Error::InvalidHeaderLen(pkt.len()));
462        }
463
464        // Starting from Linux 6.2 the virtio-vsock driver can use a single descriptor for both
465        // header and data.
466        let data_slice =
467            if !chain_head.has_next() && chain_head.len() - PKT_HEADER_SIZE as u32 >= pkt.len() {
468                mem.get_slice(
469                    chain_head
470                        .addr()
471                        .checked_add(PKT_HEADER_SIZE as u64)
472                        .ok_or(Error::DescriptorLengthTooSmall)?,
473                    pkt.len() as usize,
474                )
475                .map_err(Error::InvalidMemoryAccess)?
476            } else {
477                if !chain_head.has_next() {
478                    return Err(Error::DescriptorChainTooShort);
479                }
480
481                let data_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
482
483                if data_desc.is_write_only() {
484                    return Err(Error::UnexpectedWriteOnlyDescriptor);
485                }
486
487                // The data buffer should be large enough to fit the size of the data, as described by
488                // the header descriptor.
489                if data_desc.len() < pkt.len() {
490                    return Err(Error::DescriptorLengthTooSmall);
491                }
492
493                mem.get_slice(data_desc.addr(), pkt.len() as usize)
494                    .map_err(Error::InvalidMemoryAccess)?
495            };
496
497        pkt.data_slice = Some(data_slice);
498        Ok(pkt)
499    }
500
501    /// Create the packet wrapper from an RX chain.
502    ///
503    /// There must be two descriptors in the chain, both writable: a header descriptor and a data
504    /// descriptor.
505    ///
506    /// # Arguments
507    /// * `mem` - the `GuestMemory` object that can be used to access the queue buffers.
508    /// * `desc_chain` - the descriptor chain corresponding to a packet.
509    /// * `max_data_size` - the maximum size allowed for the packet payload, that was negotiated
510    ///                     between the device and the driver. Tracking issue for defining this
511    ///                     feature in virtio-spec
512    ///                     [here](https://github.com/oasis-tcs/virtio-spec/issues/140).
513    ///
514    /// # Example
515    ///
516    /// ```rust
517    /// # use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE;
518    /// # use virtio_queue::mock::MockSplitQueue;
519    /// # use virtio_queue::{Descriptor, Queue, QueueT};
520    /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
521    /// # use vm_memory::{Bytes, GuestAddress, GuestAddressSpace, GuestMemoryMmap};
522    ///
523    /// # const MAX_PKT_BUF_SIZE: u32 = 64 * 1024;
524    /// # const SRC_CID: u64 = 1;
525    /// # const DST_CID: u64 = 2;
526    /// # const SRC_PORT: u32 = 3;
527    /// # const DST_PORT: u32 = 4;
528    /// # const LEN: u32 = 16;
529    /// # const TYPE_STREAM: u16 = 1;
530    /// # const OP_RW: u16 = 5;
531    /// # const FLAGS: u32 = 7;
532    /// # const FLAG: u32 = 8;
533    /// # const BUF_ALLOC: u32 = 256;
534    /// # const FWD_CNT: u32 = 9;
535    ///
536    /// # fn create_queue_with_chain(m: &GuestMemoryMmap) -> Queue {
537    /// #     let vq = MockSplitQueue::new(m, 16);
538    /// #     let mut q = vq.create_queue().unwrap();
539    /// #
540    /// #     let v = vec![
541    /// #         Descriptor::new(0x5_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
542    /// #         Descriptor::new(0x8_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
543    /// #     ];
544    /// #     let mut chain = vq.build_desc_chain(&v);
545    /// #    q
546    /// # }
547    /// let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap();
548    /// // Create a queue and populate it with a descriptor chain.
549    /// let mut queue = create_queue_with_chain(&mem);
550    ///
551    /// while let Some(mut head) = queue.pop_descriptor_chain(&mem) {
552    ///     let used_len = match VsockPacket::from_rx_virtq_chain(&mem, &mut head, MAX_PKT_BUF_SIZE) {
553    ///         Ok(mut pkt) => {
554    ///             // Make sure the header is zeroed out first.
555    ///             pkt.header_slice()
556    ///                 .write(&[0u8; PKT_HEADER_SIZE], 0)
557    ///                 .unwrap();
558    ///             // Write data to the packet, using the setters.
559    ///             pkt.set_src_cid(SRC_CID)
560    ///                 .set_dst_cid(DST_CID)
561    ///                 .set_src_port(SRC_PORT)
562    ///                 .set_dst_port(DST_PORT)
563    ///                 .set_type(TYPE_STREAM)
564    ///                 .set_buf_alloc(BUF_ALLOC)
565    ///                 .set_fwd_cnt(FWD_CNT);
566    ///             // In this example, we are sending a RW packet.
567    ///             pkt.data_slice()
568    ///                 .unwrap()
569    ///                 .write_slice(&[1u8; LEN as usize], 0);
570    ///             pkt.set_op(OP_RW).set_len(LEN);
571    ///             pkt.header_slice().len() as u32 + LEN
572    ///         }
573    ///         Err(_e) => {
574    ///             // Do some error handling.
575    ///             0
576    ///         }
577    ///     };
578    ///     queue.add_used(&mem, head.head_index(), used_len);
579    /// }
580    /// ```
581    pub fn from_rx_virtq_chain<M, T>(
582        mem: &'a M,
583        desc_chain: &mut DescriptorChain<T>,
584        max_data_size: u32,
585    ) -> Result<Self>
586    where
587        M: GuestMemory,
588        <<M as GuestMemory>::R as GuestMemoryRegion>::B: WithBitmapSlice<'a, S = B>,
589        T: Deref,
590        T::Target: GuestMemory,
591    {
592        let chain_head = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
593        // All RX buffers must be device-writable.
594        if !chain_head.is_write_only() {
595            return Err(Error::UnexpectedReadOnlyDescriptor);
596        }
597
598        // The packet header should fit inside the head descriptor.
599        if (chain_head.len() as usize) < PKT_HEADER_SIZE {
600            return Err(Error::DescriptorLengthTooSmall);
601        }
602
603        let header_slice = mem
604            .get_slice(chain_head.addr(), PKT_HEADER_SIZE)
605            .map_err(Error::InvalidMemoryAccess)?;
606
607        // Starting from Linux 6.2 the virtio-vsock driver can use a single descriptor for both
608        // header and data.
609        let data_slice = if !chain_head.has_next() && chain_head.len() as usize > PKT_HEADER_SIZE {
610            mem.get_slice(
611                chain_head
612                    .addr()
613                    .checked_add(PKT_HEADER_SIZE as u64)
614                    .ok_or(Error::DescriptorLengthTooSmall)?,
615                chain_head.len() as usize - PKT_HEADER_SIZE,
616            )
617            .map_err(Error::InvalidMemoryAccess)?
618        } else {
619            if !chain_head.has_next() {
620                return Err(Error::DescriptorChainTooShort);
621            }
622
623            let data_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
624
625            if !data_desc.is_write_only() {
626                return Err(Error::UnexpectedReadOnlyDescriptor);
627            }
628
629            if data_desc.len() > max_data_size {
630                return Err(Error::DescriptorLengthTooLong);
631            }
632
633            mem.get_slice(data_desc.addr(), data_desc.len() as usize)
634                .map_err(Error::InvalidMemoryAccess)?
635        };
636
637        Ok(Self {
638            header_slice,
639            header: Default::default(),
640            data_slice: Some(data_slice),
641        })
642    }
643}
644
645impl<'a> VsockPacket<'a, ()> {
646    /// Create a packet based on one pointer for the header, and an optional one for data.
647    ///
648    /// # Safety
649    ///
650    /// To use this safely, the caller must guarantee that the memory pointed to by the `hdr` and
651    /// `data` slices is available for the duration of the lifetime of the new `VolatileSlice`. The
652    /// caller must also guarantee that all other users of the given chunk of memory are using
653    /// volatile accesses.
654    ///
655    /// # Example
656    ///
657    /// ```rust
658    /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
659    ///
660    /// const LEN: usize = 16;
661    ///
662    /// let mut pkt_raw = [0u8; PKT_HEADER_SIZE + LEN];
663    /// let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);
664    /// // Safe because `hdr_raw` and `data_raw` live for as long as the scope of the current
665    /// // example.
666    /// let packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };
667    /// ```
668    pub unsafe fn new(header: &mut [u8], data: Option<&mut [u8]>) -> Result<VsockPacket<'a, ()>> {
669        if header.len() != PKT_HEADER_SIZE {
670            return Err(Error::InvalidHeaderInputSize(header.len()));
671        }
672        Ok(VsockPacket {
673            header_slice: VolatileSlice::new(header.as_mut_ptr(), PKT_HEADER_SIZE),
674            header: Default::default(),
675            data_slice: data.map(|data| VolatileSlice::new(data.as_mut_ptr(), data.len())),
676        })
677    }
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683
684    use vm_memory::{GuestAddress, GuestMemoryMmap};
685
686    use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE;
687    use virtio_queue::mock::MockSplitQueue;
688    use virtio_queue::Descriptor;
689
690    impl PartialEq for Error {
691        fn eq(&self, other: &Self) -> bool {
692            use self::Error::*;
693            match (self, other) {
694                (DescriptorChainTooShort, DescriptorChainTooShort) => true,
695                (DescriptorLengthTooSmall, DescriptorLengthTooSmall) => true,
696                (DescriptorLengthTooLong, DescriptorLengthTooLong) => true,
697                (InvalidHeaderInputSize(size), InvalidHeaderInputSize(other_size)) => {
698                    size == other_size
699                }
700                (InvalidHeaderLen(size), InvalidHeaderLen(other_size)) => size == other_size,
701                (InvalidMemoryAccess(ref e), InvalidMemoryAccess(ref other_e)) => {
702                    format!("{}", e).eq(&format!("{}", other_e))
703                }
704                (InvalidVolatileAccess(ref e), InvalidVolatileAccess(ref other_e)) => {
705                    format!("{}", e).eq(&format!("{}", other_e))
706                }
707                (UnexpectedReadOnlyDescriptor, UnexpectedReadOnlyDescriptor) => true,
708                (UnexpectedWriteOnlyDescriptor, UnexpectedWriteOnlyDescriptor) => true,
709                _ => false,
710            }
711        }
712    }
713
714    // Random values to be used by the tests for the header fields.
715    const SRC_CID: u64 = 1;
716    const DST_CID: u64 = 2;
717    const SRC_PORT: u32 = 3;
718    const DST_PORT: u32 = 4;
719    const LEN: u32 = 16;
720    const TYPE: u16 = 5;
721    const OP: u16 = 6;
722    const FLAGS: u32 = 7;
723    const FLAG: u32 = 8;
724    const BUF_ALLOC: u32 = 256;
725    const FWD_CNT: u32 = 9;
726
727    const MAX_PKT_BUF_SIZE: u32 = 64 * 1024;
728
729    #[test]
730    fn test_from_rx_virtq_chain() {
731        let mem: GuestMemoryMmap =
732            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x1000_0000)]).unwrap();
733
734        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
735        let v = vec![
736            // A device-readable packet header descriptor should be invalid.
737            Descriptor::new(0x10_0000, 0x100, 0, 0),
738            Descriptor::new(0x20_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
739        ];
740        let queue = MockSplitQueue::new(&mem, 16);
741        let mut chain = queue.build_desc_chain(&v).unwrap();
742        assert_eq!(
743            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
744            Error::UnexpectedReadOnlyDescriptor
745        );
746
747        let v = vec![
748            // A header length < PKT_HEADER_SIZE is invalid.
749            Descriptor::new(
750                0x10_0000,
751                PKT_HEADER_SIZE as u32 - 1,
752                VRING_DESC_F_WRITE as u16,
753                0,
754            ),
755            Descriptor::new(0x20_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
756        ];
757        let mut chain = queue.build_desc_chain(&v).unwrap();
758        assert_eq!(
759            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
760            Error::DescriptorLengthTooSmall
761        );
762
763        let v = vec![
764            Descriptor::new(
765                0x10_0000,
766                PKT_HEADER_SIZE as u32,
767                VRING_DESC_F_WRITE as u16,
768                0,
769            ),
770            Descriptor::new(
771                0x20_0000,
772                MAX_PKT_BUF_SIZE + 1,
773                VRING_DESC_F_WRITE as u16,
774                0,
775            ),
776        ];
777        let mut chain = queue.build_desc_chain(&v).unwrap();
778        assert_eq!(
779            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
780            Error::DescriptorLengthTooLong
781        );
782
783        let v = vec![
784            // The data descriptor should always be present on the RX path.
785            Descriptor::new(
786                0x10_0000,
787                PKT_HEADER_SIZE as u32,
788                VRING_DESC_F_WRITE as u16,
789                0,
790            ),
791        ];
792        let mut chain = queue.build_desc_chain(&v).unwrap();
793        assert_eq!(
794            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
795            Error::DescriptorChainTooShort
796        );
797
798        let v = vec![
799            Descriptor::new(0x10_0000, 0x100, 0, 0),
800            Descriptor::new(0x20_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
801        ];
802        let mut chain = queue.build_desc_chain(&v).unwrap();
803        assert_eq!(
804            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
805            Error::UnexpectedReadOnlyDescriptor
806        );
807
808        let mem: GuestMemoryMmap =
809            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0004)]).unwrap();
810
811        let v = vec![
812            // The header doesn't fit entirely in the memory bounds.
813            Descriptor::new(0x10_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
814            Descriptor::new(0x20_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
815        ];
816        let queue = MockSplitQueue::new(&mem, 16);
817        let mut chain = queue.build_desc_chain(&v).unwrap();
818        assert_eq!(
819            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
820            Error::InvalidMemoryAccess(GuestMemoryError::InvalidBackendAddress)
821        );
822
823        let v = vec![
824            // The header is outside the memory bounds.
825            Descriptor::new(0x20_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
826            Descriptor::new(0x30_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
827        ];
828        let mut chain = queue.build_desc_chain(&v).unwrap();
829        assert_eq!(
830            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
831            Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress(
832                0x20_0000
833            )))
834        );
835
836        let v = vec![
837            Descriptor::new(0x5_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
838            // A device-readable packet data descriptor should be invalid.
839            Descriptor::new(0x8_0000, 0x100, 0, 0),
840        ];
841        let mut chain = queue.build_desc_chain(&v).unwrap();
842        assert_eq!(
843            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
844            Error::UnexpectedReadOnlyDescriptor
845        );
846        let v = vec![
847            Descriptor::new(0x5_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
848            // The data array doesn't fit entirely in the memory bounds.
849            Descriptor::new(0x10_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
850        ];
851        let mut chain = queue.build_desc_chain(&v).unwrap();
852        assert_eq!(
853            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
854            Error::InvalidMemoryAccess(GuestMemoryError::InvalidBackendAddress)
855        );
856
857        let v = vec![
858            Descriptor::new(0x5_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
859            // The data array is outside the memory bounds.
860            Descriptor::new(0x20_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
861        ];
862        let mut chain = queue.build_desc_chain(&v).unwrap();
863        assert_eq!(
864            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
865            Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress(
866                0x20_0000
867            )))
868        );
869
870        // Let's also test a valid descriptor chain.
871        let v = vec![
872            Descriptor::new(0x5_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
873            Descriptor::new(0x8_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
874        ];
875        let mut chain = queue.build_desc_chain(&v).unwrap();
876
877        let packet = VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
878        assert_eq!(packet.header, PacketHeader::default());
879        let header = packet.header_slice();
880        assert_eq!(
881            header.ptr_guard().as_ptr(),
882            mem.get_host_address(GuestAddress(0x5_0000)).unwrap()
883        );
884        assert_eq!(header.len(), PKT_HEADER_SIZE);
885
886        let data = packet.data_slice().unwrap();
887        assert_eq!(
888            data.ptr_guard().as_ptr(),
889            mem.get_host_address(GuestAddress(0x8_0000)).unwrap()
890        );
891        assert_eq!(data.len(), 0x100);
892
893        // If we try to get a vsock packet again, it fails because we already consumed all the
894        // descriptors from the chain.
895        assert_eq!(
896            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
897            Error::DescriptorChainTooShort
898        );
899
900        // Let's also test a valid descriptor chain, with both header and data on a single
901        // descriptor.
902        let v = vec![Descriptor::new(
903            0x5_0000,
904            PKT_HEADER_SIZE as u32 + 0x100,
905            VRING_DESC_F_WRITE as u16,
906            0,
907        )];
908        let mut chain = queue.build_desc_chain(&v).unwrap();
909
910        let packet = VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
911        assert_eq!(packet.header, PacketHeader::default());
912        let header = packet.header_slice();
913        assert_eq!(
914            header.ptr_guard().as_ptr(),
915            mem.get_host_address(GuestAddress(0x5_0000)).unwrap()
916        );
917        assert_eq!(header.len(), PKT_HEADER_SIZE);
918
919        let data = packet.data_slice().unwrap();
920        assert_eq!(
921            data.ptr_guard().as_ptr(),
922            mem.get_host_address(GuestAddress(0x5_0000 + PKT_HEADER_SIZE as u64))
923                .unwrap()
924        );
925        assert_eq!(data.len(), 0x100);
926    }
927
928    #[test]
929    fn test_from_tx_virtq_chain() {
930        let mem: GuestMemoryMmap =
931            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x1000_0000)]).unwrap();
932
933        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
934        let v = vec![
935            // A device-writable packet header descriptor should be invalid.
936            Descriptor::new(0x10_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
937            Descriptor::new(0x20_0000, 0x100, 0, 0),
938        ];
939        let queue = MockSplitQueue::new(&mem, 16);
940        let mut chain = queue.build_desc_chain(&v).unwrap();
941        assert_eq!(
942            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
943            Error::UnexpectedWriteOnlyDescriptor
944        );
945
946        let v = vec![
947            // A header length < PKT_HEADER_SIZE is invalid.
948            Descriptor::new(0x10_0000, PKT_HEADER_SIZE as u32 - 1, 0, 0),
949            Descriptor::new(0x20_0000, 0x100, 0, 0),
950        ];
951        let mut chain = queue.build_desc_chain(&v).unwrap();
952        assert_eq!(
953            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
954            Error::DescriptorLengthTooSmall
955        );
956
957        // On the TX path, it is allowed to not have a data descriptor.
958        let v = vec![Descriptor::new(0x10_0000, PKT_HEADER_SIZE as u32, 0, 0)];
959        let mut chain = queue.build_desc_chain(&v).unwrap();
960
961        let header = PacketHeader {
962            src_cid: SRC_CID.into(),
963            dst_cid: DST_CID.into(),
964            src_port: SRC_PORT.into(),
965            dst_port: DST_PORT.into(),
966            len: 0.into(),
967            type_: 0.into(),
968            op: 0.into(),
969            flags: 0.into(),
970            buf_alloc: 0.into(),
971            fwd_cnt: 0.into(),
972        };
973        mem.write_obj(header, GuestAddress(0x10_0000)).unwrap();
974
975        let packet = VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
976        assert_eq!(packet.header, header);
977        let header_slice = packet.header_slice();
978        assert_eq!(
979            header_slice.ptr_guard().as_ptr(),
980            mem.get_host_address(GuestAddress(0x10_0000)).unwrap()
981        );
982        assert_eq!(header_slice.len(), PKT_HEADER_SIZE);
983        assert!(packet.data_slice().is_none());
984
985        let mem: GuestMemoryMmap =
986            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0004)]).unwrap();
987
988        let v = vec![
989            // The header doesn't fit entirely in the memory bounds.
990            Descriptor::new(0x10_0000, 0x100, 0, 0),
991            Descriptor::new(0x20_0000, 0x100, 0, 0),
992        ];
993        let queue = MockSplitQueue::new(&mem, 16);
994        let mut chain = queue.build_desc_chain(&v).unwrap();
995        assert_eq!(
996            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
997            Error::InvalidMemoryAccess(GuestMemoryError::InvalidBackendAddress)
998        );
999
1000        let v = vec![
1001            // The header is outside the memory bounds.
1002            Descriptor::new(0x20_0000, 0x100, 0, 0),
1003            Descriptor::new(0x30_0000, 0x100, 0, 0),
1004        ];
1005        let mut chain = queue.build_desc_chain(&v).unwrap();
1006        assert_eq!(
1007            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1008            Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress(
1009                0x20_0000
1010            )))
1011        );
1012
1013        // Write some non-zero value to the `len` field of the header, which means there is also
1014        // a data descriptor in the chain, first with a value that exceeds the maximum allowed one.
1015        let header = PacketHeader {
1016            src_cid: SRC_CID.into(),
1017            dst_cid: DST_CID.into(),
1018            src_port: SRC_PORT.into(),
1019            dst_port: DST_PORT.into(),
1020            len: (MAX_PKT_BUF_SIZE + 1).into(),
1021            type_: 0.into(),
1022            op: 0.into(),
1023            flags: 0.into(),
1024            buf_alloc: 0.into(),
1025            fwd_cnt: 0.into(),
1026        };
1027        mem.write_obj(header, GuestAddress(0x5_0000)).unwrap();
1028        let v = vec![
1029            Descriptor::new(0x5_0000, 0x100, 0, 0),
1030            Descriptor::new(0x8_0000, 0x100, 0, 0),
1031        ];
1032        let mut chain = queue.build_desc_chain(&v).unwrap();
1033        assert_eq!(
1034            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1035            Error::InvalidHeaderLen(MAX_PKT_BUF_SIZE + 1)
1036        );
1037
1038        // Write some non-zero, valid value to the `len` field of the header.
1039        let header = PacketHeader {
1040            src_cid: SRC_CID.into(),
1041            dst_cid: DST_CID.into(),
1042            src_port: SRC_PORT.into(),
1043            dst_port: DST_PORT.into(),
1044            len: LEN.into(),
1045            type_: 0.into(),
1046            op: 0.into(),
1047            flags: 0.into(),
1048            buf_alloc: 0.into(),
1049            fwd_cnt: 0.into(),
1050        };
1051        mem.write_obj(header, GuestAddress(0x5_0000)).unwrap();
1052        let v = vec![
1053            // The data descriptor is missing.
1054            Descriptor::new(0x5_0000, PKT_HEADER_SIZE as u32, 0, 0),
1055        ];
1056        let mut chain = queue.build_desc_chain(&v).unwrap();
1057        assert_eq!(
1058            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1059            Error::DescriptorChainTooShort
1060        );
1061
1062        let v = vec![
1063            Descriptor::new(0x5_0000, 0x100, 0, 0),
1064            // The data array doesn't fit entirely in the memory bounds.
1065            Descriptor::new(0x10_0000, 0x100, 0, 0),
1066        ];
1067        let mut chain = queue.build_desc_chain(&v).unwrap();
1068        assert_eq!(
1069            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1070            Error::InvalidMemoryAccess(GuestMemoryError::InvalidBackendAddress)
1071        );
1072
1073        let v = vec![
1074            Descriptor::new(0x5_0000, 0x100, 0, 0),
1075            // The data array is outside the memory bounds.
1076            Descriptor::new(0x20_0000, 0x100, 0, 0),
1077        ];
1078        let mut chain = queue.build_desc_chain(&v).unwrap();
1079        assert_eq!(
1080            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1081            Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress(
1082                0x20_0000
1083            )))
1084        );
1085
1086        let v = vec![
1087            Descriptor::new(0x5_0000, 0x100, 0, 0),
1088            // A device-writable packet data descriptor should be invalid.
1089            Descriptor::new(0x8_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
1090        ];
1091        let mut chain = queue.build_desc_chain(&v).unwrap();
1092        assert_eq!(
1093            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1094            Error::UnexpectedWriteOnlyDescriptor
1095        );
1096
1097        let v = vec![
1098            Descriptor::new(0x5_0000, 0x100, 0, 0),
1099            // A data length < the length of data as described by the header.
1100            Descriptor::new(0x8_0000, LEN - 1, 0, 0),
1101        ];
1102        let mut chain = queue.build_desc_chain(&v).unwrap();
1103        assert_eq!(
1104            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1105            Error::DescriptorLengthTooSmall
1106        );
1107
1108        // Let's also test a valid descriptor chain, with both header and data.
1109        let v = vec![
1110            Descriptor::new(0x5_0000, 0x100, 0, 0),
1111            Descriptor::new(0x8_0000, 0x100, 0, 0),
1112        ];
1113        let mut chain = queue.build_desc_chain(&v).unwrap();
1114
1115        let packet = VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1116        assert_eq!(packet.header, header);
1117        let header_slice = packet.header_slice();
1118        assert_eq!(
1119            header_slice.ptr_guard().as_ptr(),
1120            mem.get_host_address(GuestAddress(0x5_0000)).unwrap()
1121        );
1122        assert_eq!(header_slice.len(), PKT_HEADER_SIZE);
1123        // The `len` field of the header was set to 16.
1124        assert_eq!(packet.len(), LEN);
1125
1126        let data = packet.data_slice().unwrap();
1127        assert_eq!(
1128            data.ptr_guard().as_ptr(),
1129            mem.get_host_address(GuestAddress(0x8_0000)).unwrap()
1130        );
1131        assert_eq!(data.len(), LEN as usize);
1132
1133        // If we try to get a vsock packet again, it fails because we already consumed all the
1134        // descriptors from the chain.
1135        assert_eq!(
1136            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1137            Error::DescriptorChainTooShort
1138        );
1139
1140        // Let's also test a valid descriptor chain, with both header and data on a single
1141        // descriptor.
1142        let v = vec![Descriptor::new(
1143            0x5_0000,
1144            PKT_HEADER_SIZE as u32 + 0x100,
1145            0,
1146            0,
1147        )];
1148        let mut chain = queue.build_desc_chain(&v).unwrap();
1149
1150        let packet = VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1151        assert_eq!(packet.header, header);
1152        let header_slice = packet.header_slice();
1153        assert_eq!(
1154            header_slice.ptr_guard().as_ptr(),
1155            mem.get_host_address(GuestAddress(0x5_0000)).unwrap()
1156        );
1157        assert_eq!(header_slice.len(), PKT_HEADER_SIZE);
1158        // The `len` field of the header was set to 16.
1159        assert_eq!(packet.len(), LEN);
1160
1161        let data = packet.data_slice().unwrap();
1162        assert_eq!(
1163            data.ptr_guard().as_ptr(),
1164            mem.get_host_address(GuestAddress(0x5_0000 + PKT_HEADER_SIZE as u64))
1165                .unwrap()
1166        );
1167        assert_eq!(data.len(), LEN as usize);
1168    }
1169
1170    #[test]
1171    fn test_header_set_get() {
1172        let mem: GuestMemoryMmap =
1173            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x30_0000)]).unwrap();
1174        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
1175        let v = vec![
1176            Descriptor::new(0x10_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
1177            Descriptor::new(0x20_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
1178        ];
1179        let queue = MockSplitQueue::new(&mem, 16);
1180        let mut chain = queue.build_desc_chain(&v).unwrap();
1181
1182        let mut packet =
1183            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1184        packet
1185            .set_src_cid(SRC_CID)
1186            .set_dst_cid(DST_CID)
1187            .set_src_port(SRC_PORT)
1188            .set_dst_port(DST_PORT)
1189            .set_len(LEN)
1190            .set_type(TYPE)
1191            .set_op(OP)
1192            .set_flags(FLAGS)
1193            .set_flag(FLAG)
1194            .set_buf_alloc(BUF_ALLOC)
1195            .set_fwd_cnt(FWD_CNT);
1196
1197        assert_eq!(packet.flags(), FLAGS | FLAG);
1198        assert_eq!(packet.op(), OP);
1199        assert_eq!(packet.type_(), TYPE);
1200        assert_eq!(packet.dst_cid(), DST_CID);
1201        assert_eq!(packet.dst_port(), DST_PORT);
1202        assert_eq!(packet.src_cid(), SRC_CID);
1203        assert_eq!(packet.src_port(), SRC_PORT);
1204        assert_eq!(packet.fwd_cnt(), FWD_CNT);
1205        assert_eq!(packet.len(), LEN);
1206        assert_eq!(packet.buf_alloc(), BUF_ALLOC);
1207
1208        let expected_header = PacketHeader {
1209            src_cid: SRC_CID.into(),
1210            dst_cid: DST_CID.into(),
1211            src_port: SRC_PORT.into(),
1212            dst_port: DST_PORT.into(),
1213            len: LEN.into(),
1214            type_: TYPE.into(),
1215            op: OP.into(),
1216            flags: (FLAGS | FLAG).into(),
1217            buf_alloc: BUF_ALLOC.into(),
1218            fwd_cnt: FWD_CNT.into(),
1219        };
1220
1221        assert_eq!(packet.header, expected_header);
1222        assert_eq!(
1223            u64::from_le(
1224                packet
1225                    .header_slice()
1226                    .read_obj::<u64>(SRC_CID_OFFSET)
1227                    .unwrap()
1228            ),
1229            SRC_CID
1230        );
1231        assert_eq!(
1232            u64::from_le(
1233                packet
1234                    .header_slice()
1235                    .read_obj::<u64>(DST_CID_OFFSET)
1236                    .unwrap()
1237            ),
1238            DST_CID
1239        );
1240        assert_eq!(
1241            u32::from_le(
1242                packet
1243                    .header_slice()
1244                    .read_obj::<u32>(SRC_PORT_OFFSET)
1245                    .unwrap()
1246            ),
1247            SRC_PORT
1248        );
1249        assert_eq!(
1250            u32::from_le(
1251                packet
1252                    .header_slice()
1253                    .read_obj::<u32>(DST_PORT_OFFSET)
1254                    .unwrap()
1255            ),
1256            DST_PORT,
1257        );
1258        assert_eq!(
1259            u32::from_le(packet.header_slice().read_obj::<u32>(LEN_OFFSET).unwrap()),
1260            LEN
1261        );
1262        assert_eq!(
1263            u16::from_le(packet.header_slice().read_obj::<u16>(TYPE_OFFSET).unwrap()),
1264            TYPE
1265        );
1266        assert_eq!(
1267            u16::from_le(packet.header_slice().read_obj::<u16>(OP_OFFSET).unwrap()),
1268            OP
1269        );
1270        assert_eq!(
1271            u32::from_le(packet.header_slice().read_obj::<u32>(FLAGS_OFFSET).unwrap()),
1272            FLAGS | FLAG
1273        );
1274        assert_eq!(
1275            u32::from_le(
1276                packet
1277                    .header_slice()
1278                    .read_obj::<u32>(BUF_ALLOC_OFFSET)
1279                    .unwrap()
1280            ),
1281            BUF_ALLOC
1282        );
1283        assert_eq!(
1284            u32::from_le(
1285                packet
1286                    .header_slice()
1287                    .read_obj::<u32>(FWD_CNT_OFFSET)
1288                    .unwrap()
1289            ),
1290            FWD_CNT
1291        );
1292    }
1293
1294    #[test]
1295    fn test_set_header_from_raw() {
1296        let mem: GuestMemoryMmap =
1297            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x30_0000)]).unwrap();
1298        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
1299        let v = vec![
1300            Descriptor::new(0x10_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
1301            Descriptor::new(0x20_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
1302        ];
1303        let queue = MockSplitQueue::new(&mem, 16);
1304        let mut chain = queue.build_desc_chain(&v).unwrap();
1305
1306        let mut packet =
1307            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1308
1309        let header = PacketHeader {
1310            src_cid: SRC_CID.into(),
1311            dst_cid: DST_CID.into(),
1312            src_port: SRC_PORT.into(),
1313            dst_port: DST_PORT.into(),
1314            len: LEN.into(),
1315            type_: TYPE.into(),
1316            op: OP.into(),
1317            flags: (FLAGS | FLAG).into(),
1318            buf_alloc: BUF_ALLOC.into(),
1319            fwd_cnt: FWD_CNT.into(),
1320        };
1321
1322        // SAFETY: created from an existing packet header.
1323        let slice = unsafe {
1324            std::slice::from_raw_parts(
1325                (&header as *const PacketHeader) as *const u8,
1326                std::mem::size_of::<PacketHeader>(),
1327            )
1328        };
1329        assert_eq!(packet.header, PacketHeader::default());
1330        packet.set_header_from_raw(slice).unwrap();
1331        assert_eq!(packet.header, header);
1332        let header_from_slice: PacketHeader = packet.header_slice().read_obj(0).unwrap();
1333        assert_eq!(header_from_slice, header);
1334
1335        let invalid_slice = [0; PKT_HEADER_SIZE - 1];
1336        assert_eq!(
1337            packet.set_header_from_raw(&invalid_slice).unwrap_err(),
1338            Error::InvalidHeaderInputSize(PKT_HEADER_SIZE - 1)
1339        );
1340    }
1341
1342    #[test]
1343    fn test_packet_new() {
1344        let mut pkt_raw = [0u8; PKT_HEADER_SIZE + LEN as usize];
1345        let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);
1346        // SAFETY: safe because ``hdr_raw` and `data_raw` live for as long as
1347        // the scope of the current test.
1348        let packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };
1349        assert_eq!(
1350            packet.header_slice.ptr_guard().as_ptr(),
1351            hdr_raw.as_mut_ptr(),
1352        );
1353        assert_eq!(packet.header_slice.len(), PKT_HEADER_SIZE);
1354        assert_eq!(packet.header, PacketHeader::default());
1355        assert_eq!(
1356            packet.data_slice.unwrap().ptr_guard().as_ptr(),
1357            data_raw.as_mut_ptr(),
1358        );
1359        assert_eq!(packet.data_slice.unwrap().len(), LEN as usize);
1360
1361        // SAFETY: Safe because ``hdr_raw` and `data_raw` live as long as the
1362        // scope of the current test.
1363        let packet = unsafe { VsockPacket::new(hdr_raw, None).unwrap() };
1364        assert_eq!(
1365            packet.header_slice.ptr_guard().as_ptr(),
1366            hdr_raw.as_mut_ptr(),
1367        );
1368        assert_eq!(packet.header, PacketHeader::default());
1369        assert!(packet.data_slice.is_none());
1370
1371        let mut hdr_raw = [0u8; PKT_HEADER_SIZE - 1];
1372        assert_eq!(
1373            // SAFETY: Safe because ``hdr_raw` lives for as long as the scope of the current test.
1374            unsafe { VsockPacket::new(&mut hdr_raw, None).unwrap_err() },
1375            Error::InvalidHeaderInputSize(PKT_HEADER_SIZE - 1)
1376        );
1377    }
1378
1379    #[test]
1380    #[should_panic]
1381    fn test_set_header_field_with_invalid_offset() {
1382        const INVALID_OFFSET: usize = 50;
1383
1384        impl<'a, B: BitmapSlice> VsockPacket<'a, B> {
1385            /// Set the `src_cid` of the header, but use an invalid offset for that.
1386            pub fn set_src_cid_invalid(&mut self, cid: u64) -> &mut Self {
1387                set_header_field!(self, src_cid, INVALID_OFFSET, cid);
1388                self
1389            }
1390        }
1391
1392        let mem: GuestMemoryMmap =
1393            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x30_0000)]).unwrap();
1394        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
1395        let v = vec![
1396            Descriptor::new(0x10_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
1397            Descriptor::new(0x20_0000, 0x100, VRING_DESC_F_WRITE as u16, 0),
1398        ];
1399        let queue = MockSplitQueue::new(&mem, 16);
1400        let mut chain = queue.build_desc_chain(&v).unwrap();
1401
1402        let mut packet =
1403            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1404        packet.set_src_cid_invalid(SRC_CID);
1405    }
1406}