s2n_quic_platform/
message.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::{alloc::Layout, ptr::NonNull};
5use s2n_quic_core::{inet::datagram, io::tx, path};
6
7#[cfg(s2n_quic_platform_cmsg)]
8pub mod cmsg;
9#[cfg(s2n_quic_platform_socket_mmsg)]
10pub mod mmsg;
11#[cfg(s2n_quic_platform_socket_msg)]
12pub mod msg;
13pub mod simple;
14
15pub mod default {
16    cfg_if::cfg_if! {
17        if #[cfg(s2n_quic_platform_socket_mmsg)] {
18            pub use super::mmsg::*;
19        } else if #[cfg(s2n_quic_platform_socket_msg)] {
20            pub use super::msg::*;
21        } else {
22            pub use super::simple::*;
23        }
24    }
25}
26
27/// Tracks allocations of message ring buffer state
28pub struct Storage {
29    ptr: NonNull<u8>,
30    layout: Layout,
31}
32
33/// Safety: the ring buffer controls access to the underlying storage
34unsafe impl Send for Storage {}
35/// Safety: the ring buffer controls access to the underlying storage
36unsafe impl Sync for Storage {}
37
38impl Storage {
39    #[inline]
40    pub fn new(layout: Layout) -> Self {
41        unsafe {
42            let ptr = alloc::alloc::alloc_zeroed(layout);
43            let ptr = NonNull::new(ptr).expect("could not allocate message storage");
44            Self { layout, ptr }
45        }
46    }
47
48    #[inline]
49    pub fn as_ptr(&self) -> *mut u8 {
50        self.ptr.as_ptr()
51    }
52
53    /// Asserts that the pointer is in bounds of the allocation
54    #[inline]
55    pub fn check_bounds<T: Sized>(&self, ptr: *mut T) {
56        let start = self.as_ptr();
57        let end = unsafe {
58            // Safety: pointer is allocated with the self.layout
59            start.add(self.layout.size())
60        };
61        let allocation_range = start..=end;
62        let actual_end_ptr = ptr as *mut u8;
63        debug_assert!(allocation_range.contains(&actual_end_ptr));
64    }
65}
66
67impl Drop for Storage {
68    fn drop(&mut self) {
69        unsafe {
70            // Safety: pointer was allocated with self.layout
71            alloc::alloc::dealloc(self.as_ptr(), self.layout)
72        }
73    }
74}
75
76/// An abstract message that can be sent and received on a network
77pub trait Message: 'static + Copy {
78    type Handle: path::Handle;
79
80    const SUPPORTS_GSO: bool;
81    const SUPPORTS_ECN: bool;
82    const SUPPORTS_FLOW_LABELS: bool;
83
84    /// Allocates `entries` messages, each with `payload_len` bytes
85    fn alloc(entries: u32, payload_len: u32, offset: usize) -> Storage;
86
87    /// Returns the length of the payload
88    fn payload_len(&self) -> usize;
89
90    /// Sets the payload length for the message
91    ///
92    /// # Safety
93    /// This method should only set the payload less than or
94    /// equal to its initially allocated size.
95    unsafe fn set_payload_len(&mut self, payload_len: usize);
96
97    /// Validates that the `source` message can be replicated to `dest`.
98    ///
99    /// # Panics
100    ///
101    /// This panics when the messages cannot be replicated
102    fn validate_replication(source: &Self, dest: &Self);
103
104    /// Returns a mutable pointer for the message payload
105    fn payload_ptr_mut(&mut self) -> *mut u8;
106
107    /// Returns a mutable slice for the message payload
108    #[inline]
109    fn payload_mut(&mut self) -> &mut [u8] {
110        unsafe { core::slice::from_raw_parts_mut(self.payload_ptr_mut(), self.payload_len()) }
111    }
112
113    /// Sets the segment size for the message payload
114    fn set_segment_size(&mut self, _size: usize) {
115        panic!("cannot use GSO on the current platform");
116    }
117
118    /// Resets the message for future use
119    ///
120    /// # Safety
121    /// This method should only set the MTU to the original value
122    unsafe fn reset(&mut self, mtu: usize);
123
124    /// Reads the message as an RX packet
125    fn rx_read(
126        &mut self,
127        local_address: &path::LocalAddress,
128    ) -> Option<RxMessage<'_, Self::Handle>>;
129
130    /// Writes the message into the TX packet
131    fn tx_write<M: tx::Message<Handle = Self::Handle>>(
132        &mut self,
133        message: M,
134    ) -> Result<usize, tx::Error>;
135}
136
137pub struct RxMessage<'a, Handle: Copy> {
138    /// The received header for the message
139    pub header: datagram::Header<Handle>,
140    /// The number of segments inside the message
141    pub segment_size: usize,
142    /// The full payload of the message
143    pub payload: &'a mut [u8],
144}
145
146impl<Handle: Copy> RxMessage<'_, Handle> {
147    #[inline]
148    pub fn for_each<F: FnMut(datagram::Header<Handle>, &mut [u8])>(self, mut on_packet: F) {
149        // `chunks_mut` doesn't know what to do with zero-sized segments so return early
150        if self.segment_size == 0 {
151            return;
152        }
153
154        for segment in self.payload.chunks_mut(self.segment_size) {
155            on_packet(self.header, segment);
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use bolero::check;
164
165    #[test]
166    #[cfg_attr(kani, kani::proof, kani::unwind(17), kani::solver(minisat))]
167    fn rx_message_test() {
168        let path = bolero::produce::<path::RemoteAddress>();
169        let ecn = bolero::produce();
170        let segment_size = bolero::produce();
171        let max_payload_len = if cfg!(kani) { 16 } else { u16::MAX as usize };
172        let payload_len = 0..=max_payload_len;
173
174        check!()
175            .with_generator((path, ecn, segment_size, payload_len))
176            .cloned()
177            .for_each(|(path, ecn, segment_size, payload_len)| {
178                let mut payload = vec![0u8; payload_len];
179                let rx_message = RxMessage {
180                    header: datagram::Header { path, ecn },
181                    segment_size,
182                    payload: &mut payload,
183                };
184
185                rx_message.for_each(|header, segment| {
186                    assert_eq!(header.path, path);
187                    assert_eq!(header.ecn, ecn);
188                    assert!(segment.len() <= payload_len);
189                    assert!(segment.len() <= segment_size);
190                })
191            })
192    }
193}