s2n_quic_platform/message/msg/
ext.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5
6pub trait Ext {
7    type Encoder<'a>: cmsg::Encoder
8    where
9        Self: 'a;
10
11    fn header(&self) -> Option<(datagram::Header<Handle>, datagram::AncillaryData)>;
12    fn cmsg_encoder(&mut self) -> Self::Encoder<'_>;
13    fn remote_address(&self) -> Option<SocketAddress>;
14    fn set_remote_address(&mut self, remote_address: &SocketAddress);
15}
16
17impl Ext for msghdr {
18    type Encoder<'a> = MsghdrEncoder<'a>;
19
20    #[inline]
21    fn header(&self) -> Option<(datagram::Header<Handle>, datagram::AncillaryData)> {
22        let addr = self.remote_address()?;
23        let mut path = Handle::from_remote_address(addr.into());
24
25        let ancillary_data = unsafe { cmsg::decode::Iter::from_msghdr(self) }.collect();
26        let ecn = ancillary_data.ecn;
27
28        path.with_ancillary_data(ancillary_data);
29
30        let header = datagram::Header { path, ecn };
31
32        Some((header, ancillary_data))
33    }
34
35    #[inline]
36    fn cmsg_encoder(&mut self) -> Self::Encoder<'_> {
37        MsghdrEncoder { msghdr: self }
38    }
39
40    #[inline]
41    fn remote_address(&self) -> Option<SocketAddress> {
42        debug_assert!(!self.msg_name.is_null());
43        match self.msg_namelen as usize {
44            size if size == size_of::<sockaddr_in>() => {
45                let sockaddr: &sockaddr_in = unsafe { &*(self.msg_name as *const _) };
46                let port = sockaddr.sin_port.to_be();
47                let addr: IpV4Address = sockaddr.sin_addr.s_addr.to_ne_bytes().into();
48                Some(SocketAddressV4::new(addr, port).into())
49            }
50            size if size == size_of::<sockaddr_in6>() => {
51                let sockaddr: &sockaddr_in6 = unsafe { &*(self.msg_name as *const _) };
52                let port = sockaddr.sin6_port.to_be();
53                let addr: IpV6Address = sockaddr.sin6_addr.s6_addr.into();
54                Some(SocketAddressV6::new(addr, port).into())
55            }
56            _ => None,
57        }
58    }
59
60    #[inline]
61    fn set_remote_address(&mut self, remote_address: &SocketAddress) {
62        debug_assert!(!self.msg_name.is_null());
63
64        match remote_address {
65            SocketAddress::IpV4(addr) => {
66                let sockaddr: &mut sockaddr_in = unsafe { &mut *(self.msg_name as *mut _) };
67                sockaddr.sin_family = AF_INET as _;
68                sockaddr.sin_port = addr.port().to_be();
69                sockaddr.sin_addr.s_addr = u32::from_ne_bytes((*addr.ip()).into());
70                self.msg_namelen = size_of::<sockaddr_in>() as _;
71            }
72            SocketAddress::IpV6(addr) => {
73                let sockaddr: &mut sockaddr_in6 = unsafe { &mut *(self.msg_name as *mut _) };
74                sockaddr.sin6_family = AF_INET6 as _;
75                sockaddr.sin6_port = addr.port().to_be();
76                sockaddr.sin6_addr.s6_addr = (*addr.ip()).into();
77                self.msg_namelen = size_of::<sockaddr_in6>() as _;
78            }
79        }
80    }
81}
82
83pub struct MsghdrEncoder<'a> {
84    msghdr: &'a mut msghdr,
85}
86
87impl Encoder for MsghdrEncoder<'_> {
88    #[inline]
89    fn encode_cmsg<T: Copy>(
90        &mut self,
91        level: libc::c_int,
92        ty: libc::c_int,
93        value: T,
94    ) -> Result<usize, cmsg::encode::Error> {
95        let storage =
96            unsafe { &mut *(self.msghdr.msg_control as *mut cmsg::Storage<{ cmsg::MAX_LEN }>) };
97
98        let mut encoder = storage.encoder();
99        encoder.seek(self.msghdr.msg_controllen as _);
100
101        let msg_len = encoder.encode_cmsg(level, ty, value)?;
102
103        // update the cursor
104        self.msghdr.msg_controllen = encoder.len() as _;
105
106        Ok(msg_len)
107    }
108}