s2n_quic_platform/message/
msg.rs1use crate::{
5 features,
6 message::{cmsg, cmsg::Encoder, Message as MessageTrait},
7};
8use core::{
9 alloc::Layout,
10 mem::{size_of, size_of_val},
11};
12use libc::{iovec, msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6};
13use s2n_quic_core::{
14 inet::{
15 datagram, IpV4Address, IpV6Address, SocketAddress, SocketAddressV4, SocketAddressV6,
16 Unspecified,
17 },
18 io::tx,
19 path::{self, Handle as _},
20};
21
22mod ext;
23mod handle;
24#[cfg(test)]
25mod tests;
26
27pub use ext::Ext;
28pub use handle::Handle;
29pub use libc::msghdr as Message;
30
31impl MessageTrait for msghdr {
32 type Handle = Handle;
33
34 const SUPPORTS_GSO: bool = features::gso::IS_SUPPORTED;
35 const SUPPORTS_ECN: bool = features::tos::IS_SUPPORTED;
36 const SUPPORTS_FLOW_LABELS: bool = true;
37
38 #[inline]
39 fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage {
40 unsafe { alloc(entries, payload_len, offset, |msg| msg) }
41 }
42
43 #[inline]
44 fn payload_len(&self) -> usize {
45 debug_assert!(!self.msg_iov.is_null());
46 let len = unsafe { (*self.msg_iov).iov_len as _ };
47 debug_assert!(len <= u16::MAX as usize);
48 len
49 }
50
51 #[inline]
52 unsafe fn set_payload_len(&mut self, payload_len: usize) {
53 debug_assert!(payload_len <= u16::MAX as usize);
54 debug_assert!(!self.msg_iov.is_null());
55 (*self.msg_iov).iov_len = payload_len;
56 }
57
58 #[inline]
59 fn set_segment_size(&mut self, size: usize) {
60 debug_assert!(size <= u16::MAX as usize);
61 self.cmsg_encoder().encode_gso(size as _).unwrap();
62 }
63
64 #[inline]
65 unsafe fn reset(&mut self, mtu: usize) {
66 self.set_payload_len(mtu);
68
69 self.set_remote_address(&SocketAddress::IpV6(Default::default()));
71
72 #[inline]
73 unsafe fn check_cmsg(msghdr: &msghdr) {
74 if cfg!(debug_assertions) {
75 let ptr = msghdr.msg_control as *mut u8;
76 let cmsg = core::slice::from_raw_parts_mut(ptr, cmsg::MAX_LEN);
77 #[cfg(not(kani))]
79 {
80 assert!(cmsg.iter().all(|v| *v == 0), "msg_control was not cleared");
81 }
82
83 #[cfg(kani)]
84 {
85 let index: usize = kani::any();
86 kani::assume(index < cmsg.len());
87 assert_eq!(cmsg[index], 0);
88 }
89 }
90 }
91
92 if self.msg_controllen == 0 {
94 check_cmsg(self);
95 }
96
97 #[allow(clippy::unnecessary_cast)]
101 let msg_controllen = self.msg_controllen as usize;
102
103 if msg_controllen != cmsg::MAX_LEN {
104 core::slice::from_raw_parts_mut(self.msg_control as *mut u8, msg_controllen).fill(0);
105 }
106
107 check_cmsg(self);
108
109 self.msg_controllen = cmsg::MAX_LEN as _;
110 }
111
112 #[inline]
113 fn payload_ptr_mut(&mut self) -> *mut u8 {
114 unsafe {
115 let iovec = &mut *self.msg_iov;
116 iovec.iov_base as *mut _
117 }
118 }
119
120 #[inline]
121 fn validate_replication(source: &Self, dest: &Self) {
122 assert_eq!(source.msg_name, dest.msg_name);
123 assert_eq!(source.msg_iov, dest.msg_iov);
124 assert_eq!(source.msg_control, dest.msg_control);
125 }
126
127 #[inline]
128 fn rx_read(
129 &mut self,
130 local_address: &path::LocalAddress,
131 ) -> Option<super::RxMessage<'_, Self::Handle>> {
132 if cfg!(test) {
133 assert_eq!(
134 self.msg_flags & libc::MSG_CTRUNC,
135 0,
136 "control message buffers should always have enough capacity"
137 );
138 }
139
140 let (mut header, cmsg) = self.header()?;
141
142 if !header.path.local_address.ip().is_unspecified() {
144 header.path.local_address.set_port(local_address.port());
145 } else {
146 header.path.local_address = *local_address;
147 }
148
149 let payload = self.payload_mut();
150
151 let segment_size = if cmsg.segment_size == 0 {
152 payload.len()
153 } else {
154 cmsg.segment_size as _
155 };
156
157 let message = crate::message::RxMessage {
158 header,
159 segment_size,
160 payload,
161 };
162
163 Some(message)
164 }
165
166 #[inline]
167 fn tx_write<M: tx::Message<Handle = Self::Handle>>(
168 &mut self,
169 mut message: M,
170 ) -> Result<usize, tx::Error> {
171 let payload = self.payload_mut();
172
173 let max_len = payload.len();
174 let len = message.write_payload(tx::PayloadBuffer::new(payload), 0)?;
175
176 debug_assert_ne!(len, 0);
177 debug_assert!(len <= max_len);
178 let len = len.min(max_len);
179
180 debug_assert_eq!(
181 cmsg::MAX_LEN,
182 self.msg_controllen as _,
183 "message should be reset before writing"
184 );
185 self.msg_controllen = 0;
186
187 unsafe {
188 self.set_payload_len(len);
189 }
190
191 let handle = *message.path_handle();
192 handle.update_msg_hdr(self);
193 self.cmsg_encoder()
194 .encode_ecn(message.ecn(), &handle.remote_address.0)
195 .unwrap();
196
197 Ok(len)
198 }
199}
200
201#[inline]
208pub(super) unsafe fn alloc<T: Copy + Sized, F: Fn(&mut T) -> &mut msghdr>(
209 entries: u32,
210 payload_len: u32,
211 offset: usize,
212 on_entry: F,
213) -> super::Storage {
214 let (layout, entry_offset, header_offset, payload_offset) =
216 layout::<T>(entries, payload_len, offset);
217
218 let storage = super::Storage::new(layout);
220
221 {
222 let ptr = storage.as_ptr();
223
224 let mut entry_ptr = ptr.add(entry_offset) as *mut T;
226 let mut header_ptr = ptr.add(header_offset) as *mut Header;
227 let mut payload_ptr = ptr.add(payload_offset);
228
229 for _ in 0..entries {
230 let entry = on_entry(&mut *entry_ptr);
233 (*header_ptr).update(entry, payload_ptr, payload_len);
234
235 entry_ptr = entry_ptr.add(1);
237 header_ptr = header_ptr.add(1);
238 payload_ptr = payload_ptr.add(payload_len as _);
239
240 storage.check_bounds(entry_ptr);
242 storage.check_bounds(header_ptr);
243 storage.check_bounds(payload_ptr);
244 }
245
246 let primary = ptr.add(entry_offset) as *mut T;
248 let secondary = primary.add(entries as _);
249 storage.check_bounds(secondary.add(entries as _));
250 core::ptr::copy_nonoverlapping(primary, secondary, entries as _);
251 }
252
253 storage
254}
255
256fn layout<T: Copy + Sized>(
267 entries: u32,
268 payload_len: u32,
269 offset: usize,
270) -> (Layout, usize, usize, usize) {
271 let cursor = Layout::array::<u8>(offset).unwrap();
272 let headers = Layout::array::<Header>(entries as _).unwrap();
273 let payloads = Layout::array::<u8>(entries as usize * payload_len as usize).unwrap();
274 let entries = Layout::array::<T>((entries * 2) as usize).unwrap();
276 let (layout, entry_offset) = cursor.extend(entries).unwrap();
277 let (layout, header_offset) = layout.extend(headers).unwrap();
278 let (layout, payload_offset) = layout.extend(payloads).unwrap();
279 (layout, entry_offset, header_offset, payload_offset)
280}
281
282struct Header {
284 pub iovec: iovec,
285 pub msg_name: sockaddr_in6,
286 pub cmsg: cmsg::Storage<{ cmsg::MAX_LEN }>,
287}
288
289impl Header {
290 unsafe fn update(&mut self, entry: &mut msghdr, payload: *mut u8, payload_len: u32) {
292 let iovec = &mut self.iovec;
293
294 iovec.iov_base = payload as *mut _;
295 iovec.iov_len = payload_len as _;
296
297 let entry = &mut *entry;
298
299 entry.msg_name = &mut self.msg_name as *mut _ as *mut _;
300 entry.msg_namelen = size_of_val(&self.msg_name) as _;
301 entry.msg_iov = &mut self.iovec as *mut _;
302 entry.msg_iovlen = 1;
303 entry.msg_controllen = self.cmsg.len() as _;
304 entry.msg_control = self.cmsg.as_mut_ptr() as *mut _;
305
306 debug_assert_eq!(
308 entry
309 .msg_control
310 .align_offset(core::mem::align_of::<cmsg::Storage<{ cmsg::MAX_LEN }>>()),
311 0
312 );
313 }
314}