unix_udp_sock/
unix.rs

1#![allow(clippy::unnecessary_cast)]
2use std::{
3    io::{self, IoSliceMut},
4    mem::{self, MaybeUninit},
5    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6    os::{fd::AsFd, unix::io::AsRawFd},
7    sync::{
8        atomic::{AtomicU64, AtomicUsize, Ordering},
9        Arc,
10    },
11    task::{Context, Poll},
12    time::SystemTime,
13};
14
15use crate::cmsg::{AsPtr, EcnCodepoint, Source, Transmit};
16use futures_core::ready;
17use socket2::SockRef;
18use tokio::{
19    io::{Interest, ReadBuf},
20    net::ToSocketAddrs,
21};
22
23use super::{cmsg, log_sendmsg_error, RecvMeta, UdpState, IO_ERROR_LOG_INTERVAL};
24
25pub(crate) const BATCH_SIZE_CAP: usize = SYS_BATCH_SIZE_CAP;
26
27// This is not set to the maximum as larger batch sizes require larger stack
28// frames, which may be undesirable. On non-Linux/FreeBSD systems, this is
29// reduced to 1, as they don't support batching UDP messages.
30pub(crate) const DEFAULT_BATCH_SIZE: usize = SYS_DEFAULT_BATCH_SIZE;
31
32#[cfg(target_os = "linux")]
33const SYS_BATCH_SIZE_CAP: usize = libc::UIO_MAXIOV as usize;
34
35#[cfg(target_os = "freebsd")]
36const SYS_BATCH_SIZE_CAP: usize = 1024 as usize;
37
38#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
39pub const SYS_BATCH_SIZE_CAP: usize = 1;
40
41#[cfg(any(target_os = "linux", target_os = "freebsd"))]
42pub const SYS_DEFAULT_BATCH_SIZE: usize = 128;
43
44#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
45pub const SYS_DEFAULT_BATCH_SIZE: usize = 1;
46
47#[cfg(target_os = "freebsd")]
48type IpTosTy = libc::c_uchar;
49#[cfg(not(target_os = "freebsd"))]
50type IpTosTy = libc::c_int;
51
52/// Tokio-compatible UDP socket with some useful specializations.
53///
54/// Unlike a standard tokio UDP socket, this allows ECN bits to be read and written on some
55/// platforms.
56#[derive(Debug)]
57pub struct UdpSocket {
58    io: tokio::net::UdpSocket,
59    last_send_error: LastSendError,
60}
61
62impl AsRawFd for UdpSocket {
63    fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
64        self.io.as_raw_fd()
65    }
66}
67
68impl AsFd for UdpSocket {
69    fn as_fd(&self) -> std::os::unix::prelude::BorrowedFd<'_> {
70        self.io.as_fd()
71    }
72}
73
74#[derive(Clone, Debug)]
75pub(crate) struct LastSendError(Arc<AtomicU64>);
76
77impl Default for LastSendError {
78    fn default() -> Self {
79        let now = Self::now();
80        Self(Arc::new(AtomicU64::new(
81            now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now),
82        )))
83    }
84}
85
86impl LastSendError {
87    fn now() -> u64 {
88        SystemTime::now()
89            .duration_since(SystemTime::UNIX_EPOCH)
90            .unwrap()
91            .as_secs()
92    }
93
94    /// Determine whether the last error was more than `IO_ERROR_LOG_INTERVAL`
95    /// seconds ago. If so, update the last error time and return true.
96    ///
97    /// Note: if the system clock regresses more tha `IO_ERROR_LOG_INTERVAL`,
98    /// this function may impose an additional delay on log message emission.
99    /// Similarly, if it advances, messages may be emitted prematurely.
100    pub(crate) fn should_log(&self) -> bool {
101        let now = Self::now();
102        self.0
103            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |cur| {
104                (now.saturating_sub(cur) > IO_ERROR_LOG_INTERVAL).then_some(now)
105            })
106            .is_ok()
107    }
108}
109
110impl UdpSocket {
111    /// Creates a new UDP socket from a previously created `std::net::UdpSocket`
112    pub fn from_std(socket: std::net::UdpSocket) -> io::Result<UdpSocket> {
113        socket.set_nonblocking(true)?;
114
115        init(SockRef::from(&socket))?;
116        Ok(UdpSocket {
117            io: tokio::net::UdpSocket::from_std(socket)?,
118            last_send_error: LastSendError::default(),
119        })
120    }
121
122    pub fn into_std(self) -> io::Result<std::net::UdpSocket> {
123        self.io.into_std()
124    }
125
126    /// create a new UDP socket and attempt to bind to `addr`
127    pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
128        let io = tokio::net::UdpSocket::bind(addr).await?;
129        init(SockRef::from(&io))?;
130        Ok(UdpSocket {
131            io,
132            last_send_error: LastSendError::default(),
133        })
134    }
135
136    /// sets the value of SO_BROADCAST for this socket
137    pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
138        self.io.set_broadcast(broadcast)
139    }
140
141    /// Opportunistically try to enable GRO support for this socket. This is
142    /// only supported on Linux platforms.
143    #[cfg(target_os = "linux")]
144    pub fn set_gro(&self, enable: bool) -> io::Result<()> {
145        // See gro::gro_segments().
146        const OPTION_OFF: libc::c_int = 0;
147
148        let value = if enable { OPTION_ON } else { OPTION_OFF };
149        set_socket_option(&self.io, libc::SOL_UDP, libc::UDP_GRO, value)
150    }
151
152    pub async fn connect<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<()> {
153        self.io.connect(addrs).await
154    }
155    pub async fn join_multicast_v4(
156        &self,
157        multiaddr: Ipv4Addr,
158        interface: Ipv4Addr,
159    ) -> io::Result<()> {
160        self.io.join_multicast_v4(multiaddr, interface)
161    }
162    pub async fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
163        self.io.join_multicast_v6(multiaddr, interface)
164    }
165    pub async fn leave_multicast_v4(
166        &self,
167        multiaddr: Ipv4Addr,
168        interface: Ipv4Addr,
169    ) -> io::Result<()> {
170        self.io.leave_multicast_v4(multiaddr, interface)
171    }
172    pub async fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
173        self.io.leave_multicast_v6(multiaddr, interface)
174    }
175    pub async fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
176        self.io.set_multicast_loop_v4(on)
177    }
178    pub async fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
179        self.io.set_multicast_loop_v6(on)
180    }
181    /// Sends data on the socket to the given address. On success, returns the
182    /// number of bytes written.
183    ///
184    /// calls underlying tokio [`send_to`]
185    ///
186    /// [`send_to`]: method@tokio::net::UdpSocket::send_to
187    pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
188        self.io.send_to(buf, target).await
189    }
190    /// Sends data on the socket to the given address. On success, returns the
191    /// number of bytes written.
192    ///
193    /// calls underlying tokio [`poll_send_to`]
194    ///
195    /// [`poll_send_to`]: method@tokio::net::UdpSocket::poll_send_to
196    pub fn poll_send_to(
197        &self,
198        cx: &mut Context<'_>,
199        buf: &[u8],
200        target: SocketAddr,
201    ) -> Poll<io::Result<usize>> {
202        self.io.poll_send_to(cx, buf, target)
203    }
204    /// Sends data on the socket to the remote address that the socket is
205    /// connected to.
206    ///
207    /// See tokio [`send`]
208    ///
209    /// [`send`]: method@tokio::net::UdpSocket::send
210    pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
211        self.io.send(buf).await
212    }
213    /// Sends data on the socket to the remote address that the socket is
214    /// connected to.
215    ///
216    /// See tokio [`poll_send`]
217    ///
218    /// [`poll_send`]: method@tokio::net::UdpSocket::poll_send
219    pub async fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
220        self.io.poll_send(cx, buf)
221    }
222    /// Receives a single datagram message on the socket. On success, returns
223    /// the number of bytes read and the origin.
224    ///
225    /// See tokio [`recv_from`]
226    ///
227    /// [`recv_from`]: method@tokio::net::UdpSocket::recv_from
228    pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
229        self.io.recv_from(buf).await
230    }
231    /// Receives a single datagram message on the socket. On success, returns
232    /// the number of bytes read and the origin.
233    ///
234    /// See tokio [`poll_recv_from`]
235    ///
236    /// [`poll_recv_from`]: method@tokio::net::UdpSocket::poll_recv_from
237    pub fn poll_recv_from(
238        &self,
239        cx: &mut Context<'_>,
240        buf: &mut ReadBuf<'_>,
241    ) -> Poll<io::Result<SocketAddr>> {
242        self.io.poll_recv_from(cx, buf)
243    }
244    /// Receives a single datagram message on the socket from the remote address
245    /// to which it is connected. On success, returns the number of bytes read.
246    ///
247    /// See tokio [`recv`]
248    ///
249    /// [`recv`]: method@tokio::net::UdpSocket::recv
250    pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
251        self.io.recv(buf).await
252    }
253    /// Receives a single datagram message on the socket from the remote address
254    /// to which it is connected. On success, returns the number of bytes read.
255    ///
256    /// See tokio [`poll_recv`]
257    ///
258    /// [`poll_recv`]: method@tokio::net::UdpSocket::poll_recv
259    pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
260        self.io.poll_recv(cx, buf)
261    }
262
263    /// Calls syscall [`sendmmsg`]. With a given `state` configured GSO and
264    /// `transmits` with information on the data and metadata about outgoing
265    /// packets.
266    ///
267    /// Utilizes the default batch size (`DEFAULT_BATCH_SIZE`), and will send no
268    /// more than that number of messages. The caller must call this fuction
269    /// again after modifying `transmits` to continue sending the entire set of
270    /// messages.
271    ///
272    /// [`sendmmsg`]: https://linux.die.net/man/2/sendmmsg
273    pub async fn send_mmsg<B: AsPtr<u8>>(
274        &self,
275        state: &UdpState,
276        transmits: &[Transmit<B>],
277    ) -> Result<usize, io::Error> {
278        self.send_mmsg_with_batch_size::<_, DEFAULT_BATCH_SIZE>(state, transmits)
279            .await
280    }
281
282    /// Calls syscall [`sendmmsg`]. With a given `state` configured GSO and
283    /// `transmits` with information on the data and metadata about outgoing packets.
284    ///
285    /// Sends no more than `BATCH_SIZE` messages. The caller must call this
286    /// fuction again after modifying `transmits` to continue sending the entire
287    /// set of messages.  `BATCH_SIZE_CAP` defines the maximum that will be
288    /// sent, regardless of the specified `BATCH_SIZE`
289    ///
290    /// [`sendmmsg`]: https://linux.die.net/man/2/sendmmsg
291    pub async fn send_mmsg_with_batch_size<B: AsPtr<u8>, const BATCH_SIZE: usize>(
292        &self,
293        state: &UdpState,
294        transmits: &[Transmit<B>],
295    ) -> Result<usize, io::Error> {
296        let n = loop {
297            self.io.writable().await?;
298            let last_send_error = self.last_send_error.clone();
299            let io = &self.io;
300            match io.try_io(Interest::WRITABLE, || {
301                send::<_, BATCH_SIZE>(state, SockRef::from(io), last_send_error, transmits)
302            }) {
303                Ok(res) => break res,
304                Err(_would_block) => continue,
305            }
306        };
307        // if n == transmits.len() {}
308        Ok(n)
309    }
310
311    /// Calls syscall [`sendmsg`]. With a given `state` configured GSO and
312    /// `transmit` with information on the data and metadata about outgoing packet.
313    ///
314    /// [`sendmsg`]: https://linux.die.net/man/2/sendmsg
315    pub async fn send_msg<B: AsPtr<u8>>(
316        &self,
317        state: &UdpState,
318        transmits: Transmit<B>,
319    ) -> io::Result<usize> {
320        let n = loop {
321            self.io.writable().await?;
322            let io = &self.io;
323            match io.try_io(Interest::WRITABLE, || {
324                send_msg(state, SockRef::from(io), &transmits)
325            }) {
326                Ok(res) => break res,
327                Err(_would_block) => continue,
328            }
329        };
330        Ok(n)
331    }
332
333    /// async version of `recvmmsg` with compile-time configurable batch size
334    pub async fn recv_mmsg(
335        &self,
336        bufs: &mut [IoSliceMut<'_>],
337        meta: &mut [RecvMeta],
338    ) -> io::Result<usize> {
339        self.recv_mmsg_with_batch_size::<DEFAULT_BATCH_SIZE>(bufs, meta)
340            .await
341    }
342
343    pub async fn recv_mmsg_with_batch_size<const BATCH_SIZE: usize>(
344        &self,
345        bufs: &mut [IoSliceMut<'_>],
346        meta: &mut [RecvMeta],
347    ) -> io::Result<usize> {
348        debug_assert!(!bufs.is_empty());
349        loop {
350            self.io.readable().await?;
351            let io = &self.io;
352            match io.try_io(Interest::READABLE, || {
353                recv::<BATCH_SIZE>(SockRef::from(io), bufs, meta)
354            }) {
355                Ok(res) => return Ok(res),
356                Err(_would_block) => continue,
357            }
358        }
359    }
360
361    /// `recv_msg` is similar to `recv_from` but returns extra information
362    /// about the packet in [`RecvMeta`].
363    ///
364    /// [`RecvMeta`]: crate::RecvMeta
365    pub async fn recv_msg(&self, buf: &mut [u8]) -> io::Result<RecvMeta> {
366        let mut iov = IoSliceMut::new(buf);
367        debug_assert!(!iov.is_empty());
368        loop {
369            self.io.readable().await?;
370            let io = &self.io;
371            match io.try_io(Interest::READABLE, || recv_msg(SockRef::from(io), &mut iov)) {
372                Ok(res) => return Ok(res),
373                Err(_would_block) => continue,
374            }
375        }
376    }
377
378    /// calls `sendmmsg`
379    pub fn poll_send_mmsg<B: AsPtr<u8>>(
380        &mut self,
381        state: &UdpState,
382        cx: &mut Context,
383        transmits: &[Transmit<B>],
384    ) -> Poll<io::Result<usize>> {
385        self.poll_send_mmsg_with_batch_size::<_, DEFAULT_BATCH_SIZE>(state, cx, transmits)
386    }
387
388    /// calls `sendmmsg`
389    pub fn poll_send_mmsg_with_batch_size<B: AsPtr<u8>, const BATCH_SIZE: usize>(
390        &mut self,
391        state: &UdpState,
392        cx: &mut Context,
393        transmits: &[Transmit<B>],
394    ) -> Poll<io::Result<usize>> {
395        loop {
396            ready!(self.io.poll_send_ready(cx))?;
397            let io = &self.io;
398            if let Ok(res) = io.try_io(Interest::WRITABLE, || {
399                send::<_, BATCH_SIZE>(
400                    state,
401                    SockRef::from(io),
402                    self.last_send_error.clone(),
403                    transmits,
404                )
405            }) {
406                return Poll::Ready(Ok(res));
407            }
408        }
409    }
410
411    /// calls `sendmsg` with compile-time configurable batch size
412    pub fn poll_send_msg<B: AsPtr<u8>>(
413        &self,
414        state: &UdpState,
415        cx: &mut Context,
416        transmits: Transmit<B>,
417    ) -> Poll<io::Result<usize>> {
418        loop {
419            ready!(self.io.poll_send_ready(cx))?;
420            let io = &self.io;
421            if let Ok(res) = io.try_io(Interest::WRITABLE, || {
422                send_msg(state, SockRef::from(io), &transmits)
423            }) {
424                return Poll::Ready(Ok(res));
425            }
426        }
427    }
428
429    /// calls `recvmsg`
430    pub fn poll_recv_msg(
431        &self,
432        cx: &mut Context,
433        buf: &mut IoSliceMut<'_>,
434    ) -> Poll<io::Result<RecvMeta>> {
435        loop {
436            ready!(self.io.poll_recv_ready(cx))?;
437            let io = &self.io;
438            if let Ok(res) = io.try_io(Interest::READABLE, || recv_msg(SockRef::from(io), buf)) {
439                return Poll::Ready(Ok(res));
440            }
441        }
442    }
443
444    /// calls `recvmmsg`
445    pub fn poll_recv_mmsg(
446        &self,
447        cx: &mut Context,
448        bufs: &mut [IoSliceMut<'_>],
449        meta: &mut [RecvMeta],
450    ) -> Poll<io::Result<usize>> {
451        self.poll_recv_mmsg_with_batch_size::<DEFAULT_BATCH_SIZE>(cx, bufs, meta)
452    }
453
454    /// calls `recvmmsg` with compile-time configurable batch size
455    pub fn poll_recv_mmsg_with_batch_size<const BATCH_SIZE: usize>(
456        &self,
457        cx: &mut Context,
458        bufs: &mut [IoSliceMut<'_>],
459        meta: &mut [RecvMeta],
460    ) -> Poll<io::Result<usize>> {
461        debug_assert!(!bufs.is_empty());
462        loop {
463            ready!(self.io.poll_recv_ready(cx))?;
464            let io = &self.io;
465            if let Ok(res) = io.try_io(Interest::READABLE, || {
466                recv::<BATCH_SIZE>(SockRef::from(io), bufs, meta)
467            }) {
468                return Poll::Ready(Ok(res));
469            }
470        }
471    }
472
473    /// Returns local address this socket is bound to.
474    pub fn local_addr(&self) -> io::Result<SocketAddr> {
475        self.io.local_addr()
476    }
477}
478
479pub mod sync {
480
481    use std::os::unix::prelude::IntoRawFd;
482
483    use super::*;
484
485    #[derive(Debug)]
486    pub struct UdpSocket {
487        io: std::net::UdpSocket,
488        last_send_error: LastSendError,
489    }
490
491    impl AsRawFd for UdpSocket {
492        fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
493            self.io.as_raw_fd()
494        }
495    }
496    impl IntoRawFd for UdpSocket {
497        fn into_raw_fd(self) -> std::os::unix::prelude::RawFd {
498            self.io.into_raw_fd()
499        }
500    }
501
502    impl UdpSocket {
503        /// Creates a new UDP socket from a previously created `std::net::UdpSocket`
504        pub fn from_std(socket: std::net::UdpSocket) -> io::Result<Self> {
505            init(SockRef::from(&socket))?;
506            socket.set_nonblocking(false)?;
507            Ok(Self {
508                io: socket,
509                last_send_error: LastSendError::default(),
510            })
511        }
512        /// create a new UDP socket and attempt to bind to `addr`
513        pub fn bind<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
514            let io = std::net::UdpSocket::bind(addr)?;
515            init(SockRef::from(&io))?;
516            io.set_nonblocking(false)?;
517            Ok(Self {
518                io,
519                last_send_error: LastSendError::default(),
520            })
521        }
522        /// sets nonblocking mode
523        pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
524            self.io.set_nonblocking(nonblocking)
525        }
526        /// sets the value of SO_BROADCAST for this socket
527        pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
528            self.io.set_broadcast(broadcast)
529        }
530        pub fn connect<A: std::net::ToSocketAddrs>(&self, addrs: A) -> io::Result<()> {
531            self.io.connect(addrs)
532        }
533        pub fn join_multicast_v4(
534            &self,
535            multiaddr: Ipv4Addr,
536            interface: Ipv4Addr,
537        ) -> io::Result<()> {
538            self.io.join_multicast_v4(&multiaddr, &interface)
539        }
540        pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
541            self.io.join_multicast_v6(multiaddr, interface)
542        }
543        pub fn leave_multicast_v4(
544            &self,
545            multiaddr: Ipv4Addr,
546            interface: Ipv4Addr,
547        ) -> io::Result<()> {
548            self.io.leave_multicast_v4(&multiaddr, &interface)
549        }
550        pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
551            self.io.leave_multicast_v6(multiaddr, interface)
552        }
553        pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
554            self.io.set_multicast_loop_v4(on)
555        }
556        pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
557            self.io.set_multicast_loop_v6(on)
558        }
559        /// Sends data on the socket to the given address. On success, returns the
560        /// number of bytes written.
561        ///
562        /// calls underlying tokio [`send_to`]
563        ///
564        /// [`send_to`]: method@tokio::net::UdpSocket::send_to
565        pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
566            self.io.send_to(buf, target)
567        }
568        /// Sends data on the socket to the remote address that the socket is
569        /// connected to.
570        ///
571        /// See tokio [`send`]
572        ///
573        /// [`send`]: method@tokio::net::UdpSocket::send
574        pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
575            self.io.send(buf)
576        }
577        /// Receives a single datagram message on the socket. On success, returns
578        /// the number of bytes read and the origin.
579        ///
580        /// See tokio [`recv_from`]
581        ///
582        /// [`recv_from`]: method@tokio::net::UdpSocket::recv_from
583        pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
584            self.io.recv_from(buf)
585        }
586        /// Receives a single datagram message on the socket from the remote address
587        /// to which it is connected. On success, returns the number of bytes read.
588        ///
589        /// See tokio [`recv`]
590        ///
591        /// [`recv`]: method@tokio::net::UdpSocket::recv
592        pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
593            self.io.recv(buf)
594        }
595
596        /// Calls syscall [`sendmmsg`]. With a given `state` configured GSO and
597        /// `transmits` with information on the data and metadata about outgoing
598        /// packets.
599        ///
600        /// Utilizes the default batch size (`DEFAULT_BATCH_SIZE`), and will
601        /// send no more than that number of messages. The caller must call this
602        /// fuction again after modifying `transmits` to continue sending the
603        /// entire set of messages.
604        ///
605        /// [`sendmmsg`]: https://linux.die.net/man/2/sendmmsg
606        pub fn send_mmsg<B: AsPtr<u8>>(
607            &mut self,
608            state: &UdpState,
609            transmits: &[Transmit<B>],
610        ) -> Result<usize, io::Error> {
611            self.send_mmsg_with_batch_size::<_, DEFAULT_BATCH_SIZE>(state, transmits)
612        }
613
614        /// Calls syscall [`sendmmsg`]. With a given `state` configured GSO and
615        /// `transmits` with information on the data and metadata about outgoing packets.
616        ///
617        /// Sends no more than `BATCH_SIZE` messages. The caller must call this
618        /// fuction again after modifying `transmits` to continue sending the
619        /// entire set of messages. `BATCH_SIZE_CAP` defines the maximum that
620        /// will be sent, regardless of the specified `BATCH_SIZE`
621        ///
622        /// [`sendmmsg`]: https://linux.die.net/man/2/sendmmsg
623        pub fn send_mmsg_with_batch_size<B: AsPtr<u8>, const BATCH_SIZE: usize>(
624            &mut self,
625            state: &UdpState,
626            transmits: &[Transmit<B>],
627        ) -> Result<usize, io::Error> {
628            send::<_, BATCH_SIZE>(
629                state,
630                SockRef::from(&self.io),
631                self.last_send_error.clone(),
632                transmits,
633            )
634        }
635
636        /// Calls syscall [`sendmsg`]. With a given `state` configured GSO and
637        /// `transmit` with information on the data and metadata about outgoing packet.
638        ///
639        /// [`sendmsg`]: https://linux.die.net/man/2/sendmsg
640        pub fn send_msg<B: AsPtr<u8>>(
641            &self,
642            state: &UdpState,
643            transmits: Transmit<B>,
644        ) -> io::Result<usize> {
645            send_msg(state, SockRef::from(&self.io), &transmits)
646        }
647
648        /// async version of `recvmmsg`
649        pub fn recv_mmsg(
650            &self,
651            bufs: &mut [IoSliceMut<'_>],
652            meta: &mut [RecvMeta],
653        ) -> io::Result<usize> {
654            self.recv_mmsg_with_batch_size::<DEFAULT_BATCH_SIZE>(bufs, meta)
655        }
656
657        /// async version of `recvmmsg`, with compile-time configurable batch size
658        pub fn recv_mmsg_with_batch_size<const BATCH_SIZE: usize>(
659            &self,
660            bufs: &mut [IoSliceMut<'_>],
661            meta: &mut [RecvMeta],
662        ) -> io::Result<usize> {
663            debug_assert!(!bufs.is_empty());
664            recv::<BATCH_SIZE>(SockRef::from(&self.io), bufs, meta)
665        }
666
667        /// `recv_msg` is similar to `recv_from` but returns extra information
668        /// about the packet in [`RecvMeta`].
669        ///
670        /// [`RecvMeta`]: crate::RecvMeta
671        pub fn recv_msg(&self, buf: &mut [u8]) -> io::Result<RecvMeta> {
672            let mut iov = IoSliceMut::new(buf);
673            debug_assert!(!iov.is_empty());
674
675            recv_msg(SockRef::from(&self.io), &mut iov)
676        }
677        /// Returns local address this socket is bound to.
678        pub fn local_addr(&self) -> io::Result<SocketAddr> {
679            self.io.local_addr()
680        }
681    }
682}
683
684fn set_socket_option<Fd: AsRawFd>(
685    socket: &Fd,
686    level: libc::c_int,
687    name: libc::c_int,
688    value: libc::c_int,
689) -> Result<(), io::Error> {
690    let rc = unsafe {
691        libc::setsockopt(
692            socket.as_raw_fd(),
693            level,
694            name,
695            &value as *const _ as _,
696            mem::size_of_val(&value) as _,
697        )
698    };
699
700    if rc != -1 {
701        Ok(())
702    } else {
703        Err(io::Error::last_os_error())
704    }
705}
706
707const OPTION_ON: libc::c_int = 1;
708
709fn init(io: SockRef<'_>) -> io::Result<()> {
710    let mut cmsg_platform_space = 0;
711    if cfg!(target_os = "linux") || cfg!(target_os = "freebsd") || cfg!(target_os = "macos") {
712        cmsg_platform_space +=
713            unsafe { libc::CMSG_SPACE(mem::size_of::<libc::in6_pktinfo>() as _) as usize };
714    }
715
716    assert!(
717        CMSG_LEN
718            >= unsafe { libc::CMSG_SPACE(mem::size_of::<libc::c_int>() as _) as usize }
719                + cmsg_platform_space
720    );
721    assert!(
722        mem::align_of::<libc::cmsghdr>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
723        "control message buffers will be misaligned"
724    );
725
726    io.set_nonblocking(true)?;
727
728    let addr = io.local_addr()?;
729    let is_ipv4 = addr.family() == libc::AF_INET as libc::sa_family_t;
730
731    // macos and ios do not support IP_RECVTOS on dual-stack sockets :(
732    if is_ipv4 || ((!cfg!(any(target_os = "macos", target_os = "ios"))) && !io.only_v6()?) {
733        set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_RECVTOS, OPTION_ON)?;
734    }
735    #[cfg(target_os = "linux")]
736    {
737        // Forbid IPv4 fragmentation. Set even for IPv6 to account for IPv6 mapped IPv4 addresses.
738        set_socket_option(
739            &*io,
740            libc::IPPROTO_IP,
741            libc::IP_MTU_DISCOVER,
742            libc::IP_PMTUDISC_PROBE,
743        )?;
744
745        if is_ipv4 {
746            set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_PKTINFO, OPTION_ON)?;
747        } else {
748            set_socket_option(
749                &*io,
750                libc::IPPROTO_IPV6,
751                libc::IPV6_MTU_DISCOVER,
752                libc::IP_PMTUDISC_PROBE,
753            )?;
754        }
755    }
756    #[cfg(target_os = "macos")]
757    {
758        if is_ipv4 {
759            set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_PKTINFO, OPTION_ON)?;
760        }
761    }
762    #[cfg(target_os = "freebsd")]
763    // IP_RECVDSTADDR == IP_SENDSRCADDR on FreeBSD
764    // macOS uses only IP_RECVDSTADDR, no IP_SENDSRCADDR on macOS
765    // macOS also supports IP_PKTINFO
766    {
767        if is_ipv4 {
768            set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_RECVDSTADDR, OPTION_ON)?;
769            set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_RECVIF, OPTION_ON)?;
770        }
771    }
772
773    // IPV6_RECVPKTINFO is standardized
774    if !is_ipv4 {
775        set_socket_option(&*io, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, OPTION_ON)?;
776        set_socket_option(&*io, libc::IPPROTO_IPV6, libc::IPV6_RECVTCLASS, OPTION_ON)?;
777    }
778    Ok(())
779}
780
781#[cfg(not(any(target_os = "macos", target_os = "ios")))]
782fn send_msg<B: AsPtr<u8>>(
783    // `state` is not presently used on FreeBSD
784    #[allow(unused_variables)] state: &UdpState,
785    io: SockRef<'_>,
786    transmit: &Transmit<B>,
787) -> io::Result<usize> {
788    let mut msg_hdr: libc::msghdr = unsafe { mem::zeroed() };
789    let mut iovec: libc::iovec = unsafe { mem::zeroed() };
790    let mut cmsg = cmsg::Aligned([0u8; CMSG_LEN]);
791
792    let addr = socket2::SockAddr::from(transmit.dst);
793    let dst_addr = &addr;
794    prepare_msg(transmit, dst_addr, &mut msg_hdr, &mut iovec, &mut cmsg);
795
796    loop {
797        let n = unsafe { libc::sendmsg(io.as_raw_fd(), &msg_hdr, 0) };
798        if n == -1 {
799            let e = io::Error::last_os_error();
800            match e.kind() {
801                io::ErrorKind::Interrupted => {
802                    // Retry the transmission
803                    continue;
804                }
805                io::ErrorKind::WouldBlock => return Err(e),
806                _ => {
807                    // Some network adapters do not support GSO. Unfortunately, Linux offers no easy way
808                    // for us to detect this short of an I/O error when we try to actually send
809                    // datagrams using it.
810                    #[cfg(target_os = "linux")]
811                    if e.raw_os_error() == Some(libc::EIO) {
812                        // Prevent new transmits from being scheduled using GSO. Existing GSO transmits
813                        // may already be in the pipeline, so we need to tolerate additional failures.
814                        if state.max_gso_segments() > 1 {
815                            tracing::error!("got EIO, halting segmentation offload");
816                            state
817                                .max_gso_segments
818                                .store(1, std::sync::atomic::Ordering::Relaxed);
819                        }
820                    }
821
822                    // Other errors are ignored, since they will ususally be handled
823                    // by higher level retransmits and timeouts.
824                    // - PermissionDenied errors have been observed due to iptable rules.
825                    //   Those are not fatal errors, since the
826                    //   configuration can be dynamically changed.
827                    // - Destination unreachable errors have been observed for other
828                    // log_sendmsg_error(last_send_error, e, &transmits[0]);
829
830                    // The ERRORS section in https://man7.org/linux/man-pages/man2/sendmmsg.2.html
831                    // describes that errors will only be returned if no message could be transmitted
832                    // at all. Therefore drop the first (problematic) message,
833                    // and retry the remaining ones.
834                    return Ok(n as usize);
835                }
836            }
837        }
838        return Ok(n as usize);
839    }
840}
841
842#[cfg(not(any(target_os = "macos", target_os = "ios")))]
843fn send<B: AsPtr<u8>, const BATCH_SIZE: usize>(
844    // `state` is not presently used on FreeBSD
845    #[allow(unused_variables)] state: &UdpState,
846    io: SockRef<'_>,
847    last_send_error: LastSendError,
848    transmits: &[Transmit<B>],
849) -> io::Result<usize> {
850    use std::ptr;
851
852    let mut msgs: [libc::mmsghdr; BATCH_SIZE] = unsafe { mem::zeroed() };
853    let mut iovecs: [libc::iovec; BATCH_SIZE] = unsafe { mem::zeroed() };
854    let mut cmsgs = [cmsg::Aligned([0u8; CMSG_LEN]); BATCH_SIZE];
855    // This assume_init looks a bit weird because one might think it
856    // assumes the SockAddr data to be initialized, but that call
857    // refers to the whole array, which itself is made up of MaybeUninit
858    // containers. Their presence protects the SockAddr inside from
859    // being assumed as initialized by the assume_init call.
860    // TODO: Replace this with uninit_array once it becomes MSRV-stable
861    let mut addrs: [MaybeUninit<socket2::SockAddr>; BATCH_SIZE] =
862        unsafe { MaybeUninit::uninit().assume_init() };
863    for (i, transmit) in transmits.iter().enumerate().take(BATCH_SIZE) {
864        let dst_addr = unsafe {
865            ptr::write(addrs[i].as_mut_ptr(), socket2::SockAddr::from(transmit.dst));
866            &*addrs[i].as_ptr()
867        };
868        prepare_msg(
869            transmit,
870            dst_addr,
871            &mut msgs[i].msg_hdr,
872            &mut iovecs[i],
873            &mut cmsgs[i],
874        );
875    }
876    let num_transmits = transmits.len().min(BATCH_SIZE);
877
878    loop {
879        #[cfg(target_os = "linux")]
880        let n =
881            unsafe { libc::sendmmsg(io.as_raw_fd(), msgs.as_mut_ptr(), num_transmits as u32, 0) };
882        #[cfg(target_os = "freebsd")]
883        let n =
884            unsafe { libc::sendmmsg(io.as_raw_fd(), msgs.as_mut_ptr(), num_transmits as usize, 0) };
885        if n == -1 {
886            let e = io::Error::last_os_error();
887            match e.kind() {
888                io::ErrorKind::Interrupted => {
889                    // Retry the transmission
890                    continue;
891                }
892                io::ErrorKind::WouldBlock => return Err(e),
893                _ => {
894                    // Some network adapters do not support GSO. Unfortunately, Linux offers no easy way
895                    // for us to detect this short of an I/O error when we try to actually send
896                    // datagrams using it.
897                    #[cfg(target_os = "linux")]
898                    if e.raw_os_error() == Some(libc::EIO) {
899                        // Prevent new transmits from being scheduled using GSO. Existing GSO transmits
900                        // may already be in the pipeline, so we need to tolerate additional failures.
901                        if state.max_gso_segments() > 1 {
902                            tracing::error!("got EIO, halting segmentation offload");
903                            state
904                                .max_gso_segments
905                                .store(1, std::sync::atomic::Ordering::Relaxed);
906                        }
907                    }
908
909                    // Other errors are ignored, since they will ususally be handled
910                    // by higher level retransmits and timeouts.
911                    // - PermissionDenied errors have been observed due to iptable rules.
912                    //   Those are not fatal errors, since the
913                    //   configuration can be dynamically changed.
914                    // - Destination unreachable errors have been observed for other
915                    log_sendmsg_error(last_send_error, e, &transmits[0]);
916
917                    // The ERRORS section in https://man7.org/linux/man-pages/man2/sendmmsg.2.html
918                    // describes that errors will only be returned if no message could be transmitted
919                    // at all. Therefore drop the first (problematic) message,
920                    // and retry the remaining ones.
921                    return Ok(num_transmits.min(1));
922                }
923            }
924        }
925        return Ok(n as usize);
926    }
927}
928
929#[cfg(any(target_os = "macos", target_os = "ios"))]
930fn send_msg<B: AsPtr<u8>>(
931    _state: &UdpState,
932    io: SockRef<'_>,
933    transmit: &Transmit<B>,
934) -> io::Result<usize> {
935    let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
936    let mut iov: libc::iovec = unsafe { mem::zeroed() };
937    let mut ctrl = cmsg::Aligned([0u8; CMSG_LEN]);
938
939    let addr = socket2::SockAddr::from(transmit.dst);
940    let dst_addr = &addr;
941    prepare_msg(transmit, dst_addr, &mut hdr, &mut iov, &mut ctrl);
942
943    loop {
944        let n = unsafe { libc::sendmsg(io.as_raw_fd(), &hdr, 0) };
945        if n == -1 {
946            let e = io::Error::last_os_error();
947            match e.kind() {
948                io::ErrorKind::Interrupted => {
949                    // Retry the transmission
950                    continue;
951                }
952                io::ErrorKind::WouldBlock => return Err(e),
953                _ => {
954                    // Other errors are ignored, since they will ususally be handled
955                    // by higher level retransmits and timeouts.
956                    // - PermissionDenied errors have been observed due to iptable rules.
957                    //   Those are not fatal errors, since the
958                    //   configuration can be dynamically changed.
959                    // - Destination unreachable errors have been observed for other
960                    return Ok(n as usize);
961                }
962            }
963        }
964        return Ok(n as usize);
965    }
966}
967
968#[cfg(any(target_os = "macos", target_os = "ios"))]
969fn send<B: AsPtr<u8>, const BATCH_SIZE: usize>(
970    _state: &UdpState,
971    io: SockRef<'_>,
972    last_send_error: LastSendError,
973    transmits: &[Transmit<B>],
974) -> io::Result<usize> {
975    let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
976    let mut iov: libc::iovec = unsafe { mem::zeroed() };
977    let mut ctrl = cmsg::Aligned([0u8; CMSG_LEN]);
978    let mut sent = 0;
979    while sent < transmits.len() {
980        let addr = socket2::SockAddr::from(transmits[sent].dst);
981        prepare_msg(&transmits[sent], &addr, &mut hdr, &mut iov, &mut ctrl);
982        let n = unsafe { libc::sendmsg(io.as_raw_fd(), &hdr, 0) };
983        if n == -1 {
984            let e = io::Error::last_os_error();
985            match e.kind() {
986                io::ErrorKind::Interrupted => {
987                    // Retry the transmission
988                }
989                io::ErrorKind::WouldBlock if sent != 0 => return Ok(sent),
990                io::ErrorKind::WouldBlock => return Err(e),
991                _ => {
992                    // Other errors are ignored, since they will ususally be handled
993                    // by higher level retransmits and timeouts.
994                    // - PermissionDenied errors have been observed due to iptable rules.
995                    //   Those are not fatal errors, since the
996                    //   configuration can be dynamically changed.
997                    // - Destination unreachable errors have been observed for other
998                    log_sendmsg_error(last_send_error.clone(), e, &transmits[sent]);
999                    sent += 1;
1000                }
1001            }
1002        } else {
1003            sent += 1;
1004        }
1005    }
1006    Ok(sent)
1007}
1008
1009#[cfg(not(any(target_os = "macos", target_os = "ios")))]
1010fn recv<const BATCH_SIZE: usize>(
1011    io: SockRef<'_>,
1012    bufs: &mut [IoSliceMut<'_>],
1013    meta: &mut [RecvMeta],
1014) -> io::Result<usize> {
1015    use std::ptr;
1016
1017    let mut names = [MaybeUninit::<libc::sockaddr_storage>::uninit(); BATCH_SIZE];
1018    let mut ctrls = [cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit()); BATCH_SIZE];
1019    let mut hdrs = unsafe { mem::zeroed::<[libc::mmsghdr; BATCH_SIZE]>() };
1020    let max_msg_count = bufs.len().min(BATCH_SIZE);
1021    for i in 0..max_msg_count {
1022        prepare_recv(
1023            &mut bufs[i],
1024            &mut names[i],
1025            &mut ctrls[i],
1026            &mut hdrs[i].msg_hdr,
1027        );
1028    }
1029    let msg_count = loop {
1030        #[cfg(target_os = "linux")]
1031        let n = unsafe {
1032            libc::recvmmsg(
1033                io.as_raw_fd(),
1034                hdrs.as_mut_ptr(),
1035                bufs.len().min(BATCH_SIZE) as libc::c_uint,
1036                0,
1037                ptr::null_mut(),
1038            )
1039        };
1040        #[cfg(target_os = "freebsd")]
1041        let n = unsafe {
1042            libc::recvmmsg(
1043                io.as_raw_fd(),
1044                hdrs.as_mut_ptr(),
1045                bufs.len().min(BATCH_SIZE) as usize,
1046                0,
1047                ptr::null_mut(),
1048            )
1049        };
1050        if n == -1 {
1051            let e = io::Error::last_os_error();
1052            if e.kind() == io::ErrorKind::Interrupted {
1053                continue;
1054            }
1055            return Err(e);
1056        }
1057        break n;
1058    };
1059    for i in 0..(msg_count as usize) {
1060        meta[i] = decode_recv(&names[i], &hdrs[i].msg_hdr, hdrs[i].msg_len as usize);
1061    }
1062    Ok(msg_count as usize)
1063}
1064
1065#[cfg(not(any(target_os = "macos", target_os = "ios")))]
1066fn recv_msg(io: SockRef<'_>, bufs: &mut IoSliceMut<'_>) -> io::Result<RecvMeta> {
1067    let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
1068    let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
1069    let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
1070
1071    prepare_recv(bufs, &mut name, &mut ctrl, &mut hdr);
1072
1073    let n = loop {
1074        let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
1075        if n == -1 {
1076            let e = io::Error::last_os_error();
1077            if e.kind() == io::ErrorKind::Interrupted {
1078                continue;
1079            }
1080            return Err(e);
1081        }
1082        if hdr.msg_flags & libc::MSG_TRUNC != 0 {
1083            continue;
1084        }
1085        break n;
1086    };
1087    Ok(decode_recv(&name, &hdr, n as usize))
1088}
1089
1090#[cfg(any(target_os = "macos", target_os = "ios"))]
1091fn recv<const BATCH_SIZE: usize>(
1092    io: SockRef<'_>,
1093    bufs: &mut [IoSliceMut<'_>],
1094    meta: &mut [RecvMeta],
1095) -> io::Result<usize> {
1096    let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
1097    let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
1098    let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
1099    prepare_recv(&mut bufs[0], &mut name, &mut ctrl, &mut hdr);
1100    let n = loop {
1101        let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
1102        if n == -1 {
1103            let e = io::Error::last_os_error();
1104            if e.kind() == io::ErrorKind::Interrupted {
1105                continue;
1106            }
1107            return Err(e);
1108        }
1109        if hdr.msg_flags & libc::MSG_TRUNC != 0 {
1110            continue;
1111        }
1112        break n;
1113    };
1114    meta[0] = decode_recv(&name, &hdr, n as usize);
1115    Ok(1)
1116}
1117
1118#[cfg(any(target_os = "macos", target_os = "ios"))]
1119fn recv_msg(io: SockRef<'_>, bufs: &mut IoSliceMut<'_>) -> io::Result<RecvMeta> {
1120    let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
1121    let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
1122    let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
1123    prepare_recv(bufs, &mut name, &mut ctrl, &mut hdr);
1124    let n = loop {
1125        let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
1126        if n == -1 {
1127            let e = io::Error::last_os_error();
1128            if e.kind() == io::ErrorKind::Interrupted {
1129                continue;
1130            }
1131            return Err(e);
1132        }
1133        if hdr.msg_flags & libc::MSG_TRUNC != 0 {
1134            continue;
1135        }
1136        break n;
1137    };
1138    Ok(decode_recv(&name, &hdr, n as usize))
1139}
1140
1141/// Returns the platforms UDP socket capabilities
1142pub fn udp_state() -> UdpState {
1143    UdpState {
1144        max_gso_segments: AtomicUsize::new(gso::max_gso_segments()),
1145        gro_segments: gro::gro_segments(),
1146    }
1147}
1148
1149const CMSG_LEN: usize = 88;
1150
1151fn prepare_msg<B: AsPtr<u8>>(
1152    transmit: &Transmit<B>,
1153    dst_addr: &socket2::SockAddr,
1154    hdr: &mut libc::msghdr,
1155    iov: &mut libc::iovec,
1156    ctrl: &mut cmsg::Aligned<[u8; CMSG_LEN]>,
1157) {
1158    iov.iov_base = transmit.contents.as_ptr() as *const _ as *mut _;
1159    iov.iov_len = transmit.contents.len();
1160
1161    // SAFETY: Casting the pointer to a mutable one is legal,
1162    // as sendmsg is guaranteed to not alter the mutable pointer
1163    // as per the POSIX spec. See the section on the sys/socket.h
1164    // header for details. The type is only mutable in the first
1165    // place because it is reused by recvmsg as well.
1166    let name = dst_addr.as_ptr() as *mut libc::c_void;
1167    let namelen = dst_addr.len();
1168    hdr.msg_name = name as *mut _;
1169    hdr.msg_namelen = namelen;
1170    hdr.msg_iov = iov;
1171    hdr.msg_iovlen = 1;
1172
1173    hdr.msg_control = ctrl.0.as_mut_ptr() as _;
1174    hdr.msg_controllen = CMSG_LEN as _;
1175    let mut encoder = unsafe { cmsg::Encoder::new(hdr) };
1176    let ecn = transmit.ecn.map_or(0, |x| x as libc::c_int);
1177    if transmit.dst.is_ipv4() {
1178        encoder.push(libc::IPPROTO_IP, libc::IP_TOS, ecn as IpTosTy);
1179    } else {
1180        encoder.push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn);
1181    }
1182
1183    if let Some(segment_size) = transmit.segment_size {
1184        gso::set_segment_size(&mut encoder, segment_size as u16);
1185    }
1186
1187    if let Some(ip) = &transmit.src {
1188        match ip {
1189            Source::Ip(IpAddr::V4(v4)) => {
1190                #[cfg(any(target_os = "linux", target_os = "macos"))]
1191                {
1192                    let pktinfo = libc::in_pktinfo {
1193                        ipi_ifindex: 0,
1194                        ipi_spec_dst: libc::in_addr {
1195                            s_addr: u32::from_ne_bytes(v4.octets()),
1196                        },
1197                        ipi_addr: libc::in_addr { s_addr: 0 },
1198                    };
1199                    encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo);
1200                }
1201                #[cfg(target_os = "freebsd")]
1202                {
1203                    let addr = libc::in_addr {
1204                        s_addr: u32::from_ne_bytes(v4.octets()),
1205                    };
1206                    encoder.push(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, addr);
1207                }
1208            }
1209            Source::Ip(IpAddr::V6(v6)) => {
1210                let pktinfo = libc::in6_pktinfo {
1211                    ipi6_ifindex: 0,
1212                    ipi6_addr: libc::in6_addr {
1213                        s6_addr: v6.octets(),
1214                    },
1215                };
1216                encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo);
1217            }
1218            #[cfg(not(target_os = "freebsd"))]
1219            Source::Interface(i) => {
1220                let pktinfo = libc::in_pktinfo {
1221                    ipi_ifindex: *i as _, // i32 linux, u32 mac
1222                    ipi_spec_dst: libc::in_addr { s_addr: 0 },
1223                    ipi_addr: libc::in_addr { s_addr: 0 },
1224                };
1225                encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo);
1226            }
1227            #[cfg(target_os = "freebsd")]
1228            Source::Interface(_) => (), // Not yet supported on FreeBSD
1229            Source::InterfaceV6(i, ip) => {
1230                let pktinfo = libc::in6_pktinfo {
1231                    ipi6_ifindex: *i,
1232                    ipi6_addr: libc::in6_addr {
1233                        s6_addr: ip.octets(),
1234                    },
1235                };
1236                encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo);
1237            }
1238        }
1239    }
1240
1241    encoder.finish();
1242}
1243
1244fn prepare_recv(
1245    buf: &mut IoSliceMut,
1246    name: &mut MaybeUninit<libc::sockaddr_storage>,
1247    ctrl: &mut cmsg::Aligned<MaybeUninit<[u8; CMSG_LEN]>>,
1248    hdr: &mut libc::msghdr,
1249) {
1250    hdr.msg_name = name.as_mut_ptr() as _;
1251    hdr.msg_namelen = mem::size_of::<libc::sockaddr_storage>() as _;
1252    hdr.msg_iov = buf as *mut IoSliceMut as *mut libc::iovec;
1253    hdr.msg_iovlen = 1;
1254    hdr.msg_control = ctrl.0.as_mut_ptr() as _;
1255    hdr.msg_controllen = CMSG_LEN as _;
1256    hdr.msg_flags = 0;
1257}
1258
1259fn decode_recv(
1260    name: &MaybeUninit<libc::sockaddr_storage>,
1261    hdr: &libc::msghdr,
1262    len: usize,
1263) -> RecvMeta {
1264    let name = unsafe { name.assume_init() };
1265    let mut ecn_bits = 0;
1266    let mut dst_ip = None;
1267    // Only mutated on Linux
1268    #[allow(unused_mut)]
1269    let mut dst_local_ip = None;
1270    let mut ifindex = 0;
1271    // Only mutated on Linux
1272    #[allow(unused_mut)]
1273    let mut stride = len;
1274
1275    let cmsg_iter = unsafe { cmsg::Iter::new(hdr) };
1276    for cmsg in cmsg_iter {
1277        match (cmsg.cmsg_level, cmsg.cmsg_type) {
1278            // FreeBSD uses IP_RECVTOS here, and we can be liberal because cmsgs are opt-in.
1279            (libc::IPPROTO_IP, libc::IP_TOS) | (libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe {
1280                ecn_bits = cmsg::decode::<u8>(cmsg);
1281            },
1282            (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe {
1283                // Temporary hack around broken macos ABI. Remove once upstream fixes it.
1284                // https://bugreport.apple.com/web/?problemID=48761855
1285                if cfg!(target_os = "macos")
1286                    && cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::<u8>() as _) as usize
1287                {
1288                    ecn_bits = cmsg::decode::<u8>(cmsg);
1289                } else {
1290                    ecn_bits = cmsg::decode::<libc::c_int>(cmsg) as u8;
1291                }
1292            },
1293            #[cfg(not(target_os = "freebsd"))]
1294            (libc::IPPROTO_IP, libc::IP_PKTINFO) => {
1295                let pktinfo = unsafe { cmsg::decode::<libc::in_pktinfo>(cmsg) };
1296                dst_ip = Some(IpAddr::V4(Ipv4Addr::from(
1297                    pktinfo.ipi_addr.s_addr.to_ne_bytes(),
1298                )));
1299                dst_local_ip = Some(IpAddr::V4(Ipv4Addr::from(
1300                    pktinfo.ipi_spec_dst.s_addr.to_ne_bytes(),
1301                )));
1302                ifindex = pktinfo.ipi_ifindex as _;
1303            }
1304            (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
1305                let pktinfo = unsafe { cmsg::decode::<libc::in6_pktinfo>(cmsg) };
1306                dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)));
1307                ifindex = pktinfo.ipi6_ifindex;
1308            }
1309            // freebsd doesn't have PKTINFO
1310            #[cfg(target_os = "freebsd")]
1311            (libc::IPPROTO_IP, libc::IP_RECVIF) => {
1312                let info = unsafe { cmsg::decode::<libc::sockaddr_dl>(cmsg) };
1313                ifindex = info.sdl_index as _;
1314            }
1315            #[cfg(target_os = "linux")]
1316            (libc::SOL_UDP, libc::UDP_GRO) => unsafe {
1317                stride = cmsg::decode::<libc::c_int>(cmsg) as usize;
1318            },
1319            _ => {}
1320        }
1321    }
1322
1323    let addr = match libc::c_int::from(name.ss_family) {
1324        libc::AF_INET => {
1325            // Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in.
1326            let addr = unsafe { &*(&name as *const _ as *const libc::sockaddr_in) };
1327            let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes());
1328            let port = u16::from_be(addr.sin_port);
1329            SocketAddr::V4(SocketAddrV4::new(ip, port))
1330        }
1331        libc::AF_INET6 => {
1332            // Safety: if the ss_family field is AF_INET6 then storage must be a sockaddr_in6.
1333            let addr = unsafe { &*(&name as *const _ as *const libc::sockaddr_in6) };
1334            let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
1335            let port = u16::from_be(addr.sin6_port);
1336            SocketAddr::V6(SocketAddrV6::new(
1337                ip,
1338                port,
1339                addr.sin6_flowinfo,
1340                addr.sin6_scope_id,
1341            ))
1342        }
1343        _ => unreachable!(),
1344    };
1345
1346    RecvMeta {
1347        len,
1348        stride,
1349        addr,
1350        ecn: EcnCodepoint::from_bits(ecn_bits),
1351        dst_ip,
1352        dst_local_ip,
1353        ifindex,
1354    }
1355}
1356
1357#[cfg(target_os = "linux")]
1358mod gso {
1359    use super::*;
1360
1361    /// Checks whether GSO support is available by setting the UDP_SEGMENT
1362    /// option on a socket
1363    pub fn max_gso_segments() -> usize {
1364        const GSO_SIZE: libc::c_int = 1500;
1365
1366        let socket = match std::net::UdpSocket::bind("[::]:0")
1367            .or_else(|_| std::net::UdpSocket::bind("127.0.0.1:0"))
1368        {
1369            Ok(socket) => socket,
1370            Err(_) => return 1,
1371        };
1372
1373        // As defined in linux/udp.h
1374        // #define UDP_MAX_SEGMENTS        (1 << 6UL)
1375        match set_socket_option(&socket, libc::SOL_UDP, libc::UDP_SEGMENT, GSO_SIZE) {
1376            Ok(()) => 64,
1377            Err(_) => 1,
1378        }
1379    }
1380
1381    pub fn set_segment_size(encoder: &mut cmsg::Encoder, segment_size: u16) {
1382        encoder.push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size);
1383    }
1384}
1385
1386#[cfg(not(target_os = "linux"))]
1387mod gso {
1388    use super::*;
1389
1390    pub fn max_gso_segments() -> usize {
1391        1
1392    }
1393
1394    pub fn set_segment_size(_encoder: &mut cmsg::Encoder, _segment_size: u16) {
1395        panic!("Setting a segment size is not supported on current platform");
1396    }
1397}
1398
1399#[cfg(target_os = "linux")]
1400mod gro {
1401    use super::*;
1402
1403    pub fn gro_segments() -> usize {
1404        let socket = match std::net::UdpSocket::bind("[::]:0") {
1405            Ok(socket) => socket,
1406            Err(_) => return 1,
1407        };
1408
1409        let on: libc::c_int = 1;
1410        let rc = unsafe {
1411            libc::setsockopt(
1412                socket.as_raw_fd(),
1413                libc::SOL_UDP,
1414                libc::UDP_GRO,
1415                &on as *const _ as _,
1416                mem::size_of_val(&on) as _,
1417            )
1418        };
1419
1420        if rc != -1 {
1421            // As defined in net/ipv4/udp_offload.c
1422            // #define UDP_GRO_CNT_MAX 64
1423            //
1424            // NOTE: this MUST be set to UDP_GRO_CNT_MAX to ensure that the receive buffer size
1425            // (get_max_udp_payload_size() * gro_segments()) is large enough to hold the largest GRO
1426            // list the kernel might potentially produce. See
1427            // https://github.com/quinn-rs/quinn/pull/1354.
1428            64
1429        } else {
1430            1
1431        }
1432    }
1433}
1434
1435#[cfg(not(target_os = "linux"))]
1436mod gro {
1437    pub fn gro_segments() -> usize {
1438        1
1439    }
1440}