timestamped_socket/
socket.rs

1use std::{marker::PhantomData, net::SocketAddr, os::fd::AsRawFd};
2
3use tokio::io::{unix::AsyncFd, Interest};
4
5use crate::{
6    control_message::{control_message_space, ControlMessage, MessageQueue},
7    interface::InterfaceName,
8    networkaddress::{sealed::PrivateToken, MulticastJoinable, NetworkAddress},
9    raw_socket::RawSocket,
10};
11
12#[cfg(not(any(target_os = "linux", target_os = "freebsd", target_os = "macos")))]
13mod fallback;
14#[cfg(target_os = "freebsd")]
15mod freebsd;
16#[cfg(target_os = "linux")]
17mod linux;
18#[cfg(target_os = "macos")]
19mod macos;
20
21#[cfg(not(any(target_os = "linux", target_os = "freebsd", target_os = "macos")))]
22use self::fallback::*;
23#[cfg(target_os = "freebsd")]
24use self::freebsd::*;
25#[cfg(target_os = "linux")]
26pub use self::linux::*;
27#[cfg(target_os = "macos")]
28use self::macos::*;
29
30#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Default)]
31pub struct Timestamp {
32    pub seconds: i64,
33    pub nanos: u32,
34}
35
36impl Timestamp {
37    #[cfg_attr(target_os = "macos", allow(unused))] // macos does not do nanoseconds
38    pub(crate) fn from_timespec(timespec: libc::timespec) -> Self {
39        Self {
40            seconds: timespec.tv_sec as _,
41            nanos: timespec.tv_nsec as _,
42        }
43    }
44
45    pub(crate) fn from_timeval(timeval: libc::timeval) -> Self {
46        Self {
47            seconds: timeval.tv_sec as _,
48            nanos: (1000 * timeval.tv_usec) as _,
49        }
50    }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
54pub enum GeneralTimestampMode {
55    SoftwareAll,
56    SoftwareRecv,
57    #[default]
58    None,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
62pub enum InterfaceTimestampMode {
63    HardwareAll,
64    HardwareRecv,
65    HardwarePTPAll,
66    HardwarePTPRecv,
67    SoftwareAll,
68    SoftwareRecv,
69    #[default]
70    None,
71}
72
73impl From<GeneralTimestampMode> for InterfaceTimestampMode {
74    fn from(value: GeneralTimestampMode) -> Self {
75        match value {
76            GeneralTimestampMode::SoftwareAll => InterfaceTimestampMode::SoftwareAll,
77            GeneralTimestampMode::SoftwareRecv => InterfaceTimestampMode::SoftwareRecv,
78            GeneralTimestampMode::None => InterfaceTimestampMode::None,
79        }
80    }
81}
82
83fn select_timestamp(
84    mode: InterfaceTimestampMode,
85    software: Option<Timestamp>,
86    hardware: Option<Timestamp>,
87) -> Option<Timestamp> {
88    use InterfaceTimestampMode::*;
89
90    match mode {
91        SoftwareAll | SoftwareRecv => software,
92        HardwareAll | HardwareRecv | HardwarePTPAll | HardwarePTPRecv => hardware,
93        None => Option::None,
94    }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
98pub struct RecvResult<A> {
99    pub bytes_read: usize,
100    pub remote_addr: A,
101    pub timestamp: Option<Timestamp>,
102}
103
104#[derive(Debug)]
105pub struct Socket<A, S> {
106    timestamp_mode: InterfaceTimestampMode,
107    socket: AsyncFd<RawSocket>,
108    #[cfg(target_os = "linux")]
109    send_counter: u32,
110    _addr: PhantomData<A>,
111    _state: PhantomData<S>,
112}
113
114pub struct Open;
115pub struct Connected;
116
117impl<A: NetworkAddress, S> Socket<A, S> {
118    pub fn local_addr(&self) -> std::io::Result<A> {
119        let addr = self.socket.get_ref().getsockname()?;
120        A::from_sockaddr(addr, PrivateToken).ok_or_else(|| std::io::ErrorKind::Other.into())
121    }
122
123    pub fn peer_addr(&self) -> std::io::Result<A> {
124        let addr = self.socket.get_ref().getpeername()?;
125        A::from_sockaddr(addr, PrivateToken).ok_or_else(|| std::io::ErrorKind::Other.into())
126    }
127
128    pub async fn recv(&self, buf: &mut [u8]) -> std::io::Result<RecvResult<A>> {
129        self.socket
130            .async_io(Interest::READABLE, |socket| {
131                let mut control_buf = [0; control_message_space::<[libc::timespec; 3]>()];
132
133                // loops for when we receive an interrupt during the recv
134                let (bytes_read, control_messages, remote_address) =
135                    socket.receive_message(buf, &mut control_buf, MessageQueue::Normal)?;
136
137                let mut timestamp = None;
138
139                // Loops through the control messages, but we should only get a single message
140                // in practice
141                for msg in control_messages {
142                    match msg {
143                        ControlMessage::Timestamping { software, hardware } => {
144                            tracing::trace!("Timestamps: {:?} {:?}", software, hardware);
145                            timestamp = select_timestamp(self.timestamp_mode, software, hardware);
146                        }
147
148                        #[cfg(target_os = "linux")]
149                        ControlMessage::ReceiveError(error) => {
150                            tracing::warn!(
151                                "unexpected error control message on receive: {}",
152                                error.ee_errno
153                            );
154                        }
155
156                        ControlMessage::Other(msg) => {
157                            tracing::debug!(
158                                "unexpected control message on receive: {} {}",
159                                msg.cmsg_level,
160                                msg.cmsg_type,
161                            );
162                        }
163                    }
164                }
165
166                let remote_addr = A::from_sockaddr(remote_address, PrivateToken)
167                    .ok_or(std::io::ErrorKind::Other)?;
168
169                Ok(RecvResult {
170                    bytes_read,
171                    remote_addr,
172                    timestamp,
173                })
174            })
175            .await
176    }
177}
178
179impl<A: NetworkAddress> Socket<A, Open> {
180    pub async fn send_to(&mut self, buf: &[u8], addr: A) -> std::io::Result<Option<Timestamp>> {
181        let addr = addr.to_sockaddr(PrivateToken);
182
183        self.socket
184            .async_io(Interest::WRITABLE, |socket| socket.send_to(buf, addr))
185            .await?;
186
187        if matches!(
188            self.timestamp_mode,
189            InterfaceTimestampMode::HardwarePTPAll | InterfaceTimestampMode::SoftwareAll
190        ) {
191            #[cfg(target_os = "linux")]
192            {
193                let expected_counter = self.send_counter;
194                self.send_counter = self.send_counter.wrapping_add(1);
195                self.fetch_send_timestamp(expected_counter).await
196            }
197
198            #[cfg(not(target_os = "linux"))]
199            {
200                unreachable!("Should not be able to create send timestamping sockets on platforms other than linux")
201            }
202        } else {
203            Ok(None)
204        }
205    }
206
207    pub fn connect(self, addr: A) -> std::io::Result<Socket<A, Connected>> {
208        let addr = addr.to_sockaddr(PrivateToken);
209        self.socket.get_ref().connect(addr)?;
210        Ok(Socket {
211            timestamp_mode: self.timestamp_mode,
212            socket: self.socket,
213            #[cfg(target_os = "linux")]
214            send_counter: self.send_counter,
215            _addr: PhantomData,
216            _state: PhantomData,
217        })
218    }
219}
220
221impl<A: NetworkAddress> Socket<A, Connected> {
222    pub async fn send(&mut self, buf: &[u8]) -> std::io::Result<Option<Timestamp>> {
223        self.socket
224            .async_io(Interest::WRITABLE, |socket| socket.send(buf))
225            .await?;
226
227        if matches!(
228            self.timestamp_mode,
229            InterfaceTimestampMode::HardwarePTPAll | InterfaceTimestampMode::SoftwareAll
230        ) {
231            #[cfg(target_os = "linux")]
232            {
233                let expected_counter = self.send_counter;
234                self.send_counter = self.send_counter.wrapping_add(1);
235                self.fetch_send_timestamp(expected_counter).await
236            }
237
238            #[cfg(not(target_os = "linux"))]
239            {
240                unreachable!("Should not be able to create send timestamping sockets on platforms other than linux")
241            }
242        } else {
243            Ok(None)
244        }
245    }
246}
247
248impl<A: MulticastJoinable, S> Socket<A, S> {
249    pub fn join_multicast(&self, addr: A, interface: InterfaceName) -> std::io::Result<()> {
250        addr.join_multicast(self.socket.get_ref().as_raw_fd(), interface, PrivateToken)
251    }
252
253    pub fn leave_multicast(&self, addr: A, interface: InterfaceName) -> std::io::Result<()> {
254        addr.leave_multicast(self.socket.get_ref().as_raw_fd(), interface, PrivateToken)
255    }
256}
257
258pub fn open_ip(
259    addr: SocketAddr,
260    timestamping: GeneralTimestampMode,
261) -> std::io::Result<Socket<SocketAddr, Open>> {
262    // Setup the socket
263    let socket = match addr {
264        SocketAddr::V4(_) => RawSocket::open(libc::PF_INET, libc::SOCK_DGRAM, libc::IPPROTO_UDP),
265        SocketAddr::V6(_) => RawSocket::open(libc::PF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP),
266    }?;
267    socket.bind(addr.to_sockaddr(PrivateToken))?;
268    socket.set_nonblocking(true)?;
269    configure_timestamping(&socket, None, timestamping.into(), None)?;
270
271    Ok(Socket {
272        timestamp_mode: timestamping.into(),
273        socket: AsyncFd::new(socket)?,
274        #[cfg(target_os = "linux")]
275        send_counter: 0,
276        _addr: PhantomData,
277        _state: PhantomData,
278    })
279}
280
281pub fn connect_address(
282    addr: SocketAddr,
283    timestamping: GeneralTimestampMode,
284) -> std::io::Result<Socket<SocketAddr, Connected>> {
285    // Setup the socket
286    let socket = match addr {
287        SocketAddr::V4(_) => RawSocket::open(libc::PF_INET, libc::SOCK_DGRAM, libc::IPPROTO_UDP),
288        SocketAddr::V6(_) => RawSocket::open(libc::PF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP),
289    }?;
290    socket.connect(addr.to_sockaddr(PrivateToken))?;
291    socket.set_nonblocking(true)?;
292    configure_timestamping(&socket, None, timestamping.into(), None)?;
293
294    Ok(Socket {
295        timestamp_mode: timestamping.into(),
296        socket: AsyncFd::new(socket)?,
297        #[cfg(target_os = "linux")]
298        send_counter: 0,
299        _addr: PhantomData,
300        _state: PhantomData,
301    })
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use std::net::{IpAddr, Ipv4Addr};
308
309    #[tokio::test]
310    async fn test_open_ip() {
311        let mut a = open_ip(
312            SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5125),
313            GeneralTimestampMode::None,
314        )
315        .unwrap();
316        let mut b = connect_address(
317            SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5125),
318            GeneralTimestampMode::None,
319        )
320        .unwrap();
321        assert!(b.send(&[1, 2, 3]).await.is_ok());
322        let mut buf = [0; 4];
323        let recv_result = a.recv(&mut buf).await.unwrap();
324        assert_eq!(recv_result.bytes_read, 3);
325        assert_eq!(&buf[0..3], &[1, 2, 3]);
326        assert!(a.send_to(&[4, 5, 6], recv_result.remote_addr).await.is_ok());
327        let recv_result = b.recv(&mut buf).await.unwrap();
328        assert_eq!(recv_result.bytes_read, 3);
329        assert_eq!(&buf[0..3], &[4, 5, 6]);
330    }
331}