1use std::{marker::PhantomData, net::SocketAddr, os::fd::AsRawFd};
2
3use tokio::io::{unix::AsyncFd, Interest};
4
5use crate::{
6 control_message::{control_message_space, ControlMessage, MessageQueue},
7 interface::InterfaceName,
8 networkaddress::{sealed::PrivateToken, MulticastJoinable, NetworkAddress},
9 raw_socket::RawSocket,
10};
11
12#[cfg(not(any(target_os = "linux", target_os = "freebsd", target_os = "macos")))]
13mod fallback;
14#[cfg(target_os = "freebsd")]
15mod freebsd;
16#[cfg(target_os = "linux")]
17mod linux;
18#[cfg(target_os = "macos")]
19mod macos;
20
21#[cfg(not(any(target_os = "linux", target_os = "freebsd", target_os = "macos")))]
22use self::fallback::*;
23#[cfg(target_os = "freebsd")]
24use self::freebsd::*;
25#[cfg(target_os = "linux")]
26pub use self::linux::*;
27#[cfg(target_os = "macos")]
28use self::macos::*;
29
30#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Default)]
31pub struct Timestamp {
32 pub seconds: i64,
33 pub nanos: u32,
34}
35
36impl Timestamp {
37 #[cfg_attr(target_os = "macos", allow(unused))] pub(crate) fn from_timespec(timespec: libc::timespec) -> Self {
39 Self {
40 seconds: timespec.tv_sec as _,
41 nanos: timespec.tv_nsec as _,
42 }
43 }
44
45 pub(crate) fn from_timeval(timeval: libc::timeval) -> Self {
46 Self {
47 seconds: timeval.tv_sec as _,
48 nanos: (1000 * timeval.tv_usec) as _,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
54pub enum GeneralTimestampMode {
55 SoftwareAll,
56 SoftwareRecv,
57 #[default]
58 None,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
62pub enum InterfaceTimestampMode {
63 HardwareAll,
64 HardwareRecv,
65 HardwarePTPAll,
66 HardwarePTPRecv,
67 SoftwareAll,
68 SoftwareRecv,
69 #[default]
70 None,
71}
72
73impl From<GeneralTimestampMode> for InterfaceTimestampMode {
74 fn from(value: GeneralTimestampMode) -> Self {
75 match value {
76 GeneralTimestampMode::SoftwareAll => InterfaceTimestampMode::SoftwareAll,
77 GeneralTimestampMode::SoftwareRecv => InterfaceTimestampMode::SoftwareRecv,
78 GeneralTimestampMode::None => InterfaceTimestampMode::None,
79 }
80 }
81}
82
83fn select_timestamp(
84 mode: InterfaceTimestampMode,
85 software: Option<Timestamp>,
86 hardware: Option<Timestamp>,
87) -> Option<Timestamp> {
88 use InterfaceTimestampMode::*;
89
90 match mode {
91 SoftwareAll | SoftwareRecv => software,
92 HardwareAll | HardwareRecv | HardwarePTPAll | HardwarePTPRecv => hardware,
93 None => Option::None,
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
98pub struct RecvResult<A> {
99 pub bytes_read: usize,
100 pub remote_addr: A,
101 pub timestamp: Option<Timestamp>,
102}
103
104#[derive(Debug)]
105pub struct Socket<A, S> {
106 timestamp_mode: InterfaceTimestampMode,
107 socket: AsyncFd<RawSocket>,
108 #[cfg(target_os = "linux")]
109 send_counter: u32,
110 _addr: PhantomData<A>,
111 _state: PhantomData<S>,
112}
113
114pub struct Open;
115pub struct Connected;
116
117impl<A: NetworkAddress, S> Socket<A, S> {
118 pub fn local_addr(&self) -> std::io::Result<A> {
119 let addr = self.socket.get_ref().getsockname()?;
120 A::from_sockaddr(addr, PrivateToken).ok_or_else(|| std::io::ErrorKind::Other.into())
121 }
122
123 pub fn peer_addr(&self) -> std::io::Result<A> {
124 let addr = self.socket.get_ref().getpeername()?;
125 A::from_sockaddr(addr, PrivateToken).ok_or_else(|| std::io::ErrorKind::Other.into())
126 }
127
128 pub async fn recv(&self, buf: &mut [u8]) -> std::io::Result<RecvResult<A>> {
129 self.socket
130 .async_io(Interest::READABLE, |socket| {
131 let mut control_buf = [0; control_message_space::<[libc::timespec; 3]>()];
132
133 let (bytes_read, control_messages, remote_address) =
135 socket.receive_message(buf, &mut control_buf, MessageQueue::Normal)?;
136
137 let mut timestamp = None;
138
139 for msg in control_messages {
142 match msg {
143 ControlMessage::Timestamping { software, hardware } => {
144 tracing::trace!("Timestamps: {:?} {:?}", software, hardware);
145 timestamp = select_timestamp(self.timestamp_mode, software, hardware);
146 }
147
148 #[cfg(target_os = "linux")]
149 ControlMessage::ReceiveError(error) => {
150 tracing::warn!(
151 "unexpected error control message on receive: {}",
152 error.ee_errno
153 );
154 }
155
156 ControlMessage::Other(msg) => {
157 tracing::debug!(
158 "unexpected control message on receive: {} {}",
159 msg.cmsg_level,
160 msg.cmsg_type,
161 );
162 }
163 }
164 }
165
166 let remote_addr = A::from_sockaddr(remote_address, PrivateToken)
167 .ok_or(std::io::ErrorKind::Other)?;
168
169 Ok(RecvResult {
170 bytes_read,
171 remote_addr,
172 timestamp,
173 })
174 })
175 .await
176 }
177}
178
179impl<A: NetworkAddress> Socket<A, Open> {
180 pub async fn send_to(&mut self, buf: &[u8], addr: A) -> std::io::Result<Option<Timestamp>> {
181 let addr = addr.to_sockaddr(PrivateToken);
182
183 self.socket
184 .async_io(Interest::WRITABLE, |socket| socket.send_to(buf, addr))
185 .await?;
186
187 if matches!(
188 self.timestamp_mode,
189 InterfaceTimestampMode::HardwarePTPAll | InterfaceTimestampMode::SoftwareAll
190 ) {
191 #[cfg(target_os = "linux")]
192 {
193 let expected_counter = self.send_counter;
194 self.send_counter = self.send_counter.wrapping_add(1);
195 self.fetch_send_timestamp(expected_counter).await
196 }
197
198 #[cfg(not(target_os = "linux"))]
199 {
200 unreachable!("Should not be able to create send timestamping sockets on platforms other than linux")
201 }
202 } else {
203 Ok(None)
204 }
205 }
206
207 pub fn connect(self, addr: A) -> std::io::Result<Socket<A, Connected>> {
208 let addr = addr.to_sockaddr(PrivateToken);
209 self.socket.get_ref().connect(addr)?;
210 Ok(Socket {
211 timestamp_mode: self.timestamp_mode,
212 socket: self.socket,
213 #[cfg(target_os = "linux")]
214 send_counter: self.send_counter,
215 _addr: PhantomData,
216 _state: PhantomData,
217 })
218 }
219}
220
221impl<A: NetworkAddress> Socket<A, Connected> {
222 pub async fn send(&mut self, buf: &[u8]) -> std::io::Result<Option<Timestamp>> {
223 self.socket
224 .async_io(Interest::WRITABLE, |socket| socket.send(buf))
225 .await?;
226
227 if matches!(
228 self.timestamp_mode,
229 InterfaceTimestampMode::HardwarePTPAll | InterfaceTimestampMode::SoftwareAll
230 ) {
231 #[cfg(target_os = "linux")]
232 {
233 let expected_counter = self.send_counter;
234 self.send_counter = self.send_counter.wrapping_add(1);
235 self.fetch_send_timestamp(expected_counter).await
236 }
237
238 #[cfg(not(target_os = "linux"))]
239 {
240 unreachable!("Should not be able to create send timestamping sockets on platforms other than linux")
241 }
242 } else {
243 Ok(None)
244 }
245 }
246}
247
248impl<A: MulticastJoinable, S> Socket<A, S> {
249 pub fn join_multicast(&self, addr: A, interface: InterfaceName) -> std::io::Result<()> {
250 addr.join_multicast(self.socket.get_ref().as_raw_fd(), interface, PrivateToken)
251 }
252
253 pub fn leave_multicast(&self, addr: A, interface: InterfaceName) -> std::io::Result<()> {
254 addr.leave_multicast(self.socket.get_ref().as_raw_fd(), interface, PrivateToken)
255 }
256}
257
258pub fn open_ip(
259 addr: SocketAddr,
260 timestamping: GeneralTimestampMode,
261) -> std::io::Result<Socket<SocketAddr, Open>> {
262 let socket = match addr {
264 SocketAddr::V4(_) => RawSocket::open(libc::PF_INET, libc::SOCK_DGRAM, libc::IPPROTO_UDP),
265 SocketAddr::V6(_) => RawSocket::open(libc::PF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP),
266 }?;
267 socket.bind(addr.to_sockaddr(PrivateToken))?;
268 socket.set_nonblocking(true)?;
269 configure_timestamping(&socket, None, timestamping.into(), None)?;
270
271 Ok(Socket {
272 timestamp_mode: timestamping.into(),
273 socket: AsyncFd::new(socket)?,
274 #[cfg(target_os = "linux")]
275 send_counter: 0,
276 _addr: PhantomData,
277 _state: PhantomData,
278 })
279}
280
281pub fn connect_address(
282 addr: SocketAddr,
283 timestamping: GeneralTimestampMode,
284) -> std::io::Result<Socket<SocketAddr, Connected>> {
285 let socket = match addr {
287 SocketAddr::V4(_) => RawSocket::open(libc::PF_INET, libc::SOCK_DGRAM, libc::IPPROTO_UDP),
288 SocketAddr::V6(_) => RawSocket::open(libc::PF_INET6, libc::SOCK_DGRAM, libc::IPPROTO_UDP),
289 }?;
290 socket.connect(addr.to_sockaddr(PrivateToken))?;
291 socket.set_nonblocking(true)?;
292 configure_timestamping(&socket, None, timestamping.into(), None)?;
293
294 Ok(Socket {
295 timestamp_mode: timestamping.into(),
296 socket: AsyncFd::new(socket)?,
297 #[cfg(target_os = "linux")]
298 send_counter: 0,
299 _addr: PhantomData,
300 _state: PhantomData,
301 })
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use std::net::{IpAddr, Ipv4Addr};
308
309 #[tokio::test]
310 async fn test_open_ip() {
311 let mut a = open_ip(
312 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5125),
313 GeneralTimestampMode::None,
314 )
315 .unwrap();
316 let mut b = connect_address(
317 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5125),
318 GeneralTimestampMode::None,
319 )
320 .unwrap();
321 assert!(b.send(&[1, 2, 3]).await.is_ok());
322 let mut buf = [0; 4];
323 let recv_result = a.recv(&mut buf).await.unwrap();
324 assert_eq!(recv_result.bytes_read, 3);
325 assert_eq!(&buf[0..3], &[1, 2, 3]);
326 assert!(a.send_to(&[4, 5, 6], recv_result.remote_addr).await.is_ok());
327 let recv_result = b.recv(&mut buf).await.unwrap();
328 assert_eq!(recv_result.bytes_read, 3);
329 assert_eq!(&buf[0..3], &[4, 5, 6]);
330 }
331}