rust_p2p_core/tunnel/
udp.rs

1use std::io;
2use std::io::IoSlice;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6#[cfg(any(target_os = "linux", target_os = "android"))]
7use tokio::io::Interest;
8
9#[cfg(any(target_os = "linux", target_os = "android"))]
10pub async fn read_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
11    udp.async_io(Interest::READABLE, op).await
12}
13#[cfg(any(target_os = "linux", target_os = "android"))]
14pub async fn write_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
15    udp.async_io(Interest::WRITABLE, op).await
16}
17
18use bytes::BytesMut;
19use dashmap::DashMap;
20use parking_lot::{Mutex, RwLock};
21use tachyonix::{Receiver, Sender, TrySendError};
22use tokio::net::UdpSocket;
23
24use crate::route::{Index, RouteKey};
25use crate::socket::{bind_udp, LocalInterface};
26use crate::tunnel::config::UdpTunnelConfig;
27use crate::tunnel::recycle::RecycleBuf;
28use crate::tunnel::{DEFAULT_ADDRESS_V4, DEFAULT_ADDRESS_V6};
29
30#[cfg(any(target_os = "linux", target_os = "android"))]
31const MAX_MESSAGES: usize = 16;
32#[cfg(any(target_os = "linux", target_os = "android"))]
33use libc::{c_uint, iovec, mmsghdr, sockaddr_storage, socklen_t};
34#[cfg(any(target_os = "linux", target_os = "android"))]
35use std::os::fd::AsRawFd;
36
37#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
38pub enum Model {
39    High,
40    #[default]
41    Low,
42}
43
44impl Model {
45    pub fn is_low(&self) -> bool {
46        self == &Model::Low
47    }
48    pub fn is_high(&self) -> bool {
49        self == &Model::High
50    }
51}
52
53#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
54pub enum UDPIndex {
55    MainV4(usize),
56    MainV6(usize),
57    SubV4(usize),
58}
59
60impl UDPIndex {
61    pub(crate) fn index(&self) -> usize {
62        match self {
63            UDPIndex::MainV4(i) => *i,
64            UDPIndex::MainV6(i) => *i,
65            UDPIndex::SubV4(i) => *i,
66        }
67    }
68}
69
70pub trait ToRouteKeyForUdp<T> {
71    fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey>;
72}
73
74impl ToRouteKeyForUdp<()> for RouteKey {
75    fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
76        Ok(dest)
77    }
78}
79
80impl ToRouteKeyForUdp<()> for &RouteKey {
81    fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
82        Ok(*dest)
83    }
84}
85
86impl ToRouteKeyForUdp<()> for &mut RouteKey {
87    fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
88        Ok(*dest)
89    }
90}
91
92impl<S: Into<SocketAddr>> ToRouteKeyForUdp<()> for S {
93    fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
94        let addr = dest.into();
95        socket_manager.generate_route_key_from_addr(0, addr)
96    }
97}
98
99impl<S: Into<SocketAddr>> ToRouteKeyForUdp<usize> for (usize, S) {
100    fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
101        let (index, addr) = dest;
102        socket_manager.generate_route_key_from_addr(index, addr.into())
103    }
104}
105
106/// initialize udp tunnel by config
107pub(crate) fn create_tunnel_dispatcher(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
108    config.check()?;
109    let mut udp_ports = config.udp_ports;
110    udp_ports.resize(config.main_udp_count, 0);
111    let mut main_udp_v4: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
112    let mut main_udp_v6: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
113    // 因为在mac上v4和v6的对绑定网卡的处理不同,所以这里分开监听,并且分开监听更容易处理发送目标为v4的情况,因为双协议栈下发送v4目标需要转换成v6
114    for port in &udp_ports {
115        loop {
116            let mut addr_v4 = DEFAULT_ADDRESS_V4;
117            addr_v4.set_port(*port);
118            let socket_v4 = bind_udp(addr_v4, config.default_interface.as_ref())?;
119            let udp_v4: std::net::UdpSocket = socket_v4.into();
120            if config.use_v6 {
121                let mut addr_v6 = DEFAULT_ADDRESS_V6;
122                let socket_v6 = if *port == 0 {
123                    let port = udp_v4.local_addr()?.port();
124                    addr_v6.set_port(port);
125                    match bind_udp(addr_v6, config.default_interface.as_ref()) {
126                        Ok(socket_v6) => socket_v6,
127                        Err(_) => continue,
128                    }
129                } else {
130                    addr_v6.set_port(*port);
131                    bind_udp(addr_v6, config.default_interface.as_ref())?
132                };
133                let udp_v6: std::net::UdpSocket = socket_v6.into();
134                main_udp_v6.push(Arc::new(UdpSocket::from_std(udp_v6)?))
135            }
136            main_udp_v4.push(Arc::new(UdpSocket::from_std(udp_v4)?));
137            break;
138        }
139    }
140    let (tunnel_sender, tunnel_receiver) =
141        tachyonix::channel(config.main_udp_count * 2 + config.sub_udp_count * 2);
142    let socket_manager = Arc::new(UdpSocketManager {
143        main_udp_v4,
144        main_udp_v6,
145        sub_udp: RwLock::new(Vec::with_capacity(config.sub_udp_count)),
146        sub_close_notify: Default::default(),
147        tunnel_dispatcher: tunnel_sender,
148        sub_udp_num: config.sub_udp_count,
149        default_interface: config.default_interface,
150        sender_map: Default::default(),
151    });
152    let tunnel_factory = UdpTunnelDispatcher {
153        tunnel_receiver,
154        socket_manager,
155        recycle_buf: config.recycle_buf,
156    };
157    tunnel_factory.init()?;
158    tunnel_factory.socket_manager.switch_model(config.model)?;
159    Ok(tunnel_factory)
160}
161
162pub struct UdpSocketManager {
163    main_udp_v4: Vec<Arc<UdpSocket>>,
164    main_udp_v6: Vec<Arc<UdpSocket>>,
165    sub_udp: RwLock<Vec<Arc<UdpSocket>>>,
166    sub_close_notify: Mutex<Option<async_broadcast::Sender<()>>>,
167    tunnel_dispatcher: Sender<InactiveUdpTunnel>,
168    sub_udp_num: usize,
169    default_interface: Option<LocalInterface>,
170    sender_map: DashMap<Index, Sender<(BytesMut, SocketAddr)>>,
171}
172
173impl UdpSocketManager {
174    pub(crate) fn try_sub_batch_send_to(&self, buf: &[u8], addr: SocketAddr) {
175        for (i, udp) in self.sub_udp.read().iter().enumerate() {
176            if let Err(e) = udp.try_send_to(buf, addr) {
177                log::info!("try_sub_send_to_addr_v4: {e:?},{i},{addr}")
178            }
179        }
180    }
181    pub(crate) fn try_main_v4_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
182        let len = self.main_udp_v4_count();
183        self.try_main_batch_send_to_impl(buf, addr, len);
184    }
185    pub(crate) fn try_main_v6_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
186        let len = self.main_udp_v6_count();
187        self.try_main_batch_send_to_impl(buf, addr, len);
188    }
189
190    pub(crate) fn try_main_batch_send_to_impl(&self, buf: &[u8], addr: &[SocketAddr], len: usize) {
191        for (i, addr) in addr.iter().enumerate() {
192            if let Err(e) = self.try_send_to(buf, (i % len, *addr)) {
193                log::info!("try_main_send_to_addr: {e:?},{},{addr}", i % len);
194            }
195        }
196    }
197    pub(crate) fn generate_route_key_from_addr(
198        &self,
199        index: usize,
200        addr: SocketAddr,
201    ) -> io::Result<RouteKey> {
202        let route_key = if addr.is_ipv4() {
203            let len = self.main_udp_v4.len();
204            if index >= len {
205                return Err(io::Error::other("index out of bounds"));
206            }
207            RouteKey::new(Index::Udp(UDPIndex::MainV4(index)), addr)
208        } else {
209            let len = self.main_udp_v6.len();
210            if len == 0 {
211                return Err(io::Error::other("Not support IPV6"));
212            }
213            if index >= len {
214                return Err(io::Error::other("index out of bounds"));
215            }
216            RouteKey::new(Index::Udp(UDPIndex::MainV6(index)), addr)
217        };
218        Ok(route_key)
219    }
220    pub(crate) fn switch_low(&self) {
221        let mut guard = self.sub_udp.write();
222        if guard.is_empty() {
223            return;
224        }
225        guard.clear();
226        if let Some(sub_close_notify) = self.sub_close_notify.lock().take() {
227            let _ = sub_close_notify.close();
228        }
229    }
230    pub(crate) fn switch_high(&self) -> io::Result<()> {
231        let mut guard = self.sub_udp.write();
232        if !guard.is_empty() {
233            return Ok(());
234        }
235        let mut sub_close_notify_guard = self.sub_close_notify.lock();
236        if let Some(sender) = sub_close_notify_guard.take() {
237            let _ = sender.close();
238        }
239        let (sub_close_notify_sender, sub_close_notify_receiver) = async_broadcast::broadcast(2);
240        let mut sub_udp_list = Vec::with_capacity(self.sub_udp_num);
241        for _ in 0..self.sub_udp_num {
242            let udp = bind_udp(DEFAULT_ADDRESS_V4, self.default_interface.as_ref())?;
243            let udp: std::net::UdpSocket = udp.into();
244            sub_udp_list.push(Arc::new(UdpSocket::from_std(udp)?));
245        }
246        for (index, udp) in sub_udp_list.iter().enumerate() {
247            let udp = udp.clone();
248            let udp_tunnel = InactiveUdpTunnel::new(
249                false,
250                Index::Udp(UDPIndex::SubV4(index)),
251                udp,
252                Some(sub_close_notify_receiver.clone()),
253            );
254            if self.tunnel_dispatcher.try_send(udp_tunnel).is_err() {
255                Err(io::Error::other("tunnel channel error"))?
256            }
257        }
258        sub_close_notify_guard.replace(sub_close_notify_sender);
259        *guard = sub_udp_list;
260        Ok(())
261    }
262
263    #[inline]
264    fn get_udp(&self, udp_index: UDPIndex) -> io::Result<Arc<UdpSocket>> {
265        Ok(match udp_index {
266            UDPIndex::MainV4(index) => self
267                .main_udp_v4
268                .get(index)
269                .ok_or(io::Error::other("index out of bounds"))?
270                .clone(),
271            UDPIndex::MainV6(index) => self
272                .main_udp_v6
273                .get(index)
274                .ok_or(io::Error::other("index out of bounds"))?
275                .clone(),
276            UDPIndex::SubV4(index) => {
277                let guard = self.sub_udp.read();
278                let len = guard.len();
279                if len <= index {
280                    return Err(io::Error::other("index out of bounds"));
281                } else {
282                    guard[index].clone()
283                }
284            }
285        })
286    }
287
288    #[inline]
289    fn get_udp_from_route(&self, route_key: &RouteKey) -> io::Result<Arc<UdpSocket>> {
290        Ok(match route_key.index() {
291            Index::Udp(index) => self.get_udp(index)?,
292            _ => return Err(io::Error::from(io::ErrorKind::InvalidInput)),
293        })
294    }
295}
296
297impl UdpSocketManager {
298    pub fn model(&self) -> Model {
299        if self.sub_udp.read().is_empty() {
300            Model::Low
301        } else {
302            Model::High
303        }
304    }
305
306    #[inline]
307    pub fn main_udp_v4_count(&self) -> usize {
308        self.main_udp_v4.len()
309    }
310    #[inline]
311    pub fn main_udp_v6_count(&self) -> usize {
312        self.main_udp_v6.len()
313    }
314
315    pub fn switch_model(&self, model: Model) -> io::Result<()> {
316        match model {
317            Model::High => self.switch_high(),
318            Model::Low => {
319                self.switch_low();
320                Ok(())
321            }
322        }
323    }
324    /// Acquire the local ports `UDP` sockets bind on
325    pub fn local_ports(&self) -> io::Result<Vec<u16>> {
326        let mut ports = Vec::with_capacity(self.main_udp_v4_count());
327        for udp in &self.main_udp_v4 {
328            ports.push(udp.local_addr()?.port());
329        }
330        Ok(ports)
331    }
332    /// Writing `buf` to the target denoted by `route_key`
333    pub async fn send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
334        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
335        let len = self
336            .get_udp_from_route(&route_key)?
337            .send_to(buf, route_key.addr())
338            .await?;
339        if len == 0 {
340            return Err(std::io::Error::from(io::ErrorKind::WriteZero));
341        }
342        Ok(())
343    }
344
345    /// Try to write `buf` to the target denoted by `route_key`
346    pub fn try_send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
347        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
348        let len = self
349            .get_udp_from_route(&route_key)?
350            .try_send_to(buf, route_key.addr())?;
351        if len == 0 {
352            return Err(std::io::Error::from(io::ErrorKind::WriteZero));
353        }
354        Ok(())
355    }
356
357    pub async fn batch_send_to<T, D: ToRouteKeyForUdp<T>>(
358        &self,
359        bufs: &[IoSlice<'_>],
360        dest: D,
361    ) -> io::Result<()> {
362        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
363        let udp = self.get_udp_from_route(&route_key)?;
364        for buf in bufs {
365            let len = udp.send_to(buf, route_key.addr()).await?;
366            if len == 0 {
367                return Err(std::io::Error::from(io::ErrorKind::WriteZero));
368            }
369        }
370
371        Ok(())
372    }
373    fn get_sender(&self, route_key: &RouteKey) -> io::Result<Sender<(BytesMut, SocketAddr)>> {
374        if let Some(sender) = self.sender_map.get(&route_key.index()) {
375            Ok(sender.value().clone())
376        } else {
377            Err(io::Error::new(io::ErrorKind::NotFound, "route not found"))
378        }
379    }
380    pub async fn send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
381        &self,
382        buf: BytesMut,
383        dest: D,
384    ) -> io::Result<()> {
385        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
386        let sender = self.get_sender(&route_key)?;
387        if let Err(_e) = sender.send((buf, route_key.addr())).await {
388            Err(io::Error::from(io::ErrorKind::WriteZero))
389        } else {
390            Ok(())
391        }
392    }
393    pub fn try_send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
394        &self,
395        buf: BytesMut,
396        dest: D,
397    ) -> io::Result<()> {
398        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
399        let sender = self.get_sender(&route_key)?;
400        if let Err(e) = sender.try_send((buf, route_key.addr())) {
401            match e {
402                TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
403                TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
404            }
405        } else {
406            Ok(())
407        }
408    }
409
410    /// Send bytes to the target denoted by SocketAddr with every main underlying socket
411    pub async fn detect_pub_addrs<A: Into<SocketAddr>>(
412        &self,
413        buf: &[u8],
414        addr: A,
415    ) -> io::Result<()> {
416        let addr: SocketAddr = addr.into();
417        for index in 0..self.main_udp_v4_count() {
418            self.send_to(buf, (index, addr)).await?
419        }
420        Ok(())
421    }
422}
423
424pub struct UdpTunnelDispatcher {
425    tunnel_receiver: Receiver<InactiveUdpTunnel>,
426    pub(crate) socket_manager: Arc<UdpSocketManager>,
427    recycle_buf: Option<RecycleBuf>,
428}
429
430impl UdpTunnelDispatcher {
431    pub(crate) fn init(&self) -> io::Result<()> {
432        for (index, udp) in self.socket_manager.main_udp_v4.iter().enumerate() {
433            let udp = udp.clone();
434            let tunnel =
435                InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV4(index)), udp, None);
436            if self
437                .socket_manager
438                .tunnel_dispatcher
439                .try_send(tunnel)
440                .is_err()
441            {
442                Err(io::Error::other("tunnel channel error"))?
443            }
444        }
445        for (index, udp) in self.socket_manager.main_udp_v6.iter().enumerate() {
446            let udp = udp.clone();
447            let tunnel =
448                InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV6(index)), udp, None);
449            if self
450                .socket_manager
451                .tunnel_dispatcher
452                .try_send(tunnel)
453                .is_err()
454            {
455                Err(io::Error::other("tunnel channel error"))?
456            }
457        }
458        Ok(())
459    }
460}
461
462impl UdpTunnelDispatcher {
463    /// Construct a `UDP` tunnel with the specified configuration
464    pub fn new(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
465        create_tunnel_dispatcher(config)
466    }
467    /// Dispatch `UDP` tunnel from this kind dispatcher
468    pub async fn dispatch(&mut self) -> io::Result<UdpTunnel> {
469        let mut udp_tunnel = self
470            .tunnel_receiver
471            .recv()
472            .await
473            .map_err(|_| io::Error::other("Udp tunnel close"))?;
474        let option = self
475            .socket_manager
476            .sender_map
477            .get(&udp_tunnel.index)
478            .map(|v| v.value().clone());
479        let sender = if let Some(v) = option {
480            v
481        } else {
482            let (s, mut r) = tachyonix::channel(128);
483            let index = udp_tunnel.index;
484            let sender = s.clone();
485            self.socket_manager.sender_map.insert(index, s);
486
487            let socket_manager = self.socket_manager.clone();
488            let udp = udp_tunnel.udp.clone();
489            let recycle_buf = self.recycle_buf.clone();
490            tokio::spawn(async move {
491                #[cfg(any(target_os = "linux", target_os = "android"))]
492                let mut vec_buf = Vec::with_capacity(16);
493
494                while let Ok((buf, addr)) = r.recv().await {
495                    #[cfg(any(target_os = "linux", target_os = "android"))]
496                    {
497                        vec_buf.push((buf, addr));
498                        while let Ok(tup) = r.try_recv() {
499                            vec_buf.push(tup);
500                            if vec_buf.len() == MAX_MESSAGES {
501                                break;
502                            }
503                        }
504                        let mut bufs = &mut vec_buf[..];
505                        let fd = udp.as_raw_fd();
506                        loop {
507                            if bufs.len() == 1 {
508                                let (buf, addr) = unsafe { bufs.get_unchecked(0) };
509                                if let Err(e) = udp.send_to(buf, *addr).await {
510                                    log::warn!("send_to {addr:?},{e:?}")
511                                }
512                                break;
513                            } else {
514                                let rs = write_with(&udp, || sendmmsg(fd, bufs)).await;
515                                match rs {
516                                    Ok(size) => {
517                                        if size == 0 {
518                                            break;
519                                        }
520                                        if size < bufs.len() {
521                                            bufs = &mut bufs[size..];
522                                            continue;
523                                        }
524                                        break;
525                                    }
526                                    Err(e) => {
527                                        log::warn!("sendmmsg {e:?}");
528                                    }
529                                }
530                            }
531                        }
532                        if let Some(recycle_buf) = recycle_buf.as_ref() {
533                            while let Some((buf, _)) = vec_buf.pop() {
534                                recycle_buf.push(buf);
535                            }
536                        } else {
537                            vec_buf.clear();
538                        }
539                    }
540                    #[cfg(not(any(target_os = "linux", target_os = "android")))]
541                    {
542                        let rs = udp.send_to(&buf, addr).await;
543                        if let Some(recycle_buf) = recycle_buf.as_ref() {
544                            recycle_buf.push(buf);
545                        }
546                        if let Err(e) = rs {
547                            log::debug!("{addr:?},{e:?}")
548                        }
549                    }
550                }
551                socket_manager.sender_map.remove(&index);
552            });
553            sender
554        };
555        if udp_tunnel.sender.is_none() {
556            udp_tunnel.sender.replace(OwnedUdpTunnelSender { sender });
557        }
558        if udp_tunnel.reusable {
559            UdpTunnel::with_main(udp_tunnel, self.manager().tunnel_dispatcher.clone())
560        } else {
561            UdpTunnel::with_sub(udp_tunnel)
562        }
563    }
564    pub fn manager(&self) -> &Arc<UdpSocketManager> {
565        &self.socket_manager
566    }
567}
568
569#[cfg(any(target_os = "linux", target_os = "android"))]
570fn sendmmsg(fd: std::os::fd::RawFd, bufs: &mut [(BytesMut, SocketAddr)]) -> io::Result<usize> {
571    assert!(bufs.len() <= MAX_MESSAGES);
572    let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
573    let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
574    let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
575    for (i, (buf, addr)) in bufs.iter_mut().enumerate() {
576        addrs[i] = socket_addr_to_sockaddr(addr);
577        iov[i].iov_base = buf.as_mut_ptr() as *mut libc::c_void;
578        iov[i].iov_len = buf.len();
579        msgs[i].msg_hdr.msg_iov = &mut iov[i];
580        msgs[i].msg_hdr.msg_iovlen = 1;
581
582        msgs[i].msg_hdr.msg_name = &mut addrs[i] as *mut _ as *mut libc::c_void;
583        msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
584    }
585
586    unsafe {
587        let res = libc::sendmmsg(
588            fd,
589            msgs.as_mut_ptr(),
590            bufs.len() as _,
591            libc::MSG_DONTWAIT as _,
592        );
593        if res == -1 {
594            return Err(io::Error::last_os_error());
595        }
596        Ok(res as usize)
597    }
598}
599
600#[cfg(any(target_os = "linux", target_os = "android"))]
601fn socket_addr_to_sockaddr(addr: &SocketAddr) -> sockaddr_storage {
602    let mut storage: sockaddr_storage = unsafe { std::mem::zeroed() };
603
604    match addr {
605        SocketAddr::V4(v4_addr) => {
606            let sin = libc::sockaddr_in {
607                sin_family: libc::AF_INET as _,
608                sin_port: v4_addr.port().to_be(),
609                sin_addr: libc::in_addr {
610                    s_addr: u32::from_ne_bytes(v4_addr.ip().octets()), // IP 地址
611                },
612                sin_zero: [0; 8],
613            };
614
615            unsafe {
616                let sin_ptr = &sin as *const libc::sockaddr_in as *const u8;
617                let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
618                std::ptr::copy_nonoverlapping(
619                    sin_ptr,
620                    storage_ptr,
621                    std::mem::size_of::<libc::sockaddr>(),
622                );
623            }
624        }
625        SocketAddr::V6(v6_addr) => {
626            let sin6 = libc::sockaddr_in6 {
627                sin6_family: libc::AF_INET6 as _,
628                sin6_port: v6_addr.port().to_be(),
629                sin6_flowinfo: v6_addr.flowinfo(),
630                sin6_addr: libc::in6_addr {
631                    s6_addr: v6_addr.ip().octets(),
632                },
633                sin6_scope_id: v6_addr.scope_id(),
634            };
635
636            unsafe {
637                let sin6_ptr = &sin6 as *const libc::sockaddr_in6 as *const u8;
638                let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
639                std::ptr::copy_nonoverlapping(
640                    sin6_ptr,
641                    storage_ptr,
642                    std::mem::size_of::<libc::sockaddr>(),
643                );
644            }
645        }
646    }
647    storage
648}
649
650pub struct UdpTunnel {
651    index: Index,
652    local_addr: SocketAddr,
653    udp: Option<Arc<UdpSocket>>,
654    close_notify: Option<async_broadcast::Receiver<()>>,
655    re_dispatcher: Option<Sender<InactiveUdpTunnel>>,
656    sender: Option<OwnedUdpTunnelSender>,
657}
658struct OwnedUdpTunnelSender {
659    sender: Sender<(BytesMut, SocketAddr)>,
660}
661#[derive(Clone)]
662pub struct WeakUdpTunnelSender {
663    sender: Sender<(BytesMut, SocketAddr)>,
664}
665struct InactiveUdpTunnel {
666    reusable: bool,
667    index: Index,
668    udp: Arc<UdpSocket>,
669    close_notify: Option<async_broadcast::Receiver<()>>,
670    sender: Option<OwnedUdpTunnelSender>,
671}
672impl InactiveUdpTunnel {
673    fn new(
674        reusable: bool,
675        index: Index,
676        udp: Arc<UdpSocket>,
677        close_notify: Option<async_broadcast::Receiver<()>>,
678    ) -> Self {
679        Self {
680            reusable,
681            index,
682            udp,
683            close_notify,
684            sender: None,
685        }
686    }
687    fn redistribute(index: Index, udp: Arc<UdpSocket>, sender: OwnedUdpTunnelSender) -> Self {
688        Self {
689            reusable: true,
690            index,
691            udp,
692            close_notify: None,
693            sender: Some(sender),
694        }
695    }
696}
697impl OwnedUdpTunnelSender {
698    async fn send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
699        if buf.is_empty() {
700            return Ok(());
701        }
702        self.sender
703            .send((buf, dest.into()))
704            .await
705            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
706    }
707    fn is_closed(&self) -> bool {
708        self.sender.is_closed()
709    }
710}
711impl WeakUdpTunnelSender {
712    pub async fn send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
713        if buf.is_empty() {
714            return Ok(());
715        }
716        self.sender
717            .send((buf, dest.into()))
718            .await
719            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
720    }
721    pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
722        if buf.is_empty() {
723            return Ok(());
724        }
725        self.sender
726            .try_send((buf, dest.into()))
727            .map_err(|e| match e {
728                TrySendError::Full(_) => io::Error::from(io::ErrorKind::WouldBlock),
729                TrySendError::Closed(_) => io::Error::from(io::ErrorKind::WriteZero),
730            })
731    }
732}
733impl Drop for OwnedUdpTunnelSender {
734    fn drop(&mut self) {
735        self.sender.close();
736    }
737}
738impl Drop for UdpTunnel {
739    fn drop(&mut self) {
740        let Some(sender) = self.sender.take() else {
741            return;
742        };
743        if sender.is_closed() {
744            return;
745        }
746        let Some(udp) = self.udp.take() else {
747            return;
748        };
749        let Some(re_dispatcher) = self.re_dispatcher.take() else {
750            return;
751        };
752        let rs = re_dispatcher.try_send(InactiveUdpTunnel::redistribute(self.index, udp, sender));
753        if let Err(TrySendError::Full(_)) = rs {
754            log::warn!("Udp Tunnel TrySendError full");
755        }
756    }
757}
758
759impl UdpTunnel {
760    fn with_sub(inactive_udp_tunnel: InactiveUdpTunnel) -> io::Result<Self> {
761        let local_addr = inactive_udp_tunnel.udp.local_addr()?;
762        Ok(Self {
763            index: inactive_udp_tunnel.index,
764            local_addr,
765            udp: Some(inactive_udp_tunnel.udp),
766            close_notify: inactive_udp_tunnel.close_notify,
767            re_dispatcher: None,
768            sender: inactive_udp_tunnel.sender,
769        })
770    }
771    fn with_main(
772        inactive_udp_tunnel: InactiveUdpTunnel,
773        re_sender: Sender<InactiveUdpTunnel>,
774    ) -> io::Result<Self> {
775        let local_addr = inactive_udp_tunnel.udp.local_addr()?;
776        Ok(Self {
777            local_addr,
778            index: inactive_udp_tunnel.index,
779            udp: Some(inactive_udp_tunnel.udp),
780            close_notify: None,
781            re_dispatcher: Some(re_sender),
782            sender: inactive_udp_tunnel.sender,
783        })
784    }
785    pub fn done(&mut self) {
786        _ = self.udp.take();
787        _ = self.close_notify.take();
788        _ = self.re_dispatcher.take();
789        _ = self.re_dispatcher.take();
790        _ = self.sender.take();
791    }
792    pub fn local_addr(&self) -> SocketAddr {
793        self.local_addr
794    }
795    pub fn sender(&self) -> io::Result<WeakUdpTunnelSender> {
796        if let Some(v) = &self.sender {
797            Ok(WeakUdpTunnelSender {
798                sender: v.sender.clone(),
799            })
800        } else {
801            Err(io::Error::other("closed"))
802        }
803    }
804}
805
806impl UdpTunnel {
807    /// Writing `buf` to the target denoted by SocketAddr via this tunnel
808    pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
809        if let Some(udp) = &self.udp {
810            udp.send_to(buf, addr.into()).await?;
811            Ok(())
812        } else {
813            Err(io::Error::other("closed"))
814        }
815    }
816    /// Try to write `buf` to the target denoted by SocketAddr via this tunnel
817    pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
818        if let Some(udp) = &self.udp {
819            udp.try_send_to(buf, addr.into())?;
820            Ok(())
821        } else {
822            Err(io::Error::other("closed"))
823        }
824    }
825    pub async fn send_bytes_to<A: Into<SocketAddr>>(
826        &self,
827        buf: BytesMut,
828        addr: A,
829    ) -> io::Result<()> {
830        if let Some(sender) = &self.sender {
831            sender.send_to(buf, addr).await
832        } else {
833            Err(io::Error::other("closed"))
834        }
835    }
836
837    /// Receving buf from this tunnel
838    /// `usize` in the `Ok` branch indicates how many bytes are received
839    /// `RouteKey` in the `Ok` branch denotes the source where these bytes are received from
840    pub async fn recv_from(&mut self, buf: &mut [u8]) -> Option<io::Result<(usize, RouteKey)>> {
841        let udp = if let Some(udp) = &self.udp {
842            udp
843        } else {
844            return None;
845        };
846        loop {
847            if let Some(close_notify) = &mut self.close_notify {
848                tokio::select! {
849                    _rs=close_notify.recv()=>{
850                         self.done();
851                         return None
852                    }
853                    result=udp.recv_from(buf)=>{
854                         let (len, addr) = match result {
855                            Ok(rs) => rs,
856                            Err(e) => {
857                                if should_ignore_error(&e) {
858                                    continue;
859                                }
860                                return Some(Err(e))
861                            }
862                         };
863                         return Some(Ok((len, RouteKey::new(self.index, addr))))
864                    }
865                }
866            } else {
867                let (len, addr) = match udp.recv_from(buf).await {
868                    Ok(rs) => rs,
869                    Err(e) => {
870                        if should_ignore_error(&e) {
871                            continue;
872                        }
873                        return Some(Err(e));
874                    }
875                };
876                return Some(Ok((len, RouteKey::new(self.index, addr))));
877            }
878        }
879    }
880    #[cfg(not(any(target_os = "linux", target_os = "android")))]
881    pub async fn batch_recv_from<B: AsMut<[u8]>>(
882        &mut self,
883        bufs: &mut [B],
884        sizes: &mut [usize],
885        addrs: &mut [RouteKey],
886    ) -> Option<io::Result<usize>> {
887        if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
888            return Some(Err(io::Error::other("bufs error")));
889        }
890        let rs = self.recv_from(bufs[0].as_mut()).await?;
891        match rs {
892            Ok((len, addr)) => {
893                let udp = self.udp.as_ref()?;
894                sizes[0] = len;
895                addrs[0] = addr;
896                let mut num = 1;
897                while num < bufs.len() {
898                    match udp.try_recv_from(bufs[num].as_mut()) {
899                        Ok((len, addr)) => {
900                            sizes[num] = len;
901                            addrs[num] = RouteKey::new(self.index, addr);
902                            num += 1;
903                        }
904                        Err(_) => break,
905                    }
906                }
907                Some(Ok(num))
908            }
909            Err(e) => Some(Err(e)),
910        }
911    }
912    #[cfg(any(target_os = "linux", target_os = "android"))]
913    pub async fn batch_recv_from<B: AsMut<[u8]>>(
914        &mut self,
915        bufs: &mut [B],
916        sizes: &mut [usize],
917        addrs: &mut [RouteKey],
918    ) -> Option<io::Result<usize>> {
919        if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
920            return Some(Err(io::Error::other("bufs/sizes/addrs error")));
921        }
922        let udp = self.udp.as_ref()?;
923        let fd = udp.as_raw_fd();
924        loop {
925            let rs = if let Some(close_notify) = &mut self.close_notify {
926                tokio::select! {
927                    _rs=close_notify.recv()=>{
928                        self.done();
929                        return None
930                    }
931                    rs=read_with(udp,|| recvmmsg(self.index, fd, bufs, sizes, addrs))=>{
932                        rs
933                    }
934                }
935            } else {
936                read_with(udp, || recvmmsg(self.index, fd, bufs, sizes, addrs)).await
937            };
938            return match rs {
939                Ok(size) => Some(Ok(size)),
940                Err(e) => {
941                    if should_ignore_error(&e) {
942                        continue;
943                    }
944                    Some(Err(e))
945                }
946            };
947        }
948    }
949}
950
951#[cfg(any(target_os = "linux", target_os = "android"))]
952fn recvmmsg<B: AsMut<[u8]>>(
953    index: Index,
954    fd: std::os::fd::RawFd,
955    bufs: &mut [B],
956    sizes: &mut [usize],
957    route_keys: &mut [RouteKey],
958) -> io::Result<usize> {
959    let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
960    let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
961    let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
962    let max_num = bufs.len().min(MAX_MESSAGES);
963    for i in 0..max_num {
964        iov[i].iov_base = bufs[i].as_mut().as_mut_ptr() as *mut libc::c_void;
965        iov[i].iov_len = bufs[i].as_mut().len();
966        msgs[i].msg_hdr.msg_iov = &mut iov[i];
967        msgs[i].msg_hdr.msg_iovlen = 1;
968        msgs[i].msg_hdr.msg_name = &mut addrs[i] as *const _ as *mut libc::c_void;
969        msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
970    }
971    let res = unsafe {
972        libc::recvmmsg(
973            fd,
974            msgs.as_mut_ptr(),
975            max_num as c_uint,
976            libc::MSG_DONTWAIT as _,
977            std::ptr::null_mut(),
978        )
979    };
980    if res == -1 {
981        return Err(io::Error::last_os_error());
982    }
983    let nmsgs = res as usize;
984    if nmsgs == 0 {
985        return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
986    }
987    for i in 0..nmsgs {
988        let addr = sockaddr_to_socket_addr(&addrs[i], msgs[i].msg_hdr.msg_namelen);
989        sizes[i] = msgs[i].msg_len as usize;
990        route_keys[i] = RouteKey::new(index, addr);
991    }
992    Ok(nmsgs)
993}
994#[cfg(any(target_os = "linux", target_os = "android"))]
995fn sockaddr_to_socket_addr(addr: &sockaddr_storage, _len: socklen_t) -> SocketAddr {
996    match addr.ss_family as libc::c_int {
997        libc::AF_INET => {
998            let addr_in = unsafe { *(addr as *const _ as *const libc::sockaddr_in) };
999            let ip = u32::from_be(addr_in.sin_addr.s_addr);
1000            let port = u16::from_be(addr_in.sin_port);
1001            SocketAddr::V4(std::net::SocketAddrV4::new(
1002                std::net::Ipv4Addr::from(ip),
1003                port,
1004            ))
1005        }
1006        libc::AF_INET6 => {
1007            let addr_in6 = unsafe { *(addr as *const _ as *const libc::sockaddr_in6) };
1008            let ip = std::net::Ipv6Addr::from(addr_in6.sin6_addr.s6_addr);
1009            let port = u16::from_be(addr_in6.sin6_port);
1010            SocketAddr::V6(std::net::SocketAddrV6::new(ip, port, 0, 0))
1011        }
1012        _ => panic!("Unsupported address family"),
1013    }
1014}
1015
1016fn should_ignore_error(e: &io::Error) -> bool {
1017    #[cfg(windows)]
1018    {
1019        // 检查错误码是否为 WSAECONNRESET
1020        if let Some(os_error) = e.raw_os_error() {
1021            return os_error == windows_sys::Win32::Networking::WinSock::WSAECONNRESET;
1022        }
1023    }
1024    _ = e;
1025    false
1026}
1027
1028#[cfg(test)]
1029mod tests {
1030    use std::time::Duration;
1031
1032    use crate::tunnel::udp::{Model, UdpTunnel};
1033
1034    #[tokio::test]
1035    pub async fn create_udp_tunnel() {
1036        let config = crate::tunnel::config::UdpTunnelConfig::default()
1037            .set_main_udp_count(2)
1038            .set_sub_udp_count(10)
1039            .set_model(Model::Low)
1040            .set_use_v6(false);
1041        let mut udp_tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1042        let mut count = 0;
1043        let mut join = Vec::new();
1044        while let Ok(rs) =
1045            tokio::time::timeout(Duration::from_secs(1), udp_tunnel_factory.dispatch()).await
1046        {
1047            join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1048            count += 1;
1049        }
1050        assert_eq!(count, 2)
1051    }
1052
1053    #[tokio::test]
1054    pub async fn create_sub_udp_tunnel() {
1055        let config = crate::tunnel::config::UdpTunnelConfig::default()
1056            .set_main_udp_count(2)
1057            .set_sub_udp_count(10)
1058            .set_use_v6(false)
1059            .set_model(Model::High);
1060        let mut tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1061        let mut count = 0;
1062        let mut join = Vec::new();
1063        while let Ok(rs) =
1064            tokio::time::timeout(Duration::from_secs(1), tunnel_factory.dispatch()).await
1065        {
1066            join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1067            count += 1;
1068        }
1069        tunnel_factory.manager().switch_low();
1070
1071        let mut close_tunnel_count = 0;
1072        for x in join {
1073            let rs = tokio::time::timeout(Duration::from_secs(1), x).await;
1074            match rs {
1075                Ok(rs) => {
1076                    if rs.unwrap() {
1077                        // tunnel task done
1078                        close_tunnel_count += 1;
1079                    }
1080                }
1081                Err(_e) => {
1082                    _ = _e;
1083                }
1084            }
1085        }
1086        assert_eq!(count, 12);
1087        assert_eq!(close_tunnel_count, 10);
1088    }
1089
1090    async fn tunnel_recv(mut tunnel: UdpTunnel) -> bool {
1091        let mut buf = [0; 1400];
1092        tunnel.recv_from(&mut buf).await.is_none()
1093    }
1094}