virtio_vsock/
packet.rs

1// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2//
3// SPDX-License-Identifier: Apache-2.0 OR BSD-3-Clause
4
5//! Vsock packet abstraction.
6//!
7//! This module provides the following abstraction for parsing a vsock packet, and working with it:
8//!
9//! - [`VsockPacket`](struct.VsockPacket.html) which handles the parsing of the vsock packet from
10//! either a TX descriptor chain via
11//! [`VsockPacket::from_tx_virtq_chain`](struct.VsockPacket.html#method.from_tx_virtq_chain), or an
12//! RX descriptor chain via
13//! [`VsockPacket::from_rx_virtq_chain`](struct.VsockPacket.html#method.from_rx_virtq_chain).
14//! The virtio vsock packet is defined in the standard as having a header of type `virtio_vsock_hdr`
15//! and an optional `data` array of bytes. The methods mentioned above assume that both packet
16//! elements are on the same descriptor, or each of the packet elements occupies exactly one
17//! descriptor. For the usual drivers, this assumption stands,
18//! but in the future we might make the implementation more generic by removing any constraint
19//! regarding the number of descriptors that correspond to the header/data. The buffers associated
20//! to the TX virtio queue are device-readable, and the ones associated to the RX virtio queue are
21//! device-writable.
22///
23/// The `VsockPacket` abstraction is using vm-memory's `VolatileSlice` for representing the header
24/// and the data. `VolatileSlice` is a safe wrapper over a raw pointer, which also handles the dirty
25/// page tracking behind the scenes. A limitation of the current implementation is that it does not
26/// cover the scenario where the header or data buffer doesn't fit in a single `VolatileSlice`
27/// because the guest memory regions of the buffer are contiguous in the guest physical address
28/// space, but not in the host virtual one as well. If this becomes an use case, we can extend this
29/// solution to use an array of `VolatileSlice`s for the header and data.
30/// The `VsockPacket` abstraction is also storing a `virtio_vsock_hdr` instance (which is defined
31/// here as `PacketHeader`). This is needed so that we always access the same data that was read the
32/// first time from the descriptor chain. We avoid this way potential time-of-check time-of-use
33/// problems that may occur when reading later a header field from the underlying memory itself
34/// (i.e. from the header's `VolatileSlice` object).
35use std::fmt::{self, Display};
36use std::ops::Deref;
37
38use virtio_queue::DescriptorChain;
39use vm_memory::bitmap::{BitmapSlice, WithBitmapSlice};
40use vm_memory::{
41    Address, ByteValued, Bytes, GuestMemory, GuestMemoryError, GuestMemoryRegion, Le16, Le32, Le64,
42    VolatileMemoryError, VolatileSlice,
43};
44
45/// Vsock packet parsing errors.
46#[derive(Debug)]
47pub enum Error {
48    /// Too few descriptors in a descriptor chain.
49    DescriptorChainTooShort,
50    /// Descriptor that was too short to use.
51    DescriptorLengthTooSmall,
52    /// Descriptor that was too long to use.
53    DescriptorLengthTooLong,
54    /// The slice for creating a header has an invalid length.
55    InvalidHeaderInputSize(usize),
56    /// The `len` header field value exceeds the maximum allowed data size.
57    InvalidHeaderLen(u32),
58    /// Invalid guest memory access.
59    InvalidMemoryAccess(GuestMemoryError),
60    /// Invalid volatile memory access.
61    InvalidVolatileAccess(VolatileMemoryError),
62    /// Read only descriptor that protocol says to write to.
63    UnexpectedReadOnlyDescriptor,
64    /// Write only descriptor that protocol says to read from.
65    UnexpectedWriteOnlyDescriptor,
66}
67
68impl std::error::Error for Error {}
69
70impl Display for Error {
71    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
72        match self {
73            Error::DescriptorChainTooShort => {
74                write!(f, "There are not enough descriptors in the chain.")
75            }
76            Error::DescriptorLengthTooSmall => write!(
77                f,
78                "The descriptor is pointing to a buffer that has a smaller length than expected."
79            ),
80            Error::DescriptorLengthTooLong => write!(
81                f,
82                "The descriptor is pointing to a buffer that has a longer length than expected."
83            ),
84            Error::InvalidHeaderInputSize(size) => {
85                write!(f, "Invalid header input size: {size}")
86            }
87            Error::InvalidHeaderLen(size) => {
88                write!(f, "Invalid header `len` field value: {size}")
89            }
90            Error::InvalidMemoryAccess(error) => {
91                write!(f, "Invalid guest memory access: {error}")
92            }
93            Error::InvalidVolatileAccess(error) => {
94                write!(f, "Invalid volatile memory access: {error}")
95            }
96            Error::UnexpectedReadOnlyDescriptor => {
97                write!(f, "Unexpected read-only descriptor.")
98            }
99            Error::UnexpectedWriteOnlyDescriptor => {
100                write!(f, "Unexpected write-only descriptor.")
101            }
102        }
103    }
104}
105
106#[repr(C, packed)]
107#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
108/// The vsock packet header structure.
109pub struct PacketHeader {
110    src_cid: Le64,
111    dst_cid: Le64,
112    src_port: Le32,
113    dst_port: Le32,
114    len: Le32,
115    type_: Le16,
116    op: Le16,
117    flags: Le32,
118    buf_alloc: Le32,
119    fwd_cnt: Le32,
120}
121
122// SAFETY: This is safe because `PacketHeader` contains only wrappers over POD types
123// and all accesses through safe `vm-memory` API will validate any garbage that could
124// be included in there.
125unsafe impl ByteValued for PacketHeader {}
126//
127// This structure will occupy the buffer pointed to by the head of the descriptor chain. Below are
128// the offsets for each field, as well as the packed structure size.
129// Note that these offsets are only used privately by the `VsockPacket` struct, the public interface
130// consisting of getter and setter methods, for each struct field, that will also handle the correct
131// endianness.
132
133/// The size of the header structure (when packed).
134pub const PKT_HEADER_SIZE: usize = std::mem::size_of::<PacketHeader>();
135
136// Offsets of the header fields.
137const SRC_CID_OFFSET: usize = 0;
138const DST_CID_OFFSET: usize = 8;
139const SRC_PORT_OFFSET: usize = 16;
140const DST_PORT_OFFSET: usize = 20;
141const LEN_OFFSET: usize = 24;
142const TYPE_OFFSET: usize = 28;
143const OP_OFFSET: usize = 30;
144const FLAGS_OFFSET: usize = 32;
145const BUF_ALLOC_OFFSET: usize = 36;
146const FWD_CNT_OFFSET: usize = 40;
147
148/// Dedicated [`Result`](https://doc.rust-lang.org/std/result/) type.
149pub type Result<T> = std::result::Result<T, Error>;
150
151/// The vsock packet, implemented as a wrapper over a virtio descriptor chain:
152/// - the chain head, holding the packet header;
153/// - an optional data/buffer descriptor, only present for data packets (for VSOCK_OP_RW requests).
154#[derive(Debug)]
155pub struct VsockPacket<'a, B: BitmapSlice> {
156    // When writing to the header slice, we are using the `write` method of `VolatileSlice`s Bytes
157    // implementation. Because that can only return an error if we pass an invalid offset, we can
158    // safely use `unwraps` in the setters below. If we switch to a type different than
159    // `VolatileSlice`, this assumption can no longer hold. We also must always make sure the
160    // `VsockPacket` API is creating headers with PKT_HEADER_SIZE size.
161    header_slice: VolatileSlice<'a, B>,
162    header: PacketHeader,
163    data_slice: Option<VolatileSlice<'a, B>>,
164}
165
166// This macro is intended to be used for setting a header field in both the `VolatileSlice` and the
167// `PacketHeader` structure from a packet. `$offset` should be a valid offset in the `header_slice`,
168// otherwise the macro will panic.
169macro_rules! set_header_field {
170    ($packet:ident, $field:ident, $offset:ident, $value:ident) => {
171        $packet.header.$field = $value.into();
172        $packet
173            .header_slice
174            .write(&$value.to_le_bytes(), $offset)
175            // This unwrap is safe only if `$offset` is a valid offset in the `header_slice`.
176            .unwrap();
177    };
178}
179
180impl<'a, B: BitmapSlice> VsockPacket<'a, B> {
181    /// Return a reference to the `header_slice` of the packet.
182    pub fn header_slice(&self) -> &VolatileSlice<'a, B> {
183        &self.header_slice
184    }
185
186    /// Return a reference to the `data_slice` of the packet.
187    pub fn data_slice(&self) -> Option<&VolatileSlice<'a, B>> {
188        self.data_slice.as_ref()
189    }
190
191    /// Write to the packet header from an input of raw bytes.
192    ///
193    /// # Example
194    ///
195    /// ```rust
196    /// # use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE;
197    /// # use virtio_queue::mock::MockSplitQueue;
198    /// # use virtio_queue::{desc::{split::Descriptor as SplitDescriptor, RawDescriptor}, Queue, QueueT};
199    /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
200    /// # use vm_memory::{Bytes, GuestAddress, GuestAddressSpace, GuestMemoryMmap};
201    ///
202    /// const MAX_PKT_BUF_SIZE: u32 = 64 * 1024;
203    ///
204    /// # fn create_queue_with_chain(m: &GuestMemoryMmap) -> Queue {
205    /// #     let vq = MockSplitQueue::new(m, 16);
206    /// #     let mut q = vq.create_queue().unwrap();
207    /// #
208    /// #     let v = vec![
209    /// #         RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, VRING_DESC_F_WRITE as u16, 0)),
210    /// #         RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, VRING_DESC_F_WRITE as u16, 0)),
211    /// #     ];
212    /// #     let mut chain = vq.build_desc_chain(&v);
213    /// #     q
214    /// # }
215    /// let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
216    /// // Create a queue and populate it with a descriptor chain.
217    /// let mut queue = create_queue_with_chain(&mem);
218    ///
219    /// while let Some(mut head) = queue.pop_descriptor_chain(&mem) {
220    ///     let mut pkt = VsockPacket::from_rx_virtq_chain(&mem, &mut head, MAX_PKT_BUF_SIZE).unwrap();
221    ///     pkt.set_header_from_raw(&[0u8; PKT_HEADER_SIZE]).unwrap();
222    /// }
223    /// ```
224    pub fn set_header_from_raw(&mut self, bytes: &[u8]) -> Result<()> {
225        if bytes.len() != PKT_HEADER_SIZE {
226            return Err(Error::InvalidHeaderInputSize(bytes.len()));
227        }
228        self.header_slice
229            .write(bytes, 0)
230            .map_err(Error::InvalidVolatileAccess)?;
231        let header = self
232            .header_slice()
233            .read_obj::<PacketHeader>(0)
234            .map_err(Error::InvalidVolatileAccess)?;
235        self.header = header;
236        Ok(())
237    }
238
239    /// Return the `src_cid` of the header.
240    pub fn src_cid(&self) -> u64 {
241        self.header.src_cid.into()
242    }
243
244    /// Set the `src_cid` of the header.
245    pub fn set_src_cid(&mut self, cid: u64) -> &mut Self {
246        set_header_field!(self, src_cid, SRC_CID_OFFSET, cid);
247        self
248    }
249
250    /// Return the `dst_cid` of the header.
251    pub fn dst_cid(&self) -> u64 {
252        self.header.dst_cid.into()
253    }
254
255    /// Set the `dst_cid` of the header.
256    pub fn set_dst_cid(&mut self, cid: u64) -> &mut Self {
257        set_header_field!(self, dst_cid, DST_CID_OFFSET, cid);
258        self
259    }
260
261    /// Return the `src_port` of the header.
262    pub fn src_port(&self) -> u32 {
263        self.header.src_port.into()
264    }
265
266    /// Set the `src_port` of the header.
267    pub fn set_src_port(&mut self, port: u32) -> &mut Self {
268        set_header_field!(self, src_port, SRC_PORT_OFFSET, port);
269        self
270    }
271
272    /// Return the `dst_port` of the header.
273    pub fn dst_port(&self) -> u32 {
274        self.header.dst_port.into()
275    }
276
277    /// Set the `dst_port` of the header.
278    pub fn set_dst_port(&mut self, port: u32) -> &mut Self {
279        set_header_field!(self, dst_port, DST_PORT_OFFSET, port);
280        self
281    }
282
283    /// Return the `len` of the header.
284    pub fn len(&self) -> u32 {
285        self.header.len.into()
286    }
287
288    /// Returns whether the `len` field of the header is 0 or not.
289    pub fn is_empty(&self) -> bool {
290        self.len() == 0
291    }
292
293    /// Set the `len` of the header.
294    pub fn set_len(&mut self, len: u32) -> &mut Self {
295        set_header_field!(self, len, LEN_OFFSET, len);
296        self
297    }
298
299    /// Return the `type` of the header.
300    pub fn type_(&self) -> u16 {
301        self.header.type_.into()
302    }
303
304    /// Set the `type` of the header.
305    pub fn set_type(&mut self, type_: u16) -> &mut Self {
306        set_header_field!(self, type_, TYPE_OFFSET, type_);
307        self
308    }
309
310    /// Return the `op` of the header.
311    pub fn op(&self) -> u16 {
312        self.header.op.into()
313    }
314
315    /// Set the `op` of the header.
316    pub fn set_op(&mut self, op: u16) -> &mut Self {
317        set_header_field!(self, op, OP_OFFSET, op);
318        self
319    }
320
321    /// Return the `flags` of the header.
322    pub fn flags(&self) -> u32 {
323        self.header.flags.into()
324    }
325
326    /// Set the `flags` of the header.
327    pub fn set_flags(&mut self, flags: u32) -> &mut Self {
328        set_header_field!(self, flags, FLAGS_OFFSET, flags);
329        self
330    }
331
332    /// Set a specific flag of the header.
333    pub fn set_flag(&mut self, flag: u32) -> &mut Self {
334        self.set_flags(self.flags() | flag);
335        self
336    }
337
338    /// Return the `buf_alloc` of the header.
339    pub fn buf_alloc(&self) -> u32 {
340        self.header.buf_alloc.into()
341    }
342
343    /// Set the `buf_alloc` of the header.
344    pub fn set_buf_alloc(&mut self, buf_alloc: u32) -> &mut Self {
345        set_header_field!(self, buf_alloc, BUF_ALLOC_OFFSET, buf_alloc);
346        self
347    }
348
349    /// Return the `fwd_cnt` of the header.
350    pub fn fwd_cnt(&self) -> u32 {
351        self.header.fwd_cnt.into()
352    }
353
354    /// Set the `fwd_cnt` of the header.
355    pub fn set_fwd_cnt(&mut self, fwd_cnt: u32) -> &mut Self {
356        set_header_field!(self, fwd_cnt, FWD_CNT_OFFSET, fwd_cnt);
357        self
358    }
359
360    /// Create the packet wrapper from a TX chain.
361    ///
362    /// The chain head is expected to hold a valid packet header. A following packet data
363    /// descriptor can optionally end the chain.
364    ///
365    /// # Arguments
366    ///
367    /// * `mem` - the `GuestMemory` object that can be used to access the queue buffers.
368    /// * `desc_chain` - the descriptor chain corresponding to a packet.
369    /// * `max_data_size` - the maximum size allowed for the packet payload, that was negotiated between the device and the driver. Tracking issue for defining this feature in virtio-spec [here](https://github.com/oasis-tcs/virtio-spec/issues/140).
370    ///
371    /// # Example
372    ///
373    /// ```rust
374    /// # use virtio_queue::mock::MockSplitQueue;
375    /// # use virtio_queue::{desc::{split::Descriptor as SplitDescriptor, RawDescriptor}, Queue, QueueT};
376    /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
377    /// # use vm_memory::{Bytes, GuestAddress, GuestAddressSpace, GuestMemoryMmap};
378    ///
379    /// const MAX_PKT_BUF_SIZE: u32 = 64 * 1024;
380    /// const OP_RW: u16 = 5;
381    ///
382    /// # fn create_queue_with_chain(m: &GuestMemoryMmap) -> Queue {
383    /// #     let vq = MockSplitQueue::new(m, 16);
384    /// #     let mut q = vq.create_queue().unwrap();
385    /// #
386    /// #     let v = vec![
387    /// #         RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)),
388    /// #         RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, 0, 0)),
389    /// #     ];
390    /// #     let mut chain = vq.build_desc_chain(&v);
391    /// #     q
392    /// # }
393    /// let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap();
394    /// // Create a queue and populate it with a descriptor chain.
395    /// let mut queue = create_queue_with_chain(&mem);
396    ///
397    /// while let Some(mut head) = queue.pop_descriptor_chain(&mem) {
398    ///     let pkt = match VsockPacket::from_tx_virtq_chain(&mem, &mut head, MAX_PKT_BUF_SIZE) {
399    ///         Ok(pkt) => pkt,
400    ///         Err(_e) => {
401    ///             // Do some error handling.
402    ///             queue.add_used(&mem, head.head_index(), 0);
403    ///             continue;
404    ///         }
405    ///     };
406    ///     // Here we would send the packet to the backend. Depending on the operation type, a
407    ///     // different type of action will be done.
408    ///
409    ///     // For example, if it's a RW packet, we will forward the packet payload to the backend.
410    ///     if pkt.op() == OP_RW {
411    ///         // Send the packet payload to the backend.
412    ///     }
413    ///     queue.add_used(&mem, head.head_index(), 0);
414    /// }
415    /// ```
416    pub fn from_tx_virtq_chain<M, T>(
417        mem: &'a M,
418        desc_chain: &mut DescriptorChain<T>,
419        max_data_size: u32,
420    ) -> Result<Self>
421    where
422        M: GuestMemory,
423        <<M as GuestMemory>::R as GuestMemoryRegion>::B: WithBitmapSlice<'a, S = B>,
424        T: Deref,
425        T::Target: GuestMemory,
426    {
427        let chain_head = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
428        // All TX buffers must be device-readable.
429        if chain_head.is_write_only() {
430            return Err(Error::UnexpectedWriteOnlyDescriptor);
431        }
432
433        // The packet header should fit inside the buffer corresponding to the head descriptor.
434        if (chain_head.len() as usize) < PKT_HEADER_SIZE {
435            return Err(Error::DescriptorLengthTooSmall);
436        }
437
438        let header_slice = mem
439            .get_slice(chain_head.addr(), PKT_HEADER_SIZE)
440            .map_err(Error::InvalidMemoryAccess)?;
441
442        let header = mem
443            .read_obj(chain_head.addr())
444            .map_err(Error::InvalidMemoryAccess)?;
445
446        let mut pkt = Self {
447            header_slice,
448            header,
449            data_slice: None,
450        };
451
452        // If the `len` field of the header is zero, then the packet doesn't have a `data` element.
453        if pkt.is_empty() {
454            return Ok(pkt);
455        }
456
457        // Reject packets that exceed the maximum allowed value for payload.
458        if pkt.len() > max_data_size {
459            return Err(Error::InvalidHeaderLen(pkt.len()));
460        }
461
462        // Starting from Linux 6.2 the virtio-vsock driver can use a single descriptor for both
463        // header and data.
464        let data_slice =
465            if !chain_head.has_next() && chain_head.len() - PKT_HEADER_SIZE as u32 >= pkt.len() {
466                mem.get_slice(
467                    chain_head
468                        .addr()
469                        .checked_add(PKT_HEADER_SIZE as u64)
470                        .ok_or(Error::DescriptorLengthTooSmall)?,
471                    pkt.len() as usize,
472                )
473                .map_err(Error::InvalidMemoryAccess)?
474            } else {
475                if !chain_head.has_next() {
476                    return Err(Error::DescriptorChainTooShort);
477                }
478
479                let data_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
480
481                if data_desc.is_write_only() {
482                    return Err(Error::UnexpectedWriteOnlyDescriptor);
483                }
484
485                // The data buffer should be large enough to fit the size of the data, as described by
486                // the header descriptor.
487                if data_desc.len() < pkt.len() {
488                    return Err(Error::DescriptorLengthTooSmall);
489                }
490
491                mem.get_slice(data_desc.addr(), pkt.len() as usize)
492                    .map_err(Error::InvalidMemoryAccess)?
493            };
494
495        pkt.data_slice = Some(data_slice);
496        Ok(pkt)
497    }
498
499    /// Create the packet wrapper from an RX chain.
500    ///
501    /// There must be two descriptors in the chain, both writable: a header descriptor and a data
502    /// descriptor.
503    ///
504    /// # Arguments
505    ///
506    /// * `mem` - the `GuestMemory` object that can be used to access the queue buffers.
507    /// * `desc_chain` - the descriptor chain corresponding to a packet.
508    /// * `max_data_size` - the maximum size allowed for the packet payload, that was negotiated between the device and the driver. Tracking issue for defining this feature in virtio-spec [here](https://github.com/oasis-tcs/virtio-spec/issues/140).
509    ///
510    /// # Example
511    ///
512    /// ```rust
513    /// # use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE;
514    /// # use virtio_queue::mock::MockSplitQueue;
515    /// # use virtio_queue::{desc::{split::Descriptor as SplitDescriptor, RawDescriptor}, Queue, QueueT};
516    /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
517    /// # use vm_memory::{Bytes, GuestAddress, GuestAddressSpace, GuestMemoryMmap};
518    ///
519    /// # const MAX_PKT_BUF_SIZE: u32 = 64 * 1024;
520    /// # const SRC_CID: u64 = 1;
521    /// # const DST_CID: u64 = 2;
522    /// # const SRC_PORT: u32 = 3;
523    /// # const DST_PORT: u32 = 4;
524    /// # const LEN: u32 = 16;
525    /// # const TYPE_STREAM: u16 = 1;
526    /// # const OP_RW: u16 = 5;
527    /// # const FLAGS: u32 = 7;
528    /// # const FLAG: u32 = 8;
529    /// # const BUF_ALLOC: u32 = 256;
530    /// # const FWD_CNT: u32 = 9;
531    ///
532    /// # fn create_queue_with_chain(m: &GuestMemoryMmap) -> Queue {
533    /// #     let vq = MockSplitQueue::new(m, 16);
534    /// #     let mut q = vq.create_queue().unwrap();
535    /// #
536    /// #     let v = vec![
537    /// #         RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, VRING_DESC_F_WRITE as u16, 0)),
538    /// #         RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, VRING_DESC_F_WRITE as u16, 0)),
539    /// #     ];
540    /// #     let mut chain = vq.build_desc_chain(&v);
541    /// #    q
542    /// # }
543    /// let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap();
544    /// // Create a queue and populate it with a descriptor chain.
545    /// let mut queue = create_queue_with_chain(&mem);
546    ///
547    /// while let Some(mut head) = queue.pop_descriptor_chain(&mem) {
548    ///     let used_len = match VsockPacket::from_rx_virtq_chain(&mem, &mut head, MAX_PKT_BUF_SIZE) {
549    ///         Ok(mut pkt) => {
550    ///             // Make sure the header is zeroed out first.
551    ///             pkt.header_slice()
552    ///                 .write(&[0u8; PKT_HEADER_SIZE], 0)
553    ///                 .unwrap();
554    ///             // Write data to the packet, using the setters.
555    ///             pkt.set_src_cid(SRC_CID)
556    ///                 .set_dst_cid(DST_CID)
557    ///                 .set_src_port(SRC_PORT)
558    ///                 .set_dst_port(DST_PORT)
559    ///                 .set_type(TYPE_STREAM)
560    ///                 .set_buf_alloc(BUF_ALLOC)
561    ///                 .set_fwd_cnt(FWD_CNT);
562    ///             // In this example, we are sending a RW packet.
563    ///             pkt.data_slice()
564    ///                 .unwrap()
565    ///                 .write_slice(&[1u8; LEN as usize], 0);
566    ///             pkt.set_op(OP_RW).set_len(LEN);
567    ///             pkt.header_slice().len() as u32 + LEN
568    ///         }
569    ///         Err(_e) => {
570    ///             // Do some error handling.
571    ///             0
572    ///         }
573    ///     };
574    ///     queue.add_used(&mem, head.head_index(), used_len);
575    /// }
576    /// ```
577    pub fn from_rx_virtq_chain<M, T>(
578        mem: &'a M,
579        desc_chain: &mut DescriptorChain<T>,
580        max_data_size: u32,
581    ) -> Result<Self>
582    where
583        M: GuestMemory,
584        <<M as GuestMemory>::R as GuestMemoryRegion>::B: WithBitmapSlice<'a, S = B>,
585        T: Deref,
586        T::Target: GuestMemory,
587    {
588        let chain_head = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
589        // All RX buffers must be device-writable.
590        if !chain_head.is_write_only() {
591            return Err(Error::UnexpectedReadOnlyDescriptor);
592        }
593
594        // The packet header should fit inside the head descriptor.
595        if (chain_head.len() as usize) < PKT_HEADER_SIZE {
596            return Err(Error::DescriptorLengthTooSmall);
597        }
598
599        let header_slice = mem
600            .get_slice(chain_head.addr(), PKT_HEADER_SIZE)
601            .map_err(Error::InvalidMemoryAccess)?;
602
603        // Starting from Linux 6.2 the virtio-vsock driver can use a single descriptor for both
604        // header and data.
605        let data_slice = if !chain_head.has_next() && chain_head.len() as usize > PKT_HEADER_SIZE {
606            mem.get_slice(
607                chain_head
608                    .addr()
609                    .checked_add(PKT_HEADER_SIZE as u64)
610                    .ok_or(Error::DescriptorLengthTooSmall)?,
611                chain_head.len() as usize - PKT_HEADER_SIZE,
612            )
613            .map_err(Error::InvalidMemoryAccess)?
614        } else {
615            if !chain_head.has_next() {
616                return Err(Error::DescriptorChainTooShort);
617            }
618
619            let data_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
620
621            if !data_desc.is_write_only() {
622                return Err(Error::UnexpectedReadOnlyDescriptor);
623            }
624
625            if data_desc.len() > max_data_size {
626                return Err(Error::DescriptorLengthTooLong);
627            }
628
629            mem.get_slice(data_desc.addr(), data_desc.len() as usize)
630                .map_err(Error::InvalidMemoryAccess)?
631        };
632
633        Ok(Self {
634            header_slice,
635            header: Default::default(),
636            data_slice: Some(data_slice),
637        })
638    }
639}
640
641impl<'a> VsockPacket<'a, ()> {
642    /// Create a packet based on one pointer for the header, and an optional one for data.
643    ///
644    /// # Safety
645    ///
646    /// To use this safely, the caller must guarantee that the memory pointed to by the `hdr` and
647    /// `data` slices is available for the duration of the lifetime of the new `VolatileSlice`. The
648    /// caller must also guarantee that all other users of the given chunk of memory are using
649    /// volatile accesses.
650    ///
651    /// # Example
652    ///
653    /// ```rust
654    /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
655    ///
656    /// const LEN: usize = 16;
657    ///
658    /// let mut pkt_raw = [0u8; PKT_HEADER_SIZE + LEN];
659    /// let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);
660    /// // Safe because `hdr_raw` and `data_raw` live for as long as the scope of the current
661    /// // example.
662    /// let packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };
663    /// ```
664    pub unsafe fn new(header: &mut [u8], data: Option<&mut [u8]>) -> Result<VsockPacket<'a, ()>> {
665        if header.len() != PKT_HEADER_SIZE {
666            return Err(Error::InvalidHeaderInputSize(header.len()));
667        }
668        Ok(VsockPacket {
669            header_slice: VolatileSlice::new(header.as_mut_ptr(), PKT_HEADER_SIZE),
670            header: Default::default(),
671            data_slice: data.map(|data| VolatileSlice::new(data.as_mut_ptr(), data.len())),
672        })
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679
680    use vm_memory::{GuestAddress, GuestMemoryMmap};
681
682    use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE;
683    use virtio_queue::desc::{split::Descriptor as SplitDescriptor, RawDescriptor};
684    use virtio_queue::mock::MockSplitQueue;
685
686    impl PartialEq for Error {
687        fn eq(&self, other: &Self) -> bool {
688            use self::Error::*;
689            match (self, other) {
690                (DescriptorChainTooShort, DescriptorChainTooShort) => true,
691                (DescriptorLengthTooSmall, DescriptorLengthTooSmall) => true,
692                (DescriptorLengthTooLong, DescriptorLengthTooLong) => true,
693                (InvalidHeaderInputSize(size), InvalidHeaderInputSize(other_size)) => {
694                    size == other_size
695                }
696                (InvalidHeaderLen(size), InvalidHeaderLen(other_size)) => size == other_size,
697                (InvalidMemoryAccess(ref e), InvalidMemoryAccess(ref other_e)) => {
698                    format!("{e}").eq(&format!("{other_e}"))
699                }
700                (InvalidVolatileAccess(ref e), InvalidVolatileAccess(ref other_e)) => {
701                    format!("{e}").eq(&format!("{other_e}"))
702                }
703                (UnexpectedReadOnlyDescriptor, UnexpectedReadOnlyDescriptor) => true,
704                (UnexpectedWriteOnlyDescriptor, UnexpectedWriteOnlyDescriptor) => true,
705                _ => false,
706            }
707        }
708    }
709
710    // Random values to be used by the tests for the header fields.
711    const SRC_CID: u64 = 1;
712    const DST_CID: u64 = 2;
713    const SRC_PORT: u32 = 3;
714    const DST_PORT: u32 = 4;
715    const LEN: u32 = 16;
716    const TYPE: u16 = 5;
717    const OP: u16 = 6;
718    const FLAGS: u32 = 7;
719    const FLAG: u32 = 8;
720    const BUF_ALLOC: u32 = 256;
721    const FWD_CNT: u32 = 9;
722
723    const MAX_PKT_BUF_SIZE: u32 = 64 * 1024;
724
725    #[test]
726    fn test_from_rx_virtq_chain() {
727        let mem: GuestMemoryMmap =
728            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x1000_0000)]).unwrap();
729
730        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
731        let v = vec![
732            // A device-readable packet header descriptor should be invalid.
733            RawDescriptor::from(SplitDescriptor::new(0x10_0000, 0x100, 0, 0)),
734            RawDescriptor::from(SplitDescriptor::new(
735                0x20_0000,
736                0x100,
737                VRING_DESC_F_WRITE as u16,
738                0,
739            )),
740        ];
741        let queue = MockSplitQueue::new(&mem, 16);
742        let mut chain = queue.build_desc_chain(&v).unwrap();
743        assert_eq!(
744            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
745            Error::UnexpectedReadOnlyDescriptor
746        );
747
748        let v = vec![
749            // A header length < PKT_HEADER_SIZE is invalid.
750            RawDescriptor::from(SplitDescriptor::new(
751                0x10_0000,
752                PKT_HEADER_SIZE as u32 - 1,
753                VRING_DESC_F_WRITE as u16,
754                0,
755            )),
756            RawDescriptor::from(SplitDescriptor::new(
757                0x20_0000,
758                0x100,
759                VRING_DESC_F_WRITE as u16,
760                0,
761            )),
762        ];
763        let mut chain = queue.build_desc_chain(&v).unwrap();
764        assert_eq!(
765            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
766            Error::DescriptorLengthTooSmall
767        );
768
769        let v = vec![
770            RawDescriptor::from(SplitDescriptor::new(
771                0x10_0000,
772                PKT_HEADER_SIZE as u32,
773                VRING_DESC_F_WRITE as u16,
774                0,
775            )),
776            RawDescriptor::from(SplitDescriptor::new(
777                0x20_0000,
778                MAX_PKT_BUF_SIZE + 1,
779                VRING_DESC_F_WRITE as u16,
780                0,
781            )),
782        ];
783        let mut chain = queue.build_desc_chain(&v).unwrap();
784        assert_eq!(
785            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
786            Error::DescriptorLengthTooLong
787        );
788
789        let v = vec![
790            // The data descriptor should always be present on the RX path.
791            RawDescriptor::from(SplitDescriptor::new(
792                0x10_0000,
793                PKT_HEADER_SIZE as u32,
794                VRING_DESC_F_WRITE as u16,
795                0,
796            )),
797        ];
798        let mut chain = queue.build_desc_chain(&v).unwrap();
799        assert_eq!(
800            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
801            Error::DescriptorChainTooShort
802        );
803
804        let v = vec![
805            RawDescriptor::from(SplitDescriptor::new(0x10_0000, 0x100, 0, 0)),
806            RawDescriptor::from(SplitDescriptor::new(
807                0x20_0000,
808                0x100,
809                VRING_DESC_F_WRITE as u16,
810                0,
811            )),
812        ];
813        let mut chain = queue.build_desc_chain(&v).unwrap();
814        assert_eq!(
815            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
816            Error::UnexpectedReadOnlyDescriptor
817        );
818
819        let mem: GuestMemoryMmap =
820            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0004)]).unwrap();
821
822        let v = vec![
823            // The header doesn't fit entirely in the memory bounds.
824            RawDescriptor::from(SplitDescriptor::new(
825                0x10_0000,
826                0x100,
827                VRING_DESC_F_WRITE as u16,
828                0,
829            )),
830            RawDescriptor::from(SplitDescriptor::new(
831                0x20_0000,
832                0x100,
833                VRING_DESC_F_WRITE as u16,
834                0,
835            )),
836        ];
837        let queue = MockSplitQueue::new(&mem, 16);
838        let mut chain = queue.build_desc_chain(&v).unwrap();
839        assert_eq!(
840            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
841            Error::InvalidMemoryAccess(GuestMemoryError::InvalidBackendAddress)
842        );
843
844        let v = vec![
845            // The header is outside the memory bounds.
846            RawDescriptor::from(SplitDescriptor::new(
847                0x20_0000,
848                0x100,
849                VRING_DESC_F_WRITE as u16,
850                0,
851            )),
852            RawDescriptor::from(SplitDescriptor::new(
853                0x30_0000,
854                0x100,
855                VRING_DESC_F_WRITE as u16,
856                0,
857            )),
858        ];
859        let mut chain = queue.build_desc_chain(&v).unwrap();
860        assert_eq!(
861            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
862            Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress(
863                0x20_0000
864            )))
865        );
866
867        let v = vec![
868            RawDescriptor::from(SplitDescriptor::new(
869                0x5_0000,
870                0x100,
871                VRING_DESC_F_WRITE as u16,
872                0,
873            )),
874            // A device-readable packet data descriptor should be invalid.
875            RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, 0, 0)),
876        ];
877        let mut chain = queue.build_desc_chain(&v).unwrap();
878        assert_eq!(
879            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
880            Error::UnexpectedReadOnlyDescriptor
881        );
882        let v = vec![
883            RawDescriptor::from(SplitDescriptor::new(
884                0x5_0000,
885                0x100,
886                VRING_DESC_F_WRITE as u16,
887                0,
888            )),
889            // The data array doesn't fit entirely in the memory bounds.
890            RawDescriptor::from(SplitDescriptor::new(
891                0x10_0000,
892                0x100,
893                VRING_DESC_F_WRITE as u16,
894                0,
895            )),
896        ];
897        let mut chain = queue.build_desc_chain(&v).unwrap();
898        assert_eq!(
899            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
900            Error::InvalidMemoryAccess(GuestMemoryError::InvalidBackendAddress)
901        );
902
903        let v = vec![
904            RawDescriptor::from(SplitDescriptor::new(
905                0x5_0000,
906                0x100,
907                VRING_DESC_F_WRITE as u16,
908                0,
909            )),
910            // The data array is outside the memory bounds.
911            RawDescriptor::from(SplitDescriptor::new(
912                0x20_0000,
913                0x100,
914                VRING_DESC_F_WRITE as u16,
915                0,
916            )),
917        ];
918        let mut chain = queue.build_desc_chain(&v).unwrap();
919        assert_eq!(
920            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
921            Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress(
922                0x20_0000
923            )))
924        );
925
926        // Let's also test a valid descriptor chain.
927        let v = vec![
928            RawDescriptor::from(SplitDescriptor::new(
929                0x5_0000,
930                0x100,
931                VRING_DESC_F_WRITE as u16,
932                0,
933            )),
934            RawDescriptor::from(SplitDescriptor::new(
935                0x8_0000,
936                0x100,
937                VRING_DESC_F_WRITE as u16,
938                0,
939            )),
940        ];
941        let mut chain = queue.build_desc_chain(&v).unwrap();
942
943        let packet = VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
944        assert_eq!(packet.header, PacketHeader::default());
945        let header = packet.header_slice();
946        assert_eq!(
947            header.ptr_guard().as_ptr(),
948            mem.get_host_address(GuestAddress(0x5_0000)).unwrap()
949        );
950        assert_eq!(header.len(), PKT_HEADER_SIZE);
951
952        let data = packet.data_slice().unwrap();
953        assert_eq!(
954            data.ptr_guard().as_ptr(),
955            mem.get_host_address(GuestAddress(0x8_0000)).unwrap()
956        );
957        assert_eq!(data.len(), 0x100);
958
959        // If we try to get a vsock packet again, it fails because we already consumed all the
960        // descriptors from the chain.
961        assert_eq!(
962            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
963            Error::DescriptorChainTooShort
964        );
965
966        // Let's also test a valid descriptor chain, with both header and data on a single
967        // descriptor.
968        let v = vec![RawDescriptor::from(SplitDescriptor::new(
969            0x5_0000,
970            PKT_HEADER_SIZE as u32 + 0x100,
971            VRING_DESC_F_WRITE as u16,
972            0,
973        ))];
974        let mut chain = queue.build_desc_chain(&v).unwrap();
975
976        let packet = VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
977        assert_eq!(packet.header, PacketHeader::default());
978        let header = packet.header_slice();
979        assert_eq!(
980            header.ptr_guard().as_ptr(),
981            mem.get_host_address(GuestAddress(0x5_0000)).unwrap()
982        );
983        assert_eq!(header.len(), PKT_HEADER_SIZE);
984
985        let data = packet.data_slice().unwrap();
986        assert_eq!(
987            data.ptr_guard().as_ptr(),
988            mem.get_host_address(GuestAddress(0x5_0000 + PKT_HEADER_SIZE as u64))
989                .unwrap()
990        );
991        assert_eq!(data.len(), 0x100);
992    }
993
994    #[test]
995    fn test_from_tx_virtq_chain() {
996        let mem: GuestMemoryMmap =
997            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x1000_0000)]).unwrap();
998
999        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
1000        let v = vec![
1001            // A device-writable packet header descriptor should be invalid.
1002            RawDescriptor::from(SplitDescriptor::new(
1003                0x10_0000,
1004                0x100,
1005                VRING_DESC_F_WRITE as u16,
1006                0,
1007            )),
1008            RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)),
1009        ];
1010        let queue = MockSplitQueue::new(&mem, 16);
1011        let mut chain = queue.build_desc_chain(&v).unwrap();
1012        assert_eq!(
1013            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1014            Error::UnexpectedWriteOnlyDescriptor
1015        );
1016
1017        let v = vec![
1018            // A header length < PKT_HEADER_SIZE is invalid.
1019            RawDescriptor::from(SplitDescriptor::new(
1020                0x10_0000,
1021                PKT_HEADER_SIZE as u32 - 1,
1022                0,
1023                0,
1024            )),
1025            RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)),
1026        ];
1027        let mut chain = queue.build_desc_chain(&v).unwrap();
1028        assert_eq!(
1029            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1030            Error::DescriptorLengthTooSmall
1031        );
1032
1033        // On the TX path, it is allowed to not have a data descriptor.
1034        let v = vec![RawDescriptor::from(SplitDescriptor::new(
1035            0x10_0000,
1036            PKT_HEADER_SIZE as u32,
1037            0,
1038            0,
1039        ))];
1040        let mut chain = queue.build_desc_chain(&v).unwrap();
1041
1042        let header = PacketHeader {
1043            src_cid: SRC_CID.into(),
1044            dst_cid: DST_CID.into(),
1045            src_port: SRC_PORT.into(),
1046            dst_port: DST_PORT.into(),
1047            len: 0.into(),
1048            type_: 0.into(),
1049            op: 0.into(),
1050            flags: 0.into(),
1051            buf_alloc: 0.into(),
1052            fwd_cnt: 0.into(),
1053        };
1054        mem.write_obj(header, GuestAddress(0x10_0000)).unwrap();
1055
1056        let packet = VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1057        assert_eq!(packet.header, header);
1058        let header_slice = packet.header_slice();
1059        assert_eq!(
1060            header_slice.ptr_guard().as_ptr(),
1061            mem.get_host_address(GuestAddress(0x10_0000)).unwrap()
1062        );
1063        assert_eq!(header_slice.len(), PKT_HEADER_SIZE);
1064        assert!(packet.data_slice().is_none());
1065
1066        let mem: GuestMemoryMmap =
1067            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0004)]).unwrap();
1068
1069        let v = vec![
1070            // The header doesn't fit entirely in the memory bounds.
1071            RawDescriptor::from(SplitDescriptor::new(0x10_0000, 0x100, 0, 0)),
1072            RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)),
1073        ];
1074        let queue = MockSplitQueue::new(&mem, 16);
1075        let mut chain = queue.build_desc_chain(&v).unwrap();
1076        assert_eq!(
1077            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1078            Error::InvalidMemoryAccess(GuestMemoryError::InvalidBackendAddress)
1079        );
1080
1081        let v = vec![
1082            // The header is outside the memory bounds.
1083            RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)),
1084            RawDescriptor::from(SplitDescriptor::new(0x30_0000, 0x100, 0, 0)),
1085        ];
1086        let mut chain = queue.build_desc_chain(&v).unwrap();
1087        assert_eq!(
1088            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1089            Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress(
1090                0x20_0000
1091            )))
1092        );
1093
1094        // Write some non-zero value to the `len` field of the header, which means there is also
1095        // a data descriptor in the chain, first with a value that exceeds the maximum allowed one.
1096        let header = PacketHeader {
1097            src_cid: SRC_CID.into(),
1098            dst_cid: DST_CID.into(),
1099            src_port: SRC_PORT.into(),
1100            dst_port: DST_PORT.into(),
1101            len: (MAX_PKT_BUF_SIZE + 1).into(),
1102            type_: 0.into(),
1103            op: 0.into(),
1104            flags: 0.into(),
1105            buf_alloc: 0.into(),
1106            fwd_cnt: 0.into(),
1107        };
1108        mem.write_obj(header, GuestAddress(0x5_0000)).unwrap();
1109        let v = vec![
1110            RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)),
1111            RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, 0, 0)),
1112        ];
1113        let mut chain = queue.build_desc_chain(&v).unwrap();
1114        assert_eq!(
1115            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1116            Error::InvalidHeaderLen(MAX_PKT_BUF_SIZE + 1)
1117        );
1118
1119        // Write some non-zero, valid value to the `len` field of the header.
1120        let header = PacketHeader {
1121            src_cid: SRC_CID.into(),
1122            dst_cid: DST_CID.into(),
1123            src_port: SRC_PORT.into(),
1124            dst_port: DST_PORT.into(),
1125            len: LEN.into(),
1126            type_: 0.into(),
1127            op: 0.into(),
1128            flags: 0.into(),
1129            buf_alloc: 0.into(),
1130            fwd_cnt: 0.into(),
1131        };
1132        mem.write_obj(header, GuestAddress(0x5_0000)).unwrap();
1133        let v = vec![
1134            // The data descriptor is missing.
1135            RawDescriptor::from(SplitDescriptor::new(0x5_0000, PKT_HEADER_SIZE as u32, 0, 0)),
1136        ];
1137        let mut chain = queue.build_desc_chain(&v).unwrap();
1138        assert_eq!(
1139            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1140            Error::DescriptorChainTooShort
1141        );
1142
1143        let v = vec![
1144            RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)),
1145            // The data array doesn't fit entirely in the memory bounds.
1146            RawDescriptor::from(SplitDescriptor::new(0x10_0000, 0x100, 0, 0)),
1147        ];
1148        let mut chain = queue.build_desc_chain(&v).unwrap();
1149        assert_eq!(
1150            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1151            Error::InvalidMemoryAccess(GuestMemoryError::InvalidBackendAddress)
1152        );
1153
1154        let v = vec![
1155            RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)),
1156            // The data array is outside the memory bounds.
1157            RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)),
1158        ];
1159        let mut chain = queue.build_desc_chain(&v).unwrap();
1160        assert_eq!(
1161            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1162            Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress(
1163                0x20_0000
1164            )))
1165        );
1166
1167        let v = vec![
1168            RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)),
1169            // A device-writable packet data descriptor should be invalid.
1170            RawDescriptor::from(SplitDescriptor::new(
1171                0x8_0000,
1172                0x100,
1173                VRING_DESC_F_WRITE as u16,
1174                0,
1175            )),
1176        ];
1177        let mut chain = queue.build_desc_chain(&v).unwrap();
1178        assert_eq!(
1179            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1180            Error::UnexpectedWriteOnlyDescriptor
1181        );
1182
1183        let v = vec![
1184            RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)),
1185            // A data length < the length of data as described by the header.
1186            RawDescriptor::from(SplitDescriptor::new(0x8_0000, LEN - 1, 0, 0)),
1187        ];
1188        let mut chain = queue.build_desc_chain(&v).unwrap();
1189        assert_eq!(
1190            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1191            Error::DescriptorLengthTooSmall
1192        );
1193
1194        // Let's also test a valid descriptor chain, with both header and data.
1195        let v = vec![
1196            RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)),
1197            RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, 0, 0)),
1198        ];
1199        let mut chain = queue.build_desc_chain(&v).unwrap();
1200
1201        let packet = VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1202        assert_eq!(packet.header, header);
1203        let header_slice = packet.header_slice();
1204        assert_eq!(
1205            header_slice.ptr_guard().as_ptr(),
1206            mem.get_host_address(GuestAddress(0x5_0000)).unwrap()
1207        );
1208        assert_eq!(header_slice.len(), PKT_HEADER_SIZE);
1209        // The `len` field of the header was set to 16.
1210        assert_eq!(packet.len(), LEN);
1211
1212        let data = packet.data_slice().unwrap();
1213        assert_eq!(
1214            data.ptr_guard().as_ptr(),
1215            mem.get_host_address(GuestAddress(0x8_0000)).unwrap()
1216        );
1217        assert_eq!(data.len(), LEN as usize);
1218
1219        // If we try to get a vsock packet again, it fails because we already consumed all the
1220        // descriptors from the chain.
1221        assert_eq!(
1222            VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(),
1223            Error::DescriptorChainTooShort
1224        );
1225
1226        // Let's also test a valid descriptor chain, with both header and data on a single
1227        // descriptor.
1228        let v = vec![RawDescriptor::from(SplitDescriptor::new(
1229            0x5_0000,
1230            PKT_HEADER_SIZE as u32 + 0x100,
1231            0,
1232            0,
1233        ))];
1234        let mut chain = queue.build_desc_chain(&v).unwrap();
1235
1236        let packet = VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1237        assert_eq!(packet.header, header);
1238        let header_slice = packet.header_slice();
1239        assert_eq!(
1240            header_slice.ptr_guard().as_ptr(),
1241            mem.get_host_address(GuestAddress(0x5_0000)).unwrap()
1242        );
1243        assert_eq!(header_slice.len(), PKT_HEADER_SIZE);
1244        // The `len` field of the header was set to 16.
1245        assert_eq!(packet.len(), LEN);
1246
1247        let data = packet.data_slice().unwrap();
1248        assert_eq!(
1249            data.ptr_guard().as_ptr(),
1250            mem.get_host_address(GuestAddress(0x5_0000 + PKT_HEADER_SIZE as u64))
1251                .unwrap()
1252        );
1253        assert_eq!(data.len(), LEN as usize);
1254    }
1255
1256    #[test]
1257    fn test_header_set_get() {
1258        let mem: GuestMemoryMmap =
1259            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x30_0000)]).unwrap();
1260        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
1261        let v = vec![
1262            RawDescriptor::from(SplitDescriptor::new(
1263                0x10_0000,
1264                0x100,
1265                VRING_DESC_F_WRITE as u16,
1266                0,
1267            )),
1268            RawDescriptor::from(SplitDescriptor::new(
1269                0x20_0000,
1270                0x100,
1271                VRING_DESC_F_WRITE as u16,
1272                0,
1273            )),
1274        ];
1275        let queue = MockSplitQueue::new(&mem, 16);
1276        let mut chain = queue.build_desc_chain(&v).unwrap();
1277
1278        let mut packet =
1279            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1280        packet
1281            .set_src_cid(SRC_CID)
1282            .set_dst_cid(DST_CID)
1283            .set_src_port(SRC_PORT)
1284            .set_dst_port(DST_PORT)
1285            .set_len(LEN)
1286            .set_type(TYPE)
1287            .set_op(OP)
1288            .set_flags(FLAGS)
1289            .set_flag(FLAG)
1290            .set_buf_alloc(BUF_ALLOC)
1291            .set_fwd_cnt(FWD_CNT);
1292
1293        assert_eq!(packet.flags(), FLAGS | FLAG);
1294        assert_eq!(packet.op(), OP);
1295        assert_eq!(packet.type_(), TYPE);
1296        assert_eq!(packet.dst_cid(), DST_CID);
1297        assert_eq!(packet.dst_port(), DST_PORT);
1298        assert_eq!(packet.src_cid(), SRC_CID);
1299        assert_eq!(packet.src_port(), SRC_PORT);
1300        assert_eq!(packet.fwd_cnt(), FWD_CNT);
1301        assert_eq!(packet.len(), LEN);
1302        assert_eq!(packet.buf_alloc(), BUF_ALLOC);
1303
1304        let expected_header = PacketHeader {
1305            src_cid: SRC_CID.into(),
1306            dst_cid: DST_CID.into(),
1307            src_port: SRC_PORT.into(),
1308            dst_port: DST_PORT.into(),
1309            len: LEN.into(),
1310            type_: TYPE.into(),
1311            op: OP.into(),
1312            flags: (FLAGS | FLAG).into(),
1313            buf_alloc: BUF_ALLOC.into(),
1314            fwd_cnt: FWD_CNT.into(),
1315        };
1316
1317        assert_eq!(packet.header, expected_header);
1318        assert_eq!(
1319            u64::from_le(
1320                packet
1321                    .header_slice()
1322                    .read_obj::<u64>(SRC_CID_OFFSET)
1323                    .unwrap()
1324            ),
1325            SRC_CID
1326        );
1327        assert_eq!(
1328            u64::from_le(
1329                packet
1330                    .header_slice()
1331                    .read_obj::<u64>(DST_CID_OFFSET)
1332                    .unwrap()
1333            ),
1334            DST_CID
1335        );
1336        assert_eq!(
1337            u32::from_le(
1338                packet
1339                    .header_slice()
1340                    .read_obj::<u32>(SRC_PORT_OFFSET)
1341                    .unwrap()
1342            ),
1343            SRC_PORT
1344        );
1345        assert_eq!(
1346            u32::from_le(
1347                packet
1348                    .header_slice()
1349                    .read_obj::<u32>(DST_PORT_OFFSET)
1350                    .unwrap()
1351            ),
1352            DST_PORT,
1353        );
1354        assert_eq!(
1355            u32::from_le(packet.header_slice().read_obj::<u32>(LEN_OFFSET).unwrap()),
1356            LEN
1357        );
1358        assert_eq!(
1359            u16::from_le(packet.header_slice().read_obj::<u16>(TYPE_OFFSET).unwrap()),
1360            TYPE
1361        );
1362        assert_eq!(
1363            u16::from_le(packet.header_slice().read_obj::<u16>(OP_OFFSET).unwrap()),
1364            OP
1365        );
1366        assert_eq!(
1367            u32::from_le(packet.header_slice().read_obj::<u32>(FLAGS_OFFSET).unwrap()),
1368            FLAGS | FLAG
1369        );
1370        assert_eq!(
1371            u32::from_le(
1372                packet
1373                    .header_slice()
1374                    .read_obj::<u32>(BUF_ALLOC_OFFSET)
1375                    .unwrap()
1376            ),
1377            BUF_ALLOC
1378        );
1379        assert_eq!(
1380            u32::from_le(
1381                packet
1382                    .header_slice()
1383                    .read_obj::<u32>(FWD_CNT_OFFSET)
1384                    .unwrap()
1385            ),
1386            FWD_CNT
1387        );
1388    }
1389
1390    #[test]
1391    fn test_set_header_from_raw() {
1392        let mem: GuestMemoryMmap =
1393            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x30_0000)]).unwrap();
1394        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
1395        let v = vec![
1396            RawDescriptor::from(SplitDescriptor::new(
1397                0x10_0000,
1398                0x100,
1399                VRING_DESC_F_WRITE as u16,
1400                0,
1401            )),
1402            RawDescriptor::from(SplitDescriptor::new(
1403                0x20_0000,
1404                0x100,
1405                VRING_DESC_F_WRITE as u16,
1406                0,
1407            )),
1408        ];
1409        let queue = MockSplitQueue::new(&mem, 16);
1410        let mut chain = queue.build_desc_chain(&v).unwrap();
1411
1412        let mut packet =
1413            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1414
1415        let header = PacketHeader {
1416            src_cid: SRC_CID.into(),
1417            dst_cid: DST_CID.into(),
1418            src_port: SRC_PORT.into(),
1419            dst_port: DST_PORT.into(),
1420            len: LEN.into(),
1421            type_: TYPE.into(),
1422            op: OP.into(),
1423            flags: (FLAGS | FLAG).into(),
1424            buf_alloc: BUF_ALLOC.into(),
1425            fwd_cnt: FWD_CNT.into(),
1426        };
1427
1428        // SAFETY: created from an existing packet header.
1429        let slice = unsafe {
1430            std::slice::from_raw_parts(
1431                (&header as *const PacketHeader) as *const u8,
1432                std::mem::size_of::<PacketHeader>(),
1433            )
1434        };
1435        assert_eq!(packet.header, PacketHeader::default());
1436        packet.set_header_from_raw(slice).unwrap();
1437        assert_eq!(packet.header, header);
1438        let header_from_slice: PacketHeader = packet.header_slice().read_obj(0).unwrap();
1439        assert_eq!(header_from_slice, header);
1440
1441        let invalid_slice = [0; PKT_HEADER_SIZE - 1];
1442        assert_eq!(
1443            packet.set_header_from_raw(&invalid_slice).unwrap_err(),
1444            Error::InvalidHeaderInputSize(PKT_HEADER_SIZE - 1)
1445        );
1446    }
1447
1448    #[test]
1449    fn test_packet_new() {
1450        let mut pkt_raw = [0u8; PKT_HEADER_SIZE + LEN as usize];
1451        let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);
1452        // SAFETY: safe because ``hdr_raw` and `data_raw` live for as long as
1453        // the scope of the current test.
1454        let packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };
1455        assert_eq!(
1456            packet.header_slice.ptr_guard().as_ptr(),
1457            hdr_raw.as_mut_ptr(),
1458        );
1459        assert_eq!(packet.header_slice.len(), PKT_HEADER_SIZE);
1460        assert_eq!(packet.header, PacketHeader::default());
1461        assert_eq!(
1462            packet.data_slice.unwrap().ptr_guard().as_ptr(),
1463            data_raw.as_mut_ptr(),
1464        );
1465        assert_eq!(packet.data_slice.unwrap().len(), LEN as usize);
1466
1467        // SAFETY: Safe because ``hdr_raw` and `data_raw` live as long as the
1468        // scope of the current test.
1469        let packet = unsafe { VsockPacket::new(hdr_raw, None).unwrap() };
1470        assert_eq!(
1471            packet.header_slice.ptr_guard().as_ptr(),
1472            hdr_raw.as_mut_ptr(),
1473        );
1474        assert_eq!(packet.header, PacketHeader::default());
1475        assert!(packet.data_slice.is_none());
1476
1477        let mut hdr_raw = [0u8; PKT_HEADER_SIZE - 1];
1478        assert_eq!(
1479            // SAFETY: Safe because ``hdr_raw` lives for as long as the scope of the current test.
1480            unsafe { VsockPacket::new(&mut hdr_raw, None).unwrap_err() },
1481            Error::InvalidHeaderInputSize(PKT_HEADER_SIZE - 1)
1482        );
1483    }
1484
1485    #[test]
1486    #[should_panic]
1487    fn test_set_header_field_with_invalid_offset() {
1488        const INVALID_OFFSET: usize = 50;
1489
1490        let mem: GuestMemoryMmap =
1491            GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x30_0000)]).unwrap();
1492        // The `build_desc_chain` function will populate the `NEXT` related flags and field.
1493        let v = vec![
1494            RawDescriptor::from(SplitDescriptor::new(
1495                0x10_0000,
1496                0x100,
1497                VRING_DESC_F_WRITE as u16,
1498                0,
1499            )),
1500            RawDescriptor::from(SplitDescriptor::new(
1501                0x20_0000,
1502                0x100,
1503                VRING_DESC_F_WRITE as u16,
1504                0,
1505            )),
1506        ];
1507        let queue = MockSplitQueue::new(&mem, 16);
1508        let mut chain = queue.build_desc_chain(&v).unwrap();
1509
1510        let mut packet =
1511            VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap();
1512        // Set the `src_cid` of the header, but use an invalid offset for that.
1513        set_header_field!(packet, src_cid, INVALID_OFFSET, SRC_CID);
1514    }
1515}