s2n_quic_platform/message/cmsg/
decode.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::Storage;
5use crate::features;
6use core::mem::{align_of, size_of};
7use libc::cmsghdr;
8use s2n_quic_core::{ensure, inet::AncillaryData};
9
10/// Decodes a value of type `T` from the given `cmsghdr`
11/// # Safety
12///
13/// `cmsghdr` must refer to a cmsg containing a payload of type `T`
14#[inline]
15pub unsafe fn value_from_bytes<T: Copy>(value: &[u8]) -> Option<T> {
16    use core::mem;
17
18    ensure!(value.len() == size_of::<T>(), None);
19
20    debug_assert!(mem::align_of::<T>() <= mem::align_of::<cmsghdr>());
21
22    let mut v = mem::zeroed::<T>();
23
24    core::ptr::copy_nonoverlapping(value.as_ptr(), &mut v as *mut T as *mut u8, size_of::<T>());
25
26    Some(v)
27}
28
29/// Decodes all recognized control messages in the given `iter` into `AncillaryData`
30#[inline]
31pub fn collect(iter: Iter) -> AncillaryData {
32    let mut data = AncillaryData::default();
33
34    for (cmsg, value) in iter {
35        unsafe {
36            // SAFETY: `Iter` ensures values are aligned
37            collect_item(&mut data, cmsg, value);
38        }
39    }
40
41    data
42}
43
44#[inline]
45unsafe fn collect_item(data: &mut AncillaryData, cmsg: &cmsghdr, value: &[u8]) {
46    macro_rules! decode_error {
47        ($error:expr) => {
48            #[cfg(all(test, feature = "tracing", not(any(kani, miri, fuzz))))]
49            tracing::debug!(
50                error = $error,
51                level = cmsg.cmsg_level,
52                r#type = cmsg.cmsg_type,
53                value = ?value,
54            );
55        }
56    }
57
58    match (cmsg.cmsg_level, cmsg.cmsg_type) {
59        (level, ty) if features::tos::is_match(level, ty) => {
60            if let Some(ecn) = features::tos::decode(value) {
61                data.ecn = ecn;
62            } else {
63                decode_error!("invalid TOS value");
64            }
65        }
66        (level, ty) if features::pktinfo_v4::is_match(level, ty) => {
67            if let Some((local_address, local_interface)) = features::pktinfo_v4::decode(value) {
68                // The port should be specified by a different layer that has that information
69                let port = 0;
70                let local_address = s2n_quic_core::inet::SocketAddressV4::new(local_address, port);
71                data.local_address = local_address.into();
72                data.local_interface = Some(local_interface);
73            } else {
74                decode_error!("invalid pktinfo_v4 value");
75            }
76        }
77        (level, ty) if features::pktinfo_v6::is_match(level, ty) => {
78            if let Some((local_address, local_interface)) = features::pktinfo_v6::decode(value) {
79                // The port should be specified by a different layer that has that information
80                let port = 0;
81                let local_address = s2n_quic_core::inet::SocketAddressV6::new(local_address, port);
82                data.local_address = local_address.into();
83                data.local_interface = Some(local_interface);
84            } else {
85                decode_error!("invalid pktinfo_v6 value");
86            }
87        }
88        (level, ty) if features::gso::is_match(level, ty) => {
89            // ignore GSO settings when reading
90        }
91        (level, ty) if features::gro::is_match(level, ty) => {
92            if let Some(segment_size) = value_from_bytes::<features::gro::Cmsg>(value) {
93                data.segment_size = segment_size as _;
94            } else {
95                decode_error!("invalid gro value");
96            }
97        }
98        _ => {
99            decode_error!("unexpected cmsghdr");
100        }
101    }
102}
103
104pub struct Iter<'a> {
105    cursor: *const u8,
106    len: usize,
107    contents: core::marker::PhantomData<&'a [u8]>,
108}
109
110impl<'a> Iter<'a> {
111    /// Creates a new cmsg::Iter used for iterating over control message headers in the given
112    /// [`Storage`].
113    #[inline]
114    pub fn new<const L: usize>(contents: &'a Storage<L>) -> Iter<'a> {
115        let cursor = contents.as_ptr();
116        let len = contents.len();
117
118        Self {
119            cursor,
120            len,
121            contents: Default::default(),
122        }
123    }
124
125    /// Creates a new cmsg::Iter used for iterating over control message headers in the given slice
126    /// of bytes.
127    ///
128    /// # Safety
129    ///
130    /// * `contents` must be aligned to cmsghdr
131    #[inline]
132    pub unsafe fn from_bytes(contents: &'a [u8]) -> Self {
133        let cursor = contents.as_ptr();
134        let len = contents.len();
135
136        debug_assert_eq!(
137            cursor.align_offset(align_of::<cmsghdr>()),
138            0,
139            "contents must be aligned to cmsghdr"
140        );
141
142        Self {
143            cursor,
144            len,
145            contents: Default::default(),
146        }
147    }
148
149    /// Creates a new cmsg::Iter used for iterating over control message headers in the given
150    /// msghdr.
151    ///
152    /// # Safety
153    ///
154    /// * `contents` must be aligned to cmsghdr
155    /// * `msghdr` must point to a valid control buffer
156    #[inline]
157    pub unsafe fn from_msghdr(msghdr: &'a libc::msghdr) -> Self {
158        let cursor = msghdr.msg_control as *const u8;
159        let len = msghdr.msg_controllen as usize;
160
161        debug_assert_eq!(
162            cursor.align_offset(align_of::<cmsghdr>()),
163            0,
164            "contents must be aligned to cmsghdr"
165        );
166
167        Self {
168            cursor,
169            len,
170            contents: Default::default(),
171        }
172    }
173
174    #[inline]
175    pub fn collect(self) -> AncillaryData {
176        collect(self)
177    }
178}
179
180impl<'a> Iterator for Iter<'a> {
181    type Item = (&'a cmsghdr, &'a [u8]);
182
183    #[inline]
184    fn next(&mut self) -> Option<Self::Item> {
185        unsafe {
186            let cursor = self.cursor;
187
188            // make sure we can decode a cmsghdr
189            self.len.checked_sub(size_of::<cmsghdr>())?;
190            let cmsg = &*(cursor as *const cmsghdr);
191            let data_ptr = cursor.add(size_of::<cmsghdr>());
192
193            let cmsg_len = cmsg.cmsg_len as usize;
194
195            // make sure we have capacity to decode the provided cmsg_len
196            self.len.checked_sub(cmsg_len)?;
197
198            // the cmsg_len includes the header itself so it needs to be subtracted off
199            let data_len = cmsg_len.checked_sub(size_of::<cmsghdr>())?;
200            // construct a slice with the provided data len
201            let data = core::slice::from_raw_parts(data_ptr, data_len);
202
203            // empty messages are invalid
204            if data.is_empty() {
205                return None;
206            }
207
208            // calculate the next message and update the cursor/len
209            {
210                let space = libc::CMSG_SPACE(data_len as _) as usize;
211                debug_assert!(
212                    space >= data_len,
213                    "space ({space}) should be at least of size len ({data_len})"
214                );
215                self.len = self.len.saturating_sub(space);
216                self.cursor = cursor.add(space);
217            }
218
219            Some((cmsg, data))
220        }
221    }
222}