s2n_quic_platform/message/
msg.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    features,
6    message::{cmsg, cmsg::Encoder, Message as MessageTrait},
7};
8use core::{
9    alloc::Layout,
10    mem::{size_of, size_of_val},
11};
12use libc::{iovec, msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6};
13use s2n_quic_core::{
14    inet::{
15        datagram, IpV4Address, IpV6Address, SocketAddress, SocketAddressV4, SocketAddressV6,
16        Unspecified,
17    },
18    io::tx,
19    path::{self, Handle as _},
20};
21
22mod ext;
23mod handle;
24#[cfg(test)]
25mod tests;
26
27pub use ext::Ext;
28pub use handle::Handle;
29pub use libc::msghdr as Message;
30
31impl MessageTrait for msghdr {
32    type Handle = Handle;
33
34    const SUPPORTS_GSO: bool = features::gso::IS_SUPPORTED;
35    const SUPPORTS_ECN: bool = features::tos::IS_SUPPORTED;
36    const SUPPORTS_FLOW_LABELS: bool = true;
37
38    #[inline]
39    fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage {
40        unsafe { alloc(entries, payload_len, offset, |msg| msg) }
41    }
42
43    #[inline]
44    fn payload_len(&self) -> usize {
45        debug_assert!(!self.msg_iov.is_null());
46        let len = unsafe { (*self.msg_iov).iov_len as _ };
47        debug_assert!(len <= u16::MAX as usize);
48        len
49    }
50
51    #[inline]
52    unsafe fn set_payload_len(&mut self, payload_len: usize) {
53        debug_assert!(payload_len <= u16::MAX as usize);
54        debug_assert!(!self.msg_iov.is_null());
55        (*self.msg_iov).iov_len = payload_len;
56    }
57
58    #[inline]
59    fn set_segment_size(&mut self, size: usize) {
60        debug_assert!(size <= u16::MAX as usize);
61        self.cmsg_encoder().encode_gso(size as _).unwrap();
62    }
63
64    #[inline]
65    unsafe fn reset(&mut self, mtu: usize) {
66        // reset the payload
67        self.set_payload_len(mtu);
68
69        // reset the address
70        self.set_remote_address(&SocketAddress::IpV6(Default::default()));
71
72        #[inline]
73        unsafe fn check_cmsg(msghdr: &msghdr) {
74            if cfg!(debug_assertions) {
75                let ptr = msghdr.msg_control as *mut u8;
76                let cmsg = core::slice::from_raw_parts_mut(ptr, cmsg::MAX_LEN);
77                // make sure nothing was written to the control message if it was set to 0
78                #[cfg(not(kani))]
79                {
80                    assert!(cmsg.iter().all(|v| *v == 0), "msg_control was not cleared");
81                }
82
83                #[cfg(kani)]
84                {
85                    let index: usize = kani::any();
86                    kani::assume(index < cmsg.len());
87                    assert_eq!(cmsg[index], 0);
88                }
89            }
90        }
91
92        // make sure we didn't get any data written without setting the len
93        if self.msg_controllen == 0 {
94            check_cmsg(self);
95        }
96
97        // reset the control messages if it isn't set to the default value
98
99        // some platforms encode lengths as `u32` so we cast everything to be safe
100        #[allow(clippy::unnecessary_cast)]
101        let msg_controllen = self.msg_controllen as usize;
102
103        if msg_controllen != cmsg::MAX_LEN {
104            core::slice::from_raw_parts_mut(self.msg_control as *mut u8, msg_controllen).fill(0);
105        }
106
107        check_cmsg(self);
108
109        self.msg_controllen = cmsg::MAX_LEN as _;
110    }
111
112    #[inline]
113    fn payload_ptr_mut(&mut self) -> *mut u8 {
114        unsafe {
115            let iovec = &mut *self.msg_iov;
116            iovec.iov_base as *mut _
117        }
118    }
119
120    #[inline]
121    fn validate_replication(source: &Self, dest: &Self) {
122        assert_eq!(source.msg_name, dest.msg_name);
123        assert_eq!(source.msg_iov, dest.msg_iov);
124        assert_eq!(source.msg_control, dest.msg_control);
125    }
126
127    #[inline]
128    fn rx_read(
129        &mut self,
130        local_address: &path::LocalAddress,
131    ) -> Option<super::RxMessage<'_, Self::Handle>> {
132        if cfg!(test) {
133            assert_eq!(
134                self.msg_flags & libc::MSG_CTRUNC,
135                0,
136                "control message buffers should always have enough capacity"
137            );
138        }
139
140        let (mut header, cmsg) = self.header()?;
141
142        // only copy the port if we are told the IP address
143        if !header.path.local_address.ip().is_unspecified() {
144            header.path.local_address.set_port(local_address.port());
145        } else {
146            header.path.local_address = *local_address;
147        }
148
149        let payload = self.payload_mut();
150
151        let segment_size = if cmsg.segment_size == 0 {
152            payload.len()
153        } else {
154            cmsg.segment_size as _
155        };
156
157        let message = crate::message::RxMessage {
158            header,
159            segment_size,
160            payload,
161        };
162
163        Some(message)
164    }
165
166    #[inline]
167    fn tx_write<M: tx::Message<Handle = Self::Handle>>(
168        &mut self,
169        mut message: M,
170    ) -> Result<usize, tx::Error> {
171        let payload = self.payload_mut();
172
173        let max_len = payload.len();
174        let len = message.write_payload(tx::PayloadBuffer::new(payload), 0)?;
175
176        debug_assert_ne!(len, 0);
177        debug_assert!(len <= max_len);
178        let len = len.min(max_len);
179
180        debug_assert_eq!(
181            cmsg::MAX_LEN,
182            self.msg_controllen as _,
183            "message should be reset before writing"
184        );
185        self.msg_controllen = 0;
186
187        unsafe {
188            self.set_payload_len(len);
189        }
190
191        let handle = *message.path_handle();
192        handle.update_msg_hdr(self);
193        self.cmsg_encoder()
194            .encode_ecn(message.ecn(), &handle.remote_address.0)
195            .unwrap();
196
197        Ok(len)
198    }
199}
200
201/// Allocates a region of memory holding `entries` number of `T` messages, each with `payload_len`
202/// payloads.
203///
204/// # Safety
205///
206/// * `T` can be initialized with zero bytes and still be valid
207#[inline]
208pub(super) unsafe fn alloc<T: Copy + Sized, F: Fn(&mut T) -> &mut msghdr>(
209    entries: u32,
210    payload_len: u32,
211    offset: usize,
212    on_entry: F,
213) -> super::Storage {
214    // calculate the layout of the storage for the given configuration
215    let (layout, entry_offset, header_offset, payload_offset) =
216        layout::<T>(entries, payload_len, offset);
217
218    // allocate a single contiguous block of memory
219    let storage = super::Storage::new(layout);
220
221    {
222        let ptr = storage.as_ptr();
223
224        // calculate each of the pointers we need to set up a message
225        let mut entry_ptr = ptr.add(entry_offset) as *mut T;
226        let mut header_ptr = ptr.add(header_offset) as *mut Header;
227        let mut payload_ptr = ptr.add(payload_offset);
228
229        for _ in 0..entries {
230            // for each message update all of the pointers to the correct locations
231
232            let entry = on_entry(&mut *entry_ptr);
233            (*header_ptr).update(entry, payload_ptr, payload_len);
234
235            // increment the pointers for the next iteration
236            entry_ptr = entry_ptr.add(1);
237            header_ptr = header_ptr.add(1);
238            payload_ptr = payload_ptr.add(payload_len as _);
239
240            // make sure the pointers are within the bounds of the allocation
241            storage.check_bounds(entry_ptr);
242            storage.check_bounds(header_ptr);
243            storage.check_bounds(payload_ptr);
244        }
245
246        // replicate the primary messages into the secondary region
247        let primary = ptr.add(entry_offset) as *mut T;
248        let secondary = primary.add(entries as _);
249        storage.check_bounds(secondary.add(entries as _));
250        core::ptr::copy_nonoverlapping(primary, secondary, entries as _);
251    }
252
253    storage
254}
255
256/// Computes the following layout
257///
258/// ```ignore
259/// struct Storage {
260///    cursor: Cursor,
261///    headers: [Header; entries],
262///    payloads: [[u8; payload_len]; entries],
263///    entries: [T; entries * 2],
264/// }
265/// ```
266fn layout<T: Copy + Sized>(
267    entries: u32,
268    payload_len: u32,
269    offset: usize,
270) -> (Layout, usize, usize, usize) {
271    let cursor = Layout::array::<u8>(offset).unwrap();
272    let headers = Layout::array::<Header>(entries as _).unwrap();
273    let payloads = Layout::array::<u8>(entries as usize * payload_len as usize).unwrap();
274    // double the number of entries we allocate to support the primary/secondary regions
275    let entries = Layout::array::<T>((entries * 2) as usize).unwrap();
276    let (layout, entry_offset) = cursor.extend(entries).unwrap();
277    let (layout, header_offset) = layout.extend(headers).unwrap();
278    let (layout, payload_offset) = layout.extend(payloads).unwrap();
279    (layout, entry_offset, header_offset, payload_offset)
280}
281
282/// A structure for holding data pointed to in the [`libc::msghdr`] struct.
283struct Header {
284    pub iovec: iovec,
285    pub msg_name: sockaddr_in6,
286    pub cmsg: cmsg::Storage<{ cmsg::MAX_LEN }>,
287}
288
289impl Header {
290    /// sets all of the pointers of the provided `entry` to the correct locations
291    unsafe fn update(&mut self, entry: &mut msghdr, payload: *mut u8, payload_len: u32) {
292        let iovec = &mut self.iovec;
293
294        iovec.iov_base = payload as *mut _;
295        iovec.iov_len = payload_len as _;
296
297        let entry = &mut *entry;
298
299        entry.msg_name = &mut self.msg_name as *mut _ as *mut _;
300        entry.msg_namelen = size_of_val(&self.msg_name) as _;
301        entry.msg_iov = &mut self.iovec as *mut _;
302        entry.msg_iovlen = 1;
303        entry.msg_controllen = self.cmsg.len() as _;
304        entry.msg_control = self.cmsg.as_mut_ptr() as *mut _;
305
306        // make sure that the control pointer is well-aligned
307        debug_assert_eq!(
308            entry
309                .msg_control
310                .align_offset(core::mem::align_of::<cmsg::Storage<{ cmsg::MAX_LEN }>>()),
311            0
312        );
313    }
314}