Skip to main content

s2n_quic_platform/message/cmsg/
storage.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{encode, size_of_cmsg};
5use core::{
6    mem::{align_of, size_of},
7    ops::{Deref, DerefMut},
8};
9use libc::cmsghdr;
10
11#[repr(align(8))] // the storage needs to be aligned to the same as `cmsghdr`
12#[derive(Clone, Debug)]
13pub struct Storage<const L: usize>([u8; L]);
14
15impl<const L: usize> Storage<L> {
16    #[inline]
17    pub fn encoder(&mut self) -> Encoder<'_, L> {
18        Encoder {
19            storage: self,
20            cursor: 0,
21        }
22    }
23
24    #[inline]
25    pub fn iter(&self) -> super::decode::Iter<'_> {
26        super::decode::Iter::new(self)
27    }
28}
29
30impl<const L: usize> Default for Storage<L> {
31    #[inline]
32    fn default() -> Self {
33        Self([0; L])
34    }
35}
36
37impl<const L: usize> Deref for Storage<L> {
38    type Target = [u8];
39
40    #[inline]
41    fn deref(&self) -> &[u8] {
42        &self.0
43    }
44}
45
46impl<const L: usize> DerefMut for Storage<L> {
47    #[inline]
48    fn deref_mut(&mut self) -> &mut [u8] {
49        &mut self.0
50    }
51}
52
53pub struct Encoder<'a, const L: usize> {
54    storage: &'a mut Storage<L>,
55    cursor: usize,
56}
57
58impl<'a, const L: usize> Encoder<'a, L> {
59    #[inline]
60    pub fn new(storage: &'a mut Storage<L>) -> Self {
61        Self { storage, cursor: 0 }
62    }
63
64    #[inline]
65    pub fn len(&self) -> usize {
66        self.cursor
67    }
68
69    #[inline]
70    pub fn is_empty(&self) -> bool {
71        self.cursor == 0
72    }
73
74    #[inline]
75    pub fn seek(&mut self, len: usize) {
76        self.cursor += len;
77        debug_assert!(self.cursor <= L);
78    }
79
80    #[inline]
81    pub fn iter(&self) -> super::decode::Iter<'_> {
82        unsafe {
83            // SAFETY: bytes are aligned with Storage type
84            super::decode::Iter::from_bytes(self)
85        }
86    }
87}
88
89impl<const L: usize> Deref for Encoder<'_, L> {
90    type Target = [u8];
91
92    #[inline]
93    fn deref(&self) -> &[u8] {
94        &self.storage[..self.cursor]
95    }
96}
97
98impl<const L: usize> DerefMut for Encoder<'_, L> {
99    #[inline]
100    fn deref_mut(&mut self) -> &mut [u8] {
101        &mut self.storage[..self.cursor]
102    }
103}
104
105impl<const L: usize> super::Encoder for Encoder<'_, L> {
106    #[inline]
107    fn encode_cmsg<T: Copy>(
108        &mut self,
109        level: libc::c_int,
110        ty: libc::c_int,
111        value: T,
112    ) -> Result<usize, encode::Error> {
113        unsafe {
114            debug_assert!(
115                align_of::<T>() <= align_of::<cmsghdr>(),
116                "alignment of T should be less than or equal to cmsghdr"
117            );
118
119            // CMSG_SPACE() returns the number of bytes an ancillary element
120            // with payload of the passed data length occupies.
121            let element_len = size_of_cmsg::<T>();
122            debug_assert_ne!(element_len, 0);
123            debug_assert_eq!(libc::CMSG_SPACE(size_of::<T>() as _) as usize, element_len);
124
125            let new_cursor = self.cursor.checked_add(element_len).ok_or(encode::Error)?;
126
127            self.storage
128                .len()
129                .checked_sub(new_cursor)
130                .ok_or(encode::Error)?;
131
132            let cmsg_ptr = {
133                // Safety: the msg_control buffer should always be allocated to MAX_LEN
134                let msg_controllen = self.cursor;
135                let msg_control = self.storage.as_mut_ptr().add(msg_controllen as _);
136                msg_control as *mut cmsghdr
137            };
138
139            {
140                let cmsg = &mut *cmsg_ptr;
141
142                // interpret the start of cmsg as a cmsghdr
143                // Safety: the cmsg slice should already be zero-initialized and aligned
144
145                // Indicate the type of cmsg
146                cmsg.cmsg_level = level;
147                cmsg.cmsg_type = ty;
148
149                // CMSG_LEN() returns the value to store in the cmsg_len member
150                // of the cmsghdr structure, taking into account any necessary
151                // alignment.  It takes the data length as an argument.
152                cmsg.cmsg_len = libc::CMSG_LEN(size_of::<T>() as _) as _;
153            }
154
155            {
156                // Write the actual value in the data space of the cmsg
157                // Safety: we asserted we had enough space in the cmsg buffer above
158                // CMSG_DATA() returns a pointer to the data portion of a
159                // cmsghdr. The pointer returned cannot be assumed to be
160                // suitably aligned for accessing arbitrary payload data types.
161                // Applications should not cast it to a pointer type matching the
162                // payload, but should instead use memcpy(3) to copy data to or
163                // from a suitably declared object.
164                let data_ptr = cmsg_ptr.add(1);
165
166                debug_assert_eq!(data_ptr as *mut u8, libc::CMSG_DATA(cmsg_ptr) as *mut u8);
167
168                core::ptr::copy_nonoverlapping(
169                    &value as *const T as *const u8,
170                    data_ptr as *mut u8,
171                    size_of::<T>(),
172                );
173            }
174
175            // add the values as a usize to make sure we work cross-platform
176            self.cursor = new_cursor;
177            debug_assert!(
178                self.cursor <= self.storage.len(),
179                "msg should not exceed max allocated"
180            );
181
182            Ok(self.cursor)
183        }
184    }
185}