1#![allow(clippy::unnecessary_cast)]
2use std::{
3 io::{self, IoSliceMut},
4 mem::{self, MaybeUninit},
5 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6 os::{fd::AsFd, unix::io::AsRawFd},
7 sync::{
8 atomic::{AtomicU64, AtomicUsize, Ordering},
9 Arc,
10 },
11 task::{Context, Poll},
12 time::SystemTime,
13};
14
15use crate::cmsg::{AsPtr, EcnCodepoint, Source, Transmit};
16use futures_core::ready;
17use socket2::SockRef;
18use tokio::{
19 io::{Interest, ReadBuf},
20 net::ToSocketAddrs,
21};
22
23use super::{cmsg, log_sendmsg_error, RecvMeta, UdpState, IO_ERROR_LOG_INTERVAL};
24
25pub(crate) const BATCH_SIZE_CAP: usize = SYS_BATCH_SIZE_CAP;
26
27pub(crate) const DEFAULT_BATCH_SIZE: usize = SYS_DEFAULT_BATCH_SIZE;
31
32#[cfg(target_os = "linux")]
33const SYS_BATCH_SIZE_CAP: usize = libc::UIO_MAXIOV as usize;
34
35#[cfg(target_os = "freebsd")]
36const SYS_BATCH_SIZE_CAP: usize = 1024 as usize;
37
38#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
39pub const SYS_BATCH_SIZE_CAP: usize = 1;
40
41#[cfg(any(target_os = "linux", target_os = "freebsd"))]
42pub const SYS_DEFAULT_BATCH_SIZE: usize = 128;
43
44#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
45pub const SYS_DEFAULT_BATCH_SIZE: usize = 1;
46
47#[cfg(target_os = "freebsd")]
48type IpTosTy = libc::c_uchar;
49#[cfg(not(target_os = "freebsd"))]
50type IpTosTy = libc::c_int;
51
52#[derive(Debug)]
57pub struct UdpSocket {
58 io: tokio::net::UdpSocket,
59 last_send_error: LastSendError,
60}
61
62impl AsRawFd for UdpSocket {
63 fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
64 self.io.as_raw_fd()
65 }
66}
67
68impl AsFd for UdpSocket {
69 fn as_fd(&self) -> std::os::unix::prelude::BorrowedFd<'_> {
70 self.io.as_fd()
71 }
72}
73
74#[derive(Clone, Debug)]
75pub(crate) struct LastSendError(Arc<AtomicU64>);
76
77impl Default for LastSendError {
78 fn default() -> Self {
79 let now = Self::now();
80 Self(Arc::new(AtomicU64::new(
81 now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now),
82 )))
83 }
84}
85
86impl LastSendError {
87 fn now() -> u64 {
88 SystemTime::now()
89 .duration_since(SystemTime::UNIX_EPOCH)
90 .unwrap()
91 .as_secs()
92 }
93
94 pub(crate) fn should_log(&self) -> bool {
101 let now = Self::now();
102 self.0
103 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |cur| {
104 (now.saturating_sub(cur) > IO_ERROR_LOG_INTERVAL).then_some(now)
105 })
106 .is_ok()
107 }
108}
109
110impl UdpSocket {
111 pub fn from_std(socket: std::net::UdpSocket) -> io::Result<UdpSocket> {
113 socket.set_nonblocking(true)?;
114
115 init(SockRef::from(&socket))?;
116 Ok(UdpSocket {
117 io: tokio::net::UdpSocket::from_std(socket)?,
118 last_send_error: LastSendError::default(),
119 })
120 }
121
122 pub fn into_std(self) -> io::Result<std::net::UdpSocket> {
123 self.io.into_std()
124 }
125
126 pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
128 let io = tokio::net::UdpSocket::bind(addr).await?;
129 init(SockRef::from(&io))?;
130 Ok(UdpSocket {
131 io,
132 last_send_error: LastSendError::default(),
133 })
134 }
135
136 pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
138 self.io.set_broadcast(broadcast)
139 }
140
141 #[cfg(target_os = "linux")]
144 pub fn set_gro(&self, enable: bool) -> io::Result<()> {
145 const OPTION_OFF: libc::c_int = 0;
147
148 let value = if enable { OPTION_ON } else { OPTION_OFF };
149 set_socket_option(&self.io, libc::SOL_UDP, libc::UDP_GRO, value)
150 }
151
152 pub async fn connect<A: ToSocketAddrs>(&self, addrs: A) -> io::Result<()> {
153 self.io.connect(addrs).await
154 }
155 pub async fn join_multicast_v4(
156 &self,
157 multiaddr: Ipv4Addr,
158 interface: Ipv4Addr,
159 ) -> io::Result<()> {
160 self.io.join_multicast_v4(multiaddr, interface)
161 }
162 pub async fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
163 self.io.join_multicast_v6(multiaddr, interface)
164 }
165 pub async fn leave_multicast_v4(
166 &self,
167 multiaddr: Ipv4Addr,
168 interface: Ipv4Addr,
169 ) -> io::Result<()> {
170 self.io.leave_multicast_v4(multiaddr, interface)
171 }
172 pub async fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
173 self.io.leave_multicast_v6(multiaddr, interface)
174 }
175 pub async fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
176 self.io.set_multicast_loop_v4(on)
177 }
178 pub async fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
179 self.io.set_multicast_loop_v6(on)
180 }
181 pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
188 self.io.send_to(buf, target).await
189 }
190 pub fn poll_send_to(
197 &self,
198 cx: &mut Context<'_>,
199 buf: &[u8],
200 target: SocketAddr,
201 ) -> Poll<io::Result<usize>> {
202 self.io.poll_send_to(cx, buf, target)
203 }
204 pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
211 self.io.send(buf).await
212 }
213 pub async fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
220 self.io.poll_send(cx, buf)
221 }
222 pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
229 self.io.recv_from(buf).await
230 }
231 pub fn poll_recv_from(
238 &self,
239 cx: &mut Context<'_>,
240 buf: &mut ReadBuf<'_>,
241 ) -> Poll<io::Result<SocketAddr>> {
242 self.io.poll_recv_from(cx, buf)
243 }
244 pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
251 self.io.recv(buf).await
252 }
253 pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
260 self.io.poll_recv(cx, buf)
261 }
262
263 pub async fn send_mmsg<B: AsPtr<u8>>(
274 &self,
275 state: &UdpState,
276 transmits: &[Transmit<B>],
277 ) -> Result<usize, io::Error> {
278 self.send_mmsg_with_batch_size::<_, DEFAULT_BATCH_SIZE>(state, transmits)
279 .await
280 }
281
282 pub async fn send_mmsg_with_batch_size<B: AsPtr<u8>, const BATCH_SIZE: usize>(
292 &self,
293 state: &UdpState,
294 transmits: &[Transmit<B>],
295 ) -> Result<usize, io::Error> {
296 let n = loop {
297 self.io.writable().await?;
298 let last_send_error = self.last_send_error.clone();
299 let io = &self.io;
300 match io.try_io(Interest::WRITABLE, || {
301 send::<_, BATCH_SIZE>(state, SockRef::from(io), last_send_error, transmits)
302 }) {
303 Ok(res) => break res,
304 Err(_would_block) => continue,
305 }
306 };
307 Ok(n)
309 }
310
311 pub async fn send_msg<B: AsPtr<u8>>(
316 &self,
317 state: &UdpState,
318 transmits: Transmit<B>,
319 ) -> io::Result<usize> {
320 let n = loop {
321 self.io.writable().await?;
322 let io = &self.io;
323 match io.try_io(Interest::WRITABLE, || {
324 send_msg(state, SockRef::from(io), &transmits)
325 }) {
326 Ok(res) => break res,
327 Err(_would_block) => continue,
328 }
329 };
330 Ok(n)
331 }
332
333 pub async fn recv_mmsg(
335 &self,
336 bufs: &mut [IoSliceMut<'_>],
337 meta: &mut [RecvMeta],
338 ) -> io::Result<usize> {
339 self.recv_mmsg_with_batch_size::<DEFAULT_BATCH_SIZE>(bufs, meta)
340 .await
341 }
342
343 pub async fn recv_mmsg_with_batch_size<const BATCH_SIZE: usize>(
344 &self,
345 bufs: &mut [IoSliceMut<'_>],
346 meta: &mut [RecvMeta],
347 ) -> io::Result<usize> {
348 debug_assert!(!bufs.is_empty());
349 loop {
350 self.io.readable().await?;
351 let io = &self.io;
352 match io.try_io(Interest::READABLE, || {
353 recv::<BATCH_SIZE>(SockRef::from(io), bufs, meta)
354 }) {
355 Ok(res) => return Ok(res),
356 Err(_would_block) => continue,
357 }
358 }
359 }
360
361 pub async fn recv_msg(&self, buf: &mut [u8]) -> io::Result<RecvMeta> {
366 let mut iov = IoSliceMut::new(buf);
367 debug_assert!(!iov.is_empty());
368 loop {
369 self.io.readable().await?;
370 let io = &self.io;
371 match io.try_io(Interest::READABLE, || recv_msg(SockRef::from(io), &mut iov)) {
372 Ok(res) => return Ok(res),
373 Err(_would_block) => continue,
374 }
375 }
376 }
377
378 pub fn poll_send_mmsg<B: AsPtr<u8>>(
380 &mut self,
381 state: &UdpState,
382 cx: &mut Context,
383 transmits: &[Transmit<B>],
384 ) -> Poll<io::Result<usize>> {
385 self.poll_send_mmsg_with_batch_size::<_, DEFAULT_BATCH_SIZE>(state, cx, transmits)
386 }
387
388 pub fn poll_send_mmsg_with_batch_size<B: AsPtr<u8>, const BATCH_SIZE: usize>(
390 &mut self,
391 state: &UdpState,
392 cx: &mut Context,
393 transmits: &[Transmit<B>],
394 ) -> Poll<io::Result<usize>> {
395 loop {
396 ready!(self.io.poll_send_ready(cx))?;
397 let io = &self.io;
398 if let Ok(res) = io.try_io(Interest::WRITABLE, || {
399 send::<_, BATCH_SIZE>(
400 state,
401 SockRef::from(io),
402 self.last_send_error.clone(),
403 transmits,
404 )
405 }) {
406 return Poll::Ready(Ok(res));
407 }
408 }
409 }
410
411 pub fn poll_send_msg<B: AsPtr<u8>>(
413 &self,
414 state: &UdpState,
415 cx: &mut Context,
416 transmits: Transmit<B>,
417 ) -> Poll<io::Result<usize>> {
418 loop {
419 ready!(self.io.poll_send_ready(cx))?;
420 let io = &self.io;
421 if let Ok(res) = io.try_io(Interest::WRITABLE, || {
422 send_msg(state, SockRef::from(io), &transmits)
423 }) {
424 return Poll::Ready(Ok(res));
425 }
426 }
427 }
428
429 pub fn poll_recv_msg(
431 &self,
432 cx: &mut Context,
433 buf: &mut IoSliceMut<'_>,
434 ) -> Poll<io::Result<RecvMeta>> {
435 loop {
436 ready!(self.io.poll_recv_ready(cx))?;
437 let io = &self.io;
438 if let Ok(res) = io.try_io(Interest::READABLE, || recv_msg(SockRef::from(io), buf)) {
439 return Poll::Ready(Ok(res));
440 }
441 }
442 }
443
444 pub fn poll_recv_mmsg(
446 &self,
447 cx: &mut Context,
448 bufs: &mut [IoSliceMut<'_>],
449 meta: &mut [RecvMeta],
450 ) -> Poll<io::Result<usize>> {
451 self.poll_recv_mmsg_with_batch_size::<DEFAULT_BATCH_SIZE>(cx, bufs, meta)
452 }
453
454 pub fn poll_recv_mmsg_with_batch_size<const BATCH_SIZE: usize>(
456 &self,
457 cx: &mut Context,
458 bufs: &mut [IoSliceMut<'_>],
459 meta: &mut [RecvMeta],
460 ) -> Poll<io::Result<usize>> {
461 debug_assert!(!bufs.is_empty());
462 loop {
463 ready!(self.io.poll_recv_ready(cx))?;
464 let io = &self.io;
465 if let Ok(res) = io.try_io(Interest::READABLE, || {
466 recv::<BATCH_SIZE>(SockRef::from(io), bufs, meta)
467 }) {
468 return Poll::Ready(Ok(res));
469 }
470 }
471 }
472
473 pub fn local_addr(&self) -> io::Result<SocketAddr> {
475 self.io.local_addr()
476 }
477}
478
479pub mod sync {
480
481 use std::os::unix::prelude::IntoRawFd;
482
483 use super::*;
484
485 #[derive(Debug)]
486 pub struct UdpSocket {
487 io: std::net::UdpSocket,
488 last_send_error: LastSendError,
489 }
490
491 impl AsRawFd for UdpSocket {
492 fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
493 self.io.as_raw_fd()
494 }
495 }
496 impl IntoRawFd for UdpSocket {
497 fn into_raw_fd(self) -> std::os::unix::prelude::RawFd {
498 self.io.into_raw_fd()
499 }
500 }
501
502 impl UdpSocket {
503 pub fn from_std(socket: std::net::UdpSocket) -> io::Result<Self> {
505 init(SockRef::from(&socket))?;
506 socket.set_nonblocking(false)?;
507 Ok(Self {
508 io: socket,
509 last_send_error: LastSendError::default(),
510 })
511 }
512 pub fn bind<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
514 let io = std::net::UdpSocket::bind(addr)?;
515 init(SockRef::from(&io))?;
516 io.set_nonblocking(false)?;
517 Ok(Self {
518 io,
519 last_send_error: LastSendError::default(),
520 })
521 }
522 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
524 self.io.set_nonblocking(nonblocking)
525 }
526 pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
528 self.io.set_broadcast(broadcast)
529 }
530 pub fn connect<A: std::net::ToSocketAddrs>(&self, addrs: A) -> io::Result<()> {
531 self.io.connect(addrs)
532 }
533 pub fn join_multicast_v4(
534 &self,
535 multiaddr: Ipv4Addr,
536 interface: Ipv4Addr,
537 ) -> io::Result<()> {
538 self.io.join_multicast_v4(&multiaddr, &interface)
539 }
540 pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
541 self.io.join_multicast_v6(multiaddr, interface)
542 }
543 pub fn leave_multicast_v4(
544 &self,
545 multiaddr: Ipv4Addr,
546 interface: Ipv4Addr,
547 ) -> io::Result<()> {
548 self.io.leave_multicast_v4(&multiaddr, &interface)
549 }
550 pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
551 self.io.leave_multicast_v6(multiaddr, interface)
552 }
553 pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
554 self.io.set_multicast_loop_v4(on)
555 }
556 pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
557 self.io.set_multicast_loop_v6(on)
558 }
559 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
566 self.io.send_to(buf, target)
567 }
568 pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
575 self.io.send(buf)
576 }
577 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
584 self.io.recv_from(buf)
585 }
586 pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
593 self.io.recv(buf)
594 }
595
596 pub fn send_mmsg<B: AsPtr<u8>>(
607 &mut self,
608 state: &UdpState,
609 transmits: &[Transmit<B>],
610 ) -> Result<usize, io::Error> {
611 self.send_mmsg_with_batch_size::<_, DEFAULT_BATCH_SIZE>(state, transmits)
612 }
613
614 pub fn send_mmsg_with_batch_size<B: AsPtr<u8>, const BATCH_SIZE: usize>(
624 &mut self,
625 state: &UdpState,
626 transmits: &[Transmit<B>],
627 ) -> Result<usize, io::Error> {
628 send::<_, BATCH_SIZE>(
629 state,
630 SockRef::from(&self.io),
631 self.last_send_error.clone(),
632 transmits,
633 )
634 }
635
636 pub fn send_msg<B: AsPtr<u8>>(
641 &self,
642 state: &UdpState,
643 transmits: Transmit<B>,
644 ) -> io::Result<usize> {
645 send_msg(state, SockRef::from(&self.io), &transmits)
646 }
647
648 pub fn recv_mmsg(
650 &self,
651 bufs: &mut [IoSliceMut<'_>],
652 meta: &mut [RecvMeta],
653 ) -> io::Result<usize> {
654 self.recv_mmsg_with_batch_size::<DEFAULT_BATCH_SIZE>(bufs, meta)
655 }
656
657 pub fn recv_mmsg_with_batch_size<const BATCH_SIZE: usize>(
659 &self,
660 bufs: &mut [IoSliceMut<'_>],
661 meta: &mut [RecvMeta],
662 ) -> io::Result<usize> {
663 debug_assert!(!bufs.is_empty());
664 recv::<BATCH_SIZE>(SockRef::from(&self.io), bufs, meta)
665 }
666
667 pub fn recv_msg(&self, buf: &mut [u8]) -> io::Result<RecvMeta> {
672 let mut iov = IoSliceMut::new(buf);
673 debug_assert!(!iov.is_empty());
674
675 recv_msg(SockRef::from(&self.io), &mut iov)
676 }
677 pub fn local_addr(&self) -> io::Result<SocketAddr> {
679 self.io.local_addr()
680 }
681 }
682}
683
684fn set_socket_option<Fd: AsRawFd>(
685 socket: &Fd,
686 level: libc::c_int,
687 name: libc::c_int,
688 value: libc::c_int,
689) -> Result<(), io::Error> {
690 let rc = unsafe {
691 libc::setsockopt(
692 socket.as_raw_fd(),
693 level,
694 name,
695 &value as *const _ as _,
696 mem::size_of_val(&value) as _,
697 )
698 };
699
700 if rc != -1 {
701 Ok(())
702 } else {
703 Err(io::Error::last_os_error())
704 }
705}
706
707const OPTION_ON: libc::c_int = 1;
708
709fn init(io: SockRef<'_>) -> io::Result<()> {
710 let mut cmsg_platform_space = 0;
711 if cfg!(target_os = "linux") || cfg!(target_os = "freebsd") || cfg!(target_os = "macos") {
712 cmsg_platform_space +=
713 unsafe { libc::CMSG_SPACE(mem::size_of::<libc::in6_pktinfo>() as _) as usize };
714 }
715
716 assert!(
717 CMSG_LEN
718 >= unsafe { libc::CMSG_SPACE(mem::size_of::<libc::c_int>() as _) as usize }
719 + cmsg_platform_space
720 );
721 assert!(
722 mem::align_of::<libc::cmsghdr>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
723 "control message buffers will be misaligned"
724 );
725
726 io.set_nonblocking(true)?;
727
728 let addr = io.local_addr()?;
729 let is_ipv4 = addr.family() == libc::AF_INET as libc::sa_family_t;
730
731 if is_ipv4 || ((!cfg!(any(target_os = "macos", target_os = "ios"))) && !io.only_v6()?) {
733 set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_RECVTOS, OPTION_ON)?;
734 }
735 #[cfg(target_os = "linux")]
736 {
737 set_socket_option(
739 &*io,
740 libc::IPPROTO_IP,
741 libc::IP_MTU_DISCOVER,
742 libc::IP_PMTUDISC_PROBE,
743 )?;
744
745 if is_ipv4 {
746 set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_PKTINFO, OPTION_ON)?;
747 } else {
748 set_socket_option(
749 &*io,
750 libc::IPPROTO_IPV6,
751 libc::IPV6_MTU_DISCOVER,
752 libc::IP_PMTUDISC_PROBE,
753 )?;
754 }
755 }
756 #[cfg(target_os = "macos")]
757 {
758 if is_ipv4 {
759 set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_PKTINFO, OPTION_ON)?;
760 }
761 }
762 #[cfg(target_os = "freebsd")]
763 {
767 if is_ipv4 {
768 set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_RECVDSTADDR, OPTION_ON)?;
769 set_socket_option(&*io, libc::IPPROTO_IP, libc::IP_RECVIF, OPTION_ON)?;
770 }
771 }
772
773 if !is_ipv4 {
775 set_socket_option(&*io, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, OPTION_ON)?;
776 set_socket_option(&*io, libc::IPPROTO_IPV6, libc::IPV6_RECVTCLASS, OPTION_ON)?;
777 }
778 Ok(())
779}
780
781#[cfg(not(any(target_os = "macos", target_os = "ios")))]
782fn send_msg<B: AsPtr<u8>>(
783 #[allow(unused_variables)] state: &UdpState,
785 io: SockRef<'_>,
786 transmit: &Transmit<B>,
787) -> io::Result<usize> {
788 let mut msg_hdr: libc::msghdr = unsafe { mem::zeroed() };
789 let mut iovec: libc::iovec = unsafe { mem::zeroed() };
790 let mut cmsg = cmsg::Aligned([0u8; CMSG_LEN]);
791
792 let addr = socket2::SockAddr::from(transmit.dst);
793 let dst_addr = &addr;
794 prepare_msg(transmit, dst_addr, &mut msg_hdr, &mut iovec, &mut cmsg);
795
796 loop {
797 let n = unsafe { libc::sendmsg(io.as_raw_fd(), &msg_hdr, 0) };
798 if n == -1 {
799 let e = io::Error::last_os_error();
800 match e.kind() {
801 io::ErrorKind::Interrupted => {
802 continue;
804 }
805 io::ErrorKind::WouldBlock => return Err(e),
806 _ => {
807 #[cfg(target_os = "linux")]
811 if e.raw_os_error() == Some(libc::EIO) {
812 if state.max_gso_segments() > 1 {
815 tracing::error!("got EIO, halting segmentation offload");
816 state
817 .max_gso_segments
818 .store(1, std::sync::atomic::Ordering::Relaxed);
819 }
820 }
821
822 return Ok(n as usize);
835 }
836 }
837 }
838 return Ok(n as usize);
839 }
840}
841
842#[cfg(not(any(target_os = "macos", target_os = "ios")))]
843fn send<B: AsPtr<u8>, const BATCH_SIZE: usize>(
844 #[allow(unused_variables)] state: &UdpState,
846 io: SockRef<'_>,
847 last_send_error: LastSendError,
848 transmits: &[Transmit<B>],
849) -> io::Result<usize> {
850 use std::ptr;
851
852 let mut msgs: [libc::mmsghdr; BATCH_SIZE] = unsafe { mem::zeroed() };
853 let mut iovecs: [libc::iovec; BATCH_SIZE] = unsafe { mem::zeroed() };
854 let mut cmsgs = [cmsg::Aligned([0u8; CMSG_LEN]); BATCH_SIZE];
855 let mut addrs: [MaybeUninit<socket2::SockAddr>; BATCH_SIZE] =
862 unsafe { MaybeUninit::uninit().assume_init() };
863 for (i, transmit) in transmits.iter().enumerate().take(BATCH_SIZE) {
864 let dst_addr = unsafe {
865 ptr::write(addrs[i].as_mut_ptr(), socket2::SockAddr::from(transmit.dst));
866 &*addrs[i].as_ptr()
867 };
868 prepare_msg(
869 transmit,
870 dst_addr,
871 &mut msgs[i].msg_hdr,
872 &mut iovecs[i],
873 &mut cmsgs[i],
874 );
875 }
876 let num_transmits = transmits.len().min(BATCH_SIZE);
877
878 loop {
879 #[cfg(target_os = "linux")]
880 let n =
881 unsafe { libc::sendmmsg(io.as_raw_fd(), msgs.as_mut_ptr(), num_transmits as u32, 0) };
882 #[cfg(target_os = "freebsd")]
883 let n =
884 unsafe { libc::sendmmsg(io.as_raw_fd(), msgs.as_mut_ptr(), num_transmits as usize, 0) };
885 if n == -1 {
886 let e = io::Error::last_os_error();
887 match e.kind() {
888 io::ErrorKind::Interrupted => {
889 continue;
891 }
892 io::ErrorKind::WouldBlock => return Err(e),
893 _ => {
894 #[cfg(target_os = "linux")]
898 if e.raw_os_error() == Some(libc::EIO) {
899 if state.max_gso_segments() > 1 {
902 tracing::error!("got EIO, halting segmentation offload");
903 state
904 .max_gso_segments
905 .store(1, std::sync::atomic::Ordering::Relaxed);
906 }
907 }
908
909 log_sendmsg_error(last_send_error, e, &transmits[0]);
916
917 return Ok(num_transmits.min(1));
922 }
923 }
924 }
925 return Ok(n as usize);
926 }
927}
928
929#[cfg(any(target_os = "macos", target_os = "ios"))]
930fn send_msg<B: AsPtr<u8>>(
931 _state: &UdpState,
932 io: SockRef<'_>,
933 transmit: &Transmit<B>,
934) -> io::Result<usize> {
935 let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
936 let mut iov: libc::iovec = unsafe { mem::zeroed() };
937 let mut ctrl = cmsg::Aligned([0u8; CMSG_LEN]);
938
939 let addr = socket2::SockAddr::from(transmit.dst);
940 let dst_addr = &addr;
941 prepare_msg(transmit, dst_addr, &mut hdr, &mut iov, &mut ctrl);
942
943 loop {
944 let n = unsafe { libc::sendmsg(io.as_raw_fd(), &hdr, 0) };
945 if n == -1 {
946 let e = io::Error::last_os_error();
947 match e.kind() {
948 io::ErrorKind::Interrupted => {
949 continue;
951 }
952 io::ErrorKind::WouldBlock => return Err(e),
953 _ => {
954 return Ok(n as usize);
961 }
962 }
963 }
964 return Ok(n as usize);
965 }
966}
967
968#[cfg(any(target_os = "macos", target_os = "ios"))]
969fn send<B: AsPtr<u8>, const BATCH_SIZE: usize>(
970 _state: &UdpState,
971 io: SockRef<'_>,
972 last_send_error: LastSendError,
973 transmits: &[Transmit<B>],
974) -> io::Result<usize> {
975 let mut hdr: libc::msghdr = unsafe { mem::zeroed() };
976 let mut iov: libc::iovec = unsafe { mem::zeroed() };
977 let mut ctrl = cmsg::Aligned([0u8; CMSG_LEN]);
978 let mut sent = 0;
979 while sent < transmits.len() {
980 let addr = socket2::SockAddr::from(transmits[sent].dst);
981 prepare_msg(&transmits[sent], &addr, &mut hdr, &mut iov, &mut ctrl);
982 let n = unsafe { libc::sendmsg(io.as_raw_fd(), &hdr, 0) };
983 if n == -1 {
984 let e = io::Error::last_os_error();
985 match e.kind() {
986 io::ErrorKind::Interrupted => {
987 }
989 io::ErrorKind::WouldBlock if sent != 0 => return Ok(sent),
990 io::ErrorKind::WouldBlock => return Err(e),
991 _ => {
992 log_sendmsg_error(last_send_error.clone(), e, &transmits[sent]);
999 sent += 1;
1000 }
1001 }
1002 } else {
1003 sent += 1;
1004 }
1005 }
1006 Ok(sent)
1007}
1008
1009#[cfg(not(any(target_os = "macos", target_os = "ios")))]
1010fn recv<const BATCH_SIZE: usize>(
1011 io: SockRef<'_>,
1012 bufs: &mut [IoSliceMut<'_>],
1013 meta: &mut [RecvMeta],
1014) -> io::Result<usize> {
1015 use std::ptr;
1016
1017 let mut names = [MaybeUninit::<libc::sockaddr_storage>::uninit(); BATCH_SIZE];
1018 let mut ctrls = [cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit()); BATCH_SIZE];
1019 let mut hdrs = unsafe { mem::zeroed::<[libc::mmsghdr; BATCH_SIZE]>() };
1020 let max_msg_count = bufs.len().min(BATCH_SIZE);
1021 for i in 0..max_msg_count {
1022 prepare_recv(
1023 &mut bufs[i],
1024 &mut names[i],
1025 &mut ctrls[i],
1026 &mut hdrs[i].msg_hdr,
1027 );
1028 }
1029 let msg_count = loop {
1030 #[cfg(target_os = "linux")]
1031 let n = unsafe {
1032 libc::recvmmsg(
1033 io.as_raw_fd(),
1034 hdrs.as_mut_ptr(),
1035 bufs.len().min(BATCH_SIZE) as libc::c_uint,
1036 0,
1037 ptr::null_mut(),
1038 )
1039 };
1040 #[cfg(target_os = "freebsd")]
1041 let n = unsafe {
1042 libc::recvmmsg(
1043 io.as_raw_fd(),
1044 hdrs.as_mut_ptr(),
1045 bufs.len().min(BATCH_SIZE) as usize,
1046 0,
1047 ptr::null_mut(),
1048 )
1049 };
1050 if n == -1 {
1051 let e = io::Error::last_os_error();
1052 if e.kind() == io::ErrorKind::Interrupted {
1053 continue;
1054 }
1055 return Err(e);
1056 }
1057 break n;
1058 };
1059 for i in 0..(msg_count as usize) {
1060 meta[i] = decode_recv(&names[i], &hdrs[i].msg_hdr, hdrs[i].msg_len as usize);
1061 }
1062 Ok(msg_count as usize)
1063}
1064
1065#[cfg(not(any(target_os = "macos", target_os = "ios")))]
1066fn recv_msg(io: SockRef<'_>, bufs: &mut IoSliceMut<'_>) -> io::Result<RecvMeta> {
1067 let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
1068 let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
1069 let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
1070
1071 prepare_recv(bufs, &mut name, &mut ctrl, &mut hdr);
1072
1073 let n = loop {
1074 let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
1075 if n == -1 {
1076 let e = io::Error::last_os_error();
1077 if e.kind() == io::ErrorKind::Interrupted {
1078 continue;
1079 }
1080 return Err(e);
1081 }
1082 if hdr.msg_flags & libc::MSG_TRUNC != 0 {
1083 continue;
1084 }
1085 break n;
1086 };
1087 Ok(decode_recv(&name, &hdr, n as usize))
1088}
1089
1090#[cfg(any(target_os = "macos", target_os = "ios"))]
1091fn recv<const BATCH_SIZE: usize>(
1092 io: SockRef<'_>,
1093 bufs: &mut [IoSliceMut<'_>],
1094 meta: &mut [RecvMeta],
1095) -> io::Result<usize> {
1096 let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
1097 let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
1098 let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
1099 prepare_recv(&mut bufs[0], &mut name, &mut ctrl, &mut hdr);
1100 let n = loop {
1101 let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
1102 if n == -1 {
1103 let e = io::Error::last_os_error();
1104 if e.kind() == io::ErrorKind::Interrupted {
1105 continue;
1106 }
1107 return Err(e);
1108 }
1109 if hdr.msg_flags & libc::MSG_TRUNC != 0 {
1110 continue;
1111 }
1112 break n;
1113 };
1114 meta[0] = decode_recv(&name, &hdr, n as usize);
1115 Ok(1)
1116}
1117
1118#[cfg(any(target_os = "macos", target_os = "ios"))]
1119fn recv_msg(io: SockRef<'_>, bufs: &mut IoSliceMut<'_>) -> io::Result<RecvMeta> {
1120 let mut name = MaybeUninit::<libc::sockaddr_storage>::uninit();
1121 let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit());
1122 let mut hdr = unsafe { mem::zeroed::<libc::msghdr>() };
1123 prepare_recv(bufs, &mut name, &mut ctrl, &mut hdr);
1124 let n = loop {
1125 let n = unsafe { libc::recvmsg(io.as_raw_fd(), &mut hdr, 0) };
1126 if n == -1 {
1127 let e = io::Error::last_os_error();
1128 if e.kind() == io::ErrorKind::Interrupted {
1129 continue;
1130 }
1131 return Err(e);
1132 }
1133 if hdr.msg_flags & libc::MSG_TRUNC != 0 {
1134 continue;
1135 }
1136 break n;
1137 };
1138 Ok(decode_recv(&name, &hdr, n as usize))
1139}
1140
1141pub fn udp_state() -> UdpState {
1143 UdpState {
1144 max_gso_segments: AtomicUsize::new(gso::max_gso_segments()),
1145 gro_segments: gro::gro_segments(),
1146 }
1147}
1148
1149const CMSG_LEN: usize = 88;
1150
1151fn prepare_msg<B: AsPtr<u8>>(
1152 transmit: &Transmit<B>,
1153 dst_addr: &socket2::SockAddr,
1154 hdr: &mut libc::msghdr,
1155 iov: &mut libc::iovec,
1156 ctrl: &mut cmsg::Aligned<[u8; CMSG_LEN]>,
1157) {
1158 iov.iov_base = transmit.contents.as_ptr() as *const _ as *mut _;
1159 iov.iov_len = transmit.contents.len();
1160
1161 let name = dst_addr.as_ptr() as *mut libc::c_void;
1167 let namelen = dst_addr.len();
1168 hdr.msg_name = name as *mut _;
1169 hdr.msg_namelen = namelen;
1170 hdr.msg_iov = iov;
1171 hdr.msg_iovlen = 1;
1172
1173 hdr.msg_control = ctrl.0.as_mut_ptr() as _;
1174 hdr.msg_controllen = CMSG_LEN as _;
1175 let mut encoder = unsafe { cmsg::Encoder::new(hdr) };
1176 let ecn = transmit.ecn.map_or(0, |x| x as libc::c_int);
1177 if transmit.dst.is_ipv4() {
1178 encoder.push(libc::IPPROTO_IP, libc::IP_TOS, ecn as IpTosTy);
1179 } else {
1180 encoder.push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn);
1181 }
1182
1183 if let Some(segment_size) = transmit.segment_size {
1184 gso::set_segment_size(&mut encoder, segment_size as u16);
1185 }
1186
1187 if let Some(ip) = &transmit.src {
1188 match ip {
1189 Source::Ip(IpAddr::V4(v4)) => {
1190 #[cfg(any(target_os = "linux", target_os = "macos"))]
1191 {
1192 let pktinfo = libc::in_pktinfo {
1193 ipi_ifindex: 0,
1194 ipi_spec_dst: libc::in_addr {
1195 s_addr: u32::from_ne_bytes(v4.octets()),
1196 },
1197 ipi_addr: libc::in_addr { s_addr: 0 },
1198 };
1199 encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo);
1200 }
1201 #[cfg(target_os = "freebsd")]
1202 {
1203 let addr = libc::in_addr {
1204 s_addr: u32::from_ne_bytes(v4.octets()),
1205 };
1206 encoder.push(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, addr);
1207 }
1208 }
1209 Source::Ip(IpAddr::V6(v6)) => {
1210 let pktinfo = libc::in6_pktinfo {
1211 ipi6_ifindex: 0,
1212 ipi6_addr: libc::in6_addr {
1213 s6_addr: v6.octets(),
1214 },
1215 };
1216 encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo);
1217 }
1218 #[cfg(not(target_os = "freebsd"))]
1219 Source::Interface(i) => {
1220 let pktinfo = libc::in_pktinfo {
1221 ipi_ifindex: *i as _, ipi_spec_dst: libc::in_addr { s_addr: 0 },
1223 ipi_addr: libc::in_addr { s_addr: 0 },
1224 };
1225 encoder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo);
1226 }
1227 #[cfg(target_os = "freebsd")]
1228 Source::Interface(_) => (), Source::InterfaceV6(i, ip) => {
1230 let pktinfo = libc::in6_pktinfo {
1231 ipi6_ifindex: *i,
1232 ipi6_addr: libc::in6_addr {
1233 s6_addr: ip.octets(),
1234 },
1235 };
1236 encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo);
1237 }
1238 }
1239 }
1240
1241 encoder.finish();
1242}
1243
1244fn prepare_recv(
1245 buf: &mut IoSliceMut,
1246 name: &mut MaybeUninit<libc::sockaddr_storage>,
1247 ctrl: &mut cmsg::Aligned<MaybeUninit<[u8; CMSG_LEN]>>,
1248 hdr: &mut libc::msghdr,
1249) {
1250 hdr.msg_name = name.as_mut_ptr() as _;
1251 hdr.msg_namelen = mem::size_of::<libc::sockaddr_storage>() as _;
1252 hdr.msg_iov = buf as *mut IoSliceMut as *mut libc::iovec;
1253 hdr.msg_iovlen = 1;
1254 hdr.msg_control = ctrl.0.as_mut_ptr() as _;
1255 hdr.msg_controllen = CMSG_LEN as _;
1256 hdr.msg_flags = 0;
1257}
1258
1259fn decode_recv(
1260 name: &MaybeUninit<libc::sockaddr_storage>,
1261 hdr: &libc::msghdr,
1262 len: usize,
1263) -> RecvMeta {
1264 let name = unsafe { name.assume_init() };
1265 let mut ecn_bits = 0;
1266 let mut dst_ip = None;
1267 #[allow(unused_mut)]
1269 let mut dst_local_ip = None;
1270 let mut ifindex = 0;
1271 #[allow(unused_mut)]
1273 let mut stride = len;
1274
1275 let cmsg_iter = unsafe { cmsg::Iter::new(hdr) };
1276 for cmsg in cmsg_iter {
1277 match (cmsg.cmsg_level, cmsg.cmsg_type) {
1278 (libc::IPPROTO_IP, libc::IP_TOS) | (libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe {
1280 ecn_bits = cmsg::decode::<u8>(cmsg);
1281 },
1282 (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe {
1283 if cfg!(target_os = "macos")
1286 && cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::<u8>() as _) as usize
1287 {
1288 ecn_bits = cmsg::decode::<u8>(cmsg);
1289 } else {
1290 ecn_bits = cmsg::decode::<libc::c_int>(cmsg) as u8;
1291 }
1292 },
1293 #[cfg(not(target_os = "freebsd"))]
1294 (libc::IPPROTO_IP, libc::IP_PKTINFO) => {
1295 let pktinfo = unsafe { cmsg::decode::<libc::in_pktinfo>(cmsg) };
1296 dst_ip = Some(IpAddr::V4(Ipv4Addr::from(
1297 pktinfo.ipi_addr.s_addr.to_ne_bytes(),
1298 )));
1299 dst_local_ip = Some(IpAddr::V4(Ipv4Addr::from(
1300 pktinfo.ipi_spec_dst.s_addr.to_ne_bytes(),
1301 )));
1302 ifindex = pktinfo.ipi_ifindex as _;
1303 }
1304 (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
1305 let pktinfo = unsafe { cmsg::decode::<libc::in6_pktinfo>(cmsg) };
1306 dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)));
1307 ifindex = pktinfo.ipi6_ifindex;
1308 }
1309 #[cfg(target_os = "freebsd")]
1311 (libc::IPPROTO_IP, libc::IP_RECVIF) => {
1312 let info = unsafe { cmsg::decode::<libc::sockaddr_dl>(cmsg) };
1313 ifindex = info.sdl_index as _;
1314 }
1315 #[cfg(target_os = "linux")]
1316 (libc::SOL_UDP, libc::UDP_GRO) => unsafe {
1317 stride = cmsg::decode::<libc::c_int>(cmsg) as usize;
1318 },
1319 _ => {}
1320 }
1321 }
1322
1323 let addr = match libc::c_int::from(name.ss_family) {
1324 libc::AF_INET => {
1325 let addr = unsafe { &*(&name as *const _ as *const libc::sockaddr_in) };
1327 let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes());
1328 let port = u16::from_be(addr.sin_port);
1329 SocketAddr::V4(SocketAddrV4::new(ip, port))
1330 }
1331 libc::AF_INET6 => {
1332 let addr = unsafe { &*(&name as *const _ as *const libc::sockaddr_in6) };
1334 let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
1335 let port = u16::from_be(addr.sin6_port);
1336 SocketAddr::V6(SocketAddrV6::new(
1337 ip,
1338 port,
1339 addr.sin6_flowinfo,
1340 addr.sin6_scope_id,
1341 ))
1342 }
1343 _ => unreachable!(),
1344 };
1345
1346 RecvMeta {
1347 len,
1348 stride,
1349 addr,
1350 ecn: EcnCodepoint::from_bits(ecn_bits),
1351 dst_ip,
1352 dst_local_ip,
1353 ifindex,
1354 }
1355}
1356
1357#[cfg(target_os = "linux")]
1358mod gso {
1359 use super::*;
1360
1361 pub fn max_gso_segments() -> usize {
1364 const GSO_SIZE: libc::c_int = 1500;
1365
1366 let socket = match std::net::UdpSocket::bind("[::]:0")
1367 .or_else(|_| std::net::UdpSocket::bind("127.0.0.1:0"))
1368 {
1369 Ok(socket) => socket,
1370 Err(_) => return 1,
1371 };
1372
1373 match set_socket_option(&socket, libc::SOL_UDP, libc::UDP_SEGMENT, GSO_SIZE) {
1376 Ok(()) => 64,
1377 Err(_) => 1,
1378 }
1379 }
1380
1381 pub fn set_segment_size(encoder: &mut cmsg::Encoder, segment_size: u16) {
1382 encoder.push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size);
1383 }
1384}
1385
1386#[cfg(not(target_os = "linux"))]
1387mod gso {
1388 use super::*;
1389
1390 pub fn max_gso_segments() -> usize {
1391 1
1392 }
1393
1394 pub fn set_segment_size(_encoder: &mut cmsg::Encoder, _segment_size: u16) {
1395 panic!("Setting a segment size is not supported on current platform");
1396 }
1397}
1398
1399#[cfg(target_os = "linux")]
1400mod gro {
1401 use super::*;
1402
1403 pub fn gro_segments() -> usize {
1404 let socket = match std::net::UdpSocket::bind("[::]:0") {
1405 Ok(socket) => socket,
1406 Err(_) => return 1,
1407 };
1408
1409 let on: libc::c_int = 1;
1410 let rc = unsafe {
1411 libc::setsockopt(
1412 socket.as_raw_fd(),
1413 libc::SOL_UDP,
1414 libc::UDP_GRO,
1415 &on as *const _ as _,
1416 mem::size_of_val(&on) as _,
1417 )
1418 };
1419
1420 if rc != -1 {
1421 64
1429 } else {
1430 1
1431 }
1432 }
1433}
1434
1435#[cfg(not(target_os = "linux"))]
1436mod gro {
1437 pub fn gro_segments() -> usize {
1438 1
1439 }
1440}