rust_p2p_core/tunnel/
udp.rs

1use std::io;
2use std::io::IoSlice;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6#[cfg(any(target_os = "linux", target_os = "android"))]
7use tokio::io::Interest;
8
9#[cfg(any(target_os = "linux", target_os = "android"))]
10pub async fn read_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
11    udp.async_io(Interest::READABLE, op).await
12}
13#[cfg(any(target_os = "linux", target_os = "android"))]
14pub async fn write_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
15    udp.async_io(Interest::WRITABLE, op).await
16}
17
18use bytes::BytesMut;
19use dashmap::DashMap;
20use parking_lot::{Mutex, RwLock};
21use tachyonix::{Receiver, Sender, TrySendError};
22use tokio::net::UdpSocket;
23
24use crate::route::{Index, RouteKey};
25use crate::socket::{bind_udp, LocalInterface};
26use crate::tunnel::config::UdpTunnelConfig;
27use crate::tunnel::recycle::RecycleBuf;
28use crate::tunnel::{DEFAULT_ADDRESS_V4, DEFAULT_ADDRESS_V6};
29
30#[cfg(any(target_os = "linux", target_os = "android"))]
31const MAX_MESSAGES: usize = 16;
32#[cfg(any(target_os = "linux", target_os = "android"))]
33use libc::{c_uint, iovec, mmsghdr, sockaddr_storage, socklen_t};
34#[cfg(any(target_os = "linux", target_os = "android"))]
35use std::os::fd::AsRawFd;
36
37#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
38pub enum Model {
39    High,
40    #[default]
41    Low,
42}
43
44impl Model {
45    pub fn is_low(&self) -> bool {
46        self == &Model::Low
47    }
48    pub fn is_high(&self) -> bool {
49        self == &Model::High
50    }
51}
52
53#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
54pub enum UDPIndex {
55    MainV4(usize),
56    MainV6(usize),
57    SubV4(usize),
58}
59
60impl UDPIndex {
61    pub(crate) fn index(&self) -> usize {
62        match self {
63            UDPIndex::MainV4(i) => *i,
64            UDPIndex::MainV6(i) => *i,
65            UDPIndex::SubV4(i) => *i,
66        }
67    }
68}
69
70pub trait ToRouteKeyForUdp<T> {
71    fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey>;
72}
73
74impl ToRouteKeyForUdp<()> for RouteKey {
75    fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
76        Ok(dest)
77    }
78}
79
80impl ToRouteKeyForUdp<()> for &RouteKey {
81    fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
82        Ok(*dest)
83    }
84}
85
86impl ToRouteKeyForUdp<()> for &mut RouteKey {
87    fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
88        Ok(*dest)
89    }
90}
91
92impl<S: Into<SocketAddr>> ToRouteKeyForUdp<()> for S {
93    fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
94        let addr = dest.into();
95        socket_manager.generate_route_key_from_addr(0, addr)
96    }
97}
98
99impl<S: Into<SocketAddr>> ToRouteKeyForUdp<usize> for (usize, S) {
100    fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
101        let (index, addr) = dest;
102        socket_manager.generate_route_key_from_addr(index, addr.into())
103    }
104}
105
106/// initialize udp tunnel by config
107pub(crate) fn create_tunnel_dispatcher(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
108    config.check()?;
109    let mut udp_ports = config.udp_ports;
110    udp_ports.resize(config.main_udp_count, 0);
111    let mut main_udp_v4: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
112    let mut main_udp_v6: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
113    // 因为在mac上v4和v6的对绑定网卡的处理不同,所以这里分开监听,并且分开监听更容易处理发送目标为v4的情况,因为双协议栈下发送v4目标需要转换成v6
114    for port in &udp_ports {
115        loop {
116            let mut addr_v4 = DEFAULT_ADDRESS_V4;
117            addr_v4.set_port(*port);
118            let socket_v4 = bind_udp(addr_v4, config.default_interface.as_ref())?;
119            let udp_v4: std::net::UdpSocket = socket_v4.into();
120            if config.use_v6 {
121                let mut addr_v6 = DEFAULT_ADDRESS_V6;
122                let socket_v6 = if *port == 0 {
123                    let port = udp_v4.local_addr()?.port();
124                    addr_v6.set_port(port);
125                    match bind_udp(addr_v6, config.default_interface.as_ref()) {
126                        Ok(socket_v6) => socket_v6,
127                        Err(_) => continue,
128                    }
129                } else {
130                    addr_v6.set_port(*port);
131                    bind_udp(addr_v6, config.default_interface.as_ref())?
132                };
133                let udp_v6: std::net::UdpSocket = socket_v6.into();
134                main_udp_v6.push(Arc::new(UdpSocket::from_std(udp_v6)?))
135            }
136            main_udp_v4.push(Arc::new(UdpSocket::from_std(udp_v4)?));
137            break;
138        }
139    }
140    let (tunnel_sender, tunnel_receiver) =
141        tachyonix::channel(config.main_udp_count * 2 + config.sub_udp_count * 2);
142    let socket_manager = Arc::new(UdpSocketManager {
143        main_udp_v4,
144        main_udp_v6,
145        sub_udp: RwLock::new(Vec::with_capacity(config.sub_udp_count)),
146        sub_close_notify: Default::default(),
147        tunnel_dispatcher: tunnel_sender,
148        sub_udp_num: config.sub_udp_count,
149        default_interface: config.default_interface,
150        sender_map: Default::default(),
151    });
152    let tunnel_factory = UdpTunnelDispatcher {
153        tunnel_receiver,
154        socket_manager,
155        recycle_buf: config.recycle_buf,
156    };
157    tunnel_factory.init()?;
158    tunnel_factory.socket_manager.switch_model(config.model)?;
159    Ok(tunnel_factory)
160}
161
162pub struct UdpSocketManager {
163    main_udp_v4: Vec<Arc<UdpSocket>>,
164    main_udp_v6: Vec<Arc<UdpSocket>>,
165    sub_udp: RwLock<Vec<Arc<UdpSocket>>>,
166    sub_close_notify: Mutex<Option<async_broadcast::Sender<()>>>,
167    tunnel_dispatcher: Sender<InactiveUdpTunnel>,
168    sub_udp_num: usize,
169    default_interface: Option<LocalInterface>,
170    sender_map: DashMap<Index, Sender<(BytesMut, SocketAddr)>>,
171}
172
173impl UdpSocketManager {
174    pub(crate) fn try_sub_batch_send_to(&self, buf: &[u8], addr: SocketAddr) {
175        for (i, udp) in self.sub_udp.read().iter().enumerate() {
176            if let Err(e) = udp.try_send_to(buf, addr) {
177                log::info!("try_sub_send_to_addr_v4: {e:?},{i},{addr}")
178            }
179        }
180    }
181    pub(crate) fn try_main_v4_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
182        let len = self.main_udp_v4_count();
183        self.try_main_batch_send_to_impl(buf, addr, len);
184    }
185    pub(crate) fn try_main_v6_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
186        let len = self.main_udp_v6_count();
187        self.try_main_batch_send_to_impl(buf, addr, len);
188    }
189
190    pub(crate) fn try_main_batch_send_to_impl(&self, buf: &[u8], addr: &[SocketAddr], len: usize) {
191        for (i, addr) in addr.iter().enumerate() {
192            if let Err(e) = self.try_send_to(buf, (i % len, *addr)) {
193                log::info!("try_main_send_to_addr: {e:?},{},{addr}", i % len);
194            }
195        }
196    }
197    pub(crate) fn generate_route_key_from_addr(
198        &self,
199        index: usize,
200        addr: SocketAddr,
201    ) -> io::Result<RouteKey> {
202        let route_key = if addr.is_ipv4() {
203            let len = self.main_udp_v4.len();
204            if index >= len {
205                return Err(io::Error::new(io::ErrorKind::Other, "index out of bounds"));
206            }
207            RouteKey::new(Index::Udp(UDPIndex::MainV4(index)), addr)
208        } else {
209            let len = self.main_udp_v6.len();
210            if len == 0 {
211                return Err(io::Error::new(io::ErrorKind::Other, "Not support IPV6"));
212            }
213            if index >= len {
214                return Err(io::Error::new(io::ErrorKind::Other, "index out of bounds"));
215            }
216            RouteKey::new(Index::Udp(UDPIndex::MainV6(index)), addr)
217        };
218        Ok(route_key)
219    }
220    pub(crate) fn switch_low(&self) {
221        let mut guard = self.sub_udp.write();
222        if guard.is_empty() {
223            return;
224        }
225        guard.clear();
226        if let Some(sub_close_notify) = self.sub_close_notify.lock().take() {
227            let _ = sub_close_notify.close();
228        }
229    }
230    pub(crate) fn switch_high(&self) -> io::Result<()> {
231        let mut guard = self.sub_udp.write();
232        if !guard.is_empty() {
233            return Ok(());
234        }
235        let mut sub_close_notify_guard = self.sub_close_notify.lock();
236        if let Some(sender) = sub_close_notify_guard.take() {
237            let _ = sender.close();
238        }
239        let (sub_close_notify_sender, sub_close_notify_receiver) = async_broadcast::broadcast(2);
240        let mut sub_udp_list = Vec::with_capacity(self.sub_udp_num);
241        for _ in 0..self.sub_udp_num {
242            let udp = bind_udp(DEFAULT_ADDRESS_V4, self.default_interface.as_ref())?;
243            let udp: std::net::UdpSocket = udp.into();
244            sub_udp_list.push(Arc::new(UdpSocket::from_std(udp)?));
245        }
246        for (index, udp) in sub_udp_list.iter().enumerate() {
247            let udp = udp.clone();
248            let udp_tunnel = InactiveUdpTunnel::new(
249                false,
250                Index::Udp(UDPIndex::SubV4(index)),
251                udp,
252                Some(sub_close_notify_receiver.clone()),
253            );
254            if self.tunnel_dispatcher.try_send(udp_tunnel).is_err() {
255                Err(io::Error::new(io::ErrorKind::Other, "tunnel channel error"))?
256            }
257        }
258        sub_close_notify_guard.replace(sub_close_notify_sender);
259        *guard = sub_udp_list;
260        Ok(())
261    }
262
263    #[inline]
264    fn get_udp(&self, udp_index: UDPIndex) -> io::Result<Arc<UdpSocket>> {
265        Ok(match udp_index {
266            UDPIndex::MainV4(index) => self
267                .main_udp_v4
268                .get(index)
269                .ok_or(io::Error::new(io::ErrorKind::Other, "index out of bounds"))?
270                .clone(),
271            UDPIndex::MainV6(index) => self
272                .main_udp_v6
273                .get(index)
274                .ok_or(io::Error::new(io::ErrorKind::Other, "index out of bounds"))?
275                .clone(),
276            UDPIndex::SubV4(index) => {
277                let guard = self.sub_udp.read();
278                let len = guard.len();
279                if len <= index {
280                    return Err(io::Error::new(io::ErrorKind::Other, "index out of bounds"));
281                } else {
282                    guard[index].clone()
283                }
284            }
285        })
286    }
287
288    #[inline]
289    fn get_udp_from_route(&self, route_key: &RouteKey) -> io::Result<Arc<UdpSocket>> {
290        Ok(match route_key.index() {
291            Index::Udp(index) => self.get_udp(index)?,
292            _ => return Err(io::Error::from(io::ErrorKind::InvalidInput)),
293        })
294    }
295}
296
297impl UdpSocketManager {
298    pub fn model(&self) -> Model {
299        if self.sub_udp.read().is_empty() {
300            Model::Low
301        } else {
302            Model::High
303        }
304    }
305
306    #[inline]
307    pub fn main_udp_v4_count(&self) -> usize {
308        self.main_udp_v4.len()
309    }
310    #[inline]
311    pub fn main_udp_v6_count(&self) -> usize {
312        self.main_udp_v6.len()
313    }
314
315    pub fn switch_model(&self, model: Model) -> io::Result<()> {
316        match model {
317            Model::High => self.switch_high(),
318            Model::Low => {
319                self.switch_low();
320                Ok(())
321            }
322        }
323    }
324    /// Acquire the local ports `UDP` sockets bind on
325    pub fn local_ports(&self) -> io::Result<Vec<u16>> {
326        let mut ports = Vec::with_capacity(self.main_udp_v4_count());
327        for udp in &self.main_udp_v4 {
328            ports.push(udp.local_addr()?.port());
329        }
330        Ok(ports)
331    }
332    /// Writing `buf` to the target denoted by `route_key`
333    pub async fn send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
334        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
335        let len = self
336            .get_udp_from_route(&route_key)?
337            .send_to(buf, route_key.addr())
338            .await?;
339        if len == 0 {
340            return Err(std::io::Error::from(io::ErrorKind::WriteZero));
341        }
342        Ok(())
343    }
344
345    /// Try to write `buf` to the target denoted by `route_key`
346    pub fn try_send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
347        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
348        let len = self
349            .get_udp_from_route(&route_key)?
350            .try_send_to(buf, route_key.addr())?;
351        if len == 0 {
352            return Err(std::io::Error::from(io::ErrorKind::WriteZero));
353        }
354        Ok(())
355    }
356
357    pub async fn batch_send_to<T, D: ToRouteKeyForUdp<T>>(
358        &self,
359        bufs: &[IoSlice<'_>],
360        dest: D,
361    ) -> io::Result<()> {
362        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
363        let udp = self.get_udp_from_route(&route_key)?;
364        for buf in bufs {
365            let len = udp.send_to(buf, route_key.addr()).await?;
366            if len == 0 {
367                return Err(std::io::Error::from(io::ErrorKind::WriteZero));
368            }
369        }
370
371        Ok(())
372    }
373    fn get_sender(&self, route_key: &RouteKey) -> io::Result<Sender<(BytesMut, SocketAddr)>> {
374        if let Some(sender) = self.sender_map.get(&route_key.index()) {
375            Ok(sender.value().clone())
376        } else {
377            Err(io::Error::new(io::ErrorKind::NotFound, "route not found"))
378        }
379    }
380    pub async fn send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
381        &self,
382        buf: BytesMut,
383        dest: D,
384    ) -> io::Result<()> {
385        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
386        let sender = self.get_sender(&route_key)?;
387        if let Err(_e) = sender.send((buf, route_key.addr())).await {
388            Err(io::Error::from(io::ErrorKind::WriteZero))
389        } else {
390            Ok(())
391        }
392    }
393    pub fn try_send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
394        &self,
395        buf: BytesMut,
396        dest: D,
397    ) -> io::Result<()> {
398        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
399        let sender = self.get_sender(&route_key)?;
400        if let Err(e) = sender.try_send((buf, route_key.addr())) {
401            match e {
402                TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
403                TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
404            }
405        } else {
406            Ok(())
407        }
408    }
409
410    /// Send bytes to the target denoted by SocketAddr with every main underlying socket
411    pub async fn detect_pub_addrs<A: Into<SocketAddr>>(
412        &self,
413        buf: &[u8],
414        addr: A,
415    ) -> io::Result<()> {
416        let addr: SocketAddr = addr.into();
417        for index in 0..self.main_udp_v4_count() {
418            self.send_to(buf, (index, addr)).await?
419        }
420        Ok(())
421    }
422}
423
424pub struct UdpTunnelDispatcher {
425    tunnel_receiver: Receiver<InactiveUdpTunnel>,
426    pub(crate) socket_manager: Arc<UdpSocketManager>,
427    recycle_buf: Option<RecycleBuf>,
428}
429
430impl UdpTunnelDispatcher {
431    pub(crate) fn init(&self) -> io::Result<()> {
432        for (index, udp) in self.socket_manager.main_udp_v4.iter().enumerate() {
433            let udp = udp.clone();
434            let tunnel =
435                InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV4(index)), udp, None);
436            if self
437                .socket_manager
438                .tunnel_dispatcher
439                .try_send(tunnel)
440                .is_err()
441            {
442                Err(io::Error::new(io::ErrorKind::Other, "tunnel channel error"))?
443            }
444        }
445        for (index, udp) in self.socket_manager.main_udp_v6.iter().enumerate() {
446            let udp = udp.clone();
447            let tunnel =
448                InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV6(index)), udp, None);
449            if self
450                .socket_manager
451                .tunnel_dispatcher
452                .try_send(tunnel)
453                .is_err()
454            {
455                Err(io::Error::new(io::ErrorKind::Other, "tunnel channel error"))?
456            }
457        }
458        Ok(())
459    }
460}
461
462impl UdpTunnelDispatcher {
463    /// Construct a `UDP` tunnel with the specified configuration
464    pub fn new(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
465        create_tunnel_dispatcher(config)
466    }
467    /// Dispatch `UDP` tunnel from this kind dispatcher
468    pub async fn dispatch(&mut self) -> io::Result<UdpTunnel> {
469        let mut udp_tunnel = self
470            .tunnel_receiver
471            .recv()
472            .await
473            .map_err(|_| io::Error::new(io::ErrorKind::Other, "Udp tunnel close"))?;
474        let option = self
475            .socket_manager
476            .sender_map
477            .get(&udp_tunnel.index)
478            .map(|v| v.value().clone());
479        let sender = if let Some(v) = option {
480            v
481        } else {
482            let (s, mut r) = tachyonix::channel(128);
483            let index = udp_tunnel.index;
484            let sender = s.clone();
485            self.socket_manager.sender_map.insert(index, s);
486
487            let socket_manager = self.socket_manager.clone();
488            let udp = udp_tunnel.udp.clone();
489            let recycle_buf = self.recycle_buf.clone();
490            tokio::spawn(async move {
491                #[cfg(any(target_os = "linux", target_os = "android"))]
492                let mut vec_buf = Vec::with_capacity(16);
493
494                while let Ok((buf, addr)) = r.recv().await {
495                    #[cfg(any(target_os = "linux", target_os = "android"))]
496                    {
497                        vec_buf.push((buf, addr));
498                        while let Ok(tup) = r.try_recv() {
499                            vec_buf.push(tup);
500                            if vec_buf.len() == MAX_MESSAGES {
501                                break;
502                            }
503                        }
504                        let mut bufs = &mut vec_buf[..];
505                        let fd = udp.as_raw_fd();
506                        loop {
507                            if bufs.len() == 1 {
508                                let (buf, addr) = unsafe { bufs.get_unchecked(0) };
509                                if let Err(e) = udp.send_to(buf, *addr).await {
510                                    log::warn!("send_to {addr:?},{e:?}")
511                                }
512                                break;
513                            } else {
514                                let rs = write_with(&udp, || sendmmsg(fd, bufs)).await;
515                                match rs {
516                                    Ok(size) => {
517                                        if size < bufs.len() {
518                                            bufs = &mut bufs[size..];
519                                            continue;
520                                        }
521                                        break;
522                                    }
523                                    Err(e) => {
524                                        log::warn!("sendmmsg {e:?}");
525                                    }
526                                }
527                            }
528                        }
529                        if let Some(recycle_buf) = recycle_buf.as_ref() {
530                            while let Some((buf, _)) = vec_buf.pop() {
531                                recycle_buf.push(buf);
532                            }
533                        } else {
534                            vec_buf.clear();
535                        }
536                    }
537                    #[cfg(not(any(target_os = "linux", target_os = "android")))]
538                    {
539                        let rs = udp.send_to(&buf, addr).await;
540                        if let Some(recycle_buf) = recycle_buf.as_ref() {
541                            recycle_buf.push(buf);
542                        }
543                        if let Err(e) = rs {
544                            log::debug!("{addr:?},{e:?}")
545                        }
546                    }
547                }
548                socket_manager.sender_map.remove(&index);
549            });
550            sender
551        };
552        if udp_tunnel.sender.is_none() {
553            udp_tunnel.sender.replace(OwnedUdpTunnelSender { sender });
554        }
555        if udp_tunnel.reusable {
556            UdpTunnel::with_main(udp_tunnel, self.manager().tunnel_dispatcher.clone())
557        } else {
558            UdpTunnel::with_sub(udp_tunnel)
559        }
560    }
561    pub fn manager(&self) -> &Arc<UdpSocketManager> {
562        &self.socket_manager
563    }
564}
565
566#[cfg(any(target_os = "linux", target_os = "android"))]
567fn sendmmsg(fd: std::os::fd::RawFd, bufs: &mut [(BytesMut, SocketAddr)]) -> io::Result<usize> {
568    assert!(bufs.len() <= MAX_MESSAGES);
569    let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
570    let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
571    let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
572    for (i, (buf, addr)) in bufs.iter_mut().enumerate() {
573        addrs[i] = socket_addr_to_sockaddr(addr);
574        iov[i].iov_base = buf.as_mut_ptr() as *mut libc::c_void;
575        iov[i].iov_len = buf.len();
576        msgs[i].msg_hdr.msg_iov = &mut iov[i];
577        msgs[i].msg_hdr.msg_iovlen = 1;
578
579        msgs[i].msg_hdr.msg_name = &mut addrs[i] as *mut _ as *mut libc::c_void;
580        msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
581    }
582
583    unsafe {
584        let res = libc::sendmmsg(
585            fd,
586            msgs.as_mut_ptr(),
587            bufs.len() as _,
588            libc::MSG_DONTWAIT as _,
589        );
590        if res == -1 {
591            return Err(io::Error::last_os_error());
592        }
593        Ok(res as usize)
594    }
595}
596
597#[cfg(any(target_os = "linux", target_os = "android"))]
598fn socket_addr_to_sockaddr(addr: &SocketAddr) -> sockaddr_storage {
599    let mut storage: sockaddr_storage = unsafe { std::mem::zeroed() };
600
601    match addr {
602        SocketAddr::V4(v4_addr) => {
603            let sin = libc::sockaddr_in {
604                sin_family: libc::AF_INET as _,
605                sin_port: v4_addr.port().to_be(),
606                sin_addr: libc::in_addr {
607                    s_addr: u32::from_ne_bytes(v4_addr.ip().octets()), // IP 地址
608                },
609                sin_zero: [0; 8],
610            };
611
612            unsafe {
613                let sin_ptr = &sin as *const libc::sockaddr_in as *const u8;
614                let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
615                std::ptr::copy_nonoverlapping(
616                    sin_ptr,
617                    storage_ptr,
618                    std::mem::size_of::<libc::sockaddr>(),
619                );
620            }
621        }
622        SocketAddr::V6(v6_addr) => {
623            let sin6 = libc::sockaddr_in6 {
624                sin6_family: libc::AF_INET6 as _,
625                sin6_port: v6_addr.port().to_be(),
626                sin6_flowinfo: v6_addr.flowinfo(),
627                sin6_addr: libc::in6_addr {
628                    s6_addr: v6_addr.ip().octets(),
629                },
630                sin6_scope_id: v6_addr.scope_id(),
631            };
632
633            unsafe {
634                let sin6_ptr = &sin6 as *const libc::sockaddr_in6 as *const u8;
635                let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
636                std::ptr::copy_nonoverlapping(
637                    sin6_ptr,
638                    storage_ptr,
639                    std::mem::size_of::<libc::sockaddr>(),
640                );
641            }
642        }
643    }
644    storage
645}
646
647pub struct UdpTunnel {
648    index: Index,
649    local_addr: SocketAddr,
650    udp: Option<Arc<UdpSocket>>,
651    close_notify: Option<async_broadcast::Receiver<()>>,
652    re_dispatcher: Option<Sender<InactiveUdpTunnel>>,
653    sender: Option<OwnedUdpTunnelSender>,
654}
655struct OwnedUdpTunnelSender {
656    sender: Sender<(BytesMut, SocketAddr)>,
657}
658#[derive(Clone)]
659pub struct WeakUdpTunnelSender {
660    sender: Sender<(BytesMut, SocketAddr)>,
661}
662struct InactiveUdpTunnel {
663    reusable: bool,
664    index: Index,
665    udp: Arc<UdpSocket>,
666    close_notify: Option<async_broadcast::Receiver<()>>,
667    sender: Option<OwnedUdpTunnelSender>,
668}
669impl InactiveUdpTunnel {
670    fn new(
671        reusable: bool,
672        index: Index,
673        udp: Arc<UdpSocket>,
674        close_notify: Option<async_broadcast::Receiver<()>>,
675    ) -> Self {
676        Self {
677            reusable,
678            index,
679            udp,
680            close_notify,
681            sender: None,
682        }
683    }
684    fn redistribute(index: Index, udp: Arc<UdpSocket>, sender: OwnedUdpTunnelSender) -> Self {
685        Self {
686            reusable: true,
687            index,
688            udp,
689            close_notify: None,
690            sender: Some(sender),
691        }
692    }
693}
694impl OwnedUdpTunnelSender {
695    async fn send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
696        if buf.is_empty() {
697            return Ok(());
698        }
699        self.sender
700            .send((buf, dest.into()))
701            .await
702            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
703    }
704    fn is_closed(&self) -> bool {
705        self.sender.is_closed()
706    }
707}
708impl WeakUdpTunnelSender {
709    pub async fn send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
710        if buf.is_empty() {
711            return Ok(());
712        }
713        self.sender
714            .send((buf, dest.into()))
715            .await
716            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
717    }
718    pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
719        if buf.is_empty() {
720            return Ok(());
721        }
722        self.sender
723            .try_send((buf, dest.into()))
724            .map_err(|e| match e {
725                TrySendError::Full(_) => io::Error::from(io::ErrorKind::WouldBlock),
726                TrySendError::Closed(_) => io::Error::from(io::ErrorKind::WriteZero),
727            })
728    }
729}
730impl Drop for OwnedUdpTunnelSender {
731    fn drop(&mut self) {
732        self.sender.close();
733    }
734}
735impl Drop for UdpTunnel {
736    fn drop(&mut self) {
737        let Some(sender) = self.sender.take() else {
738            return;
739        };
740        if sender.is_closed() {
741            return;
742        }
743        let Some(udp) = self.udp.take() else {
744            return;
745        };
746        let Some(re_dispatcher) = self.re_dispatcher.take() else {
747            return;
748        };
749        let rs = re_dispatcher.try_send(InactiveUdpTunnel::redistribute(self.index, udp, sender));
750        if let Err(TrySendError::Full(_)) = rs {
751            log::warn!("Udp Tunnel TrySendError full");
752        }
753    }
754}
755
756impl UdpTunnel {
757    fn with_sub(inactive_udp_tunnel: InactiveUdpTunnel) -> io::Result<Self> {
758        let local_addr = inactive_udp_tunnel.udp.local_addr()?;
759        Ok(Self {
760            index: inactive_udp_tunnel.index,
761            local_addr,
762            udp: Some(inactive_udp_tunnel.udp),
763            close_notify: inactive_udp_tunnel.close_notify,
764            re_dispatcher: None,
765            sender: inactive_udp_tunnel.sender,
766        })
767    }
768    fn with_main(
769        inactive_udp_tunnel: InactiveUdpTunnel,
770        re_sender: Sender<InactiveUdpTunnel>,
771    ) -> io::Result<Self> {
772        let local_addr = inactive_udp_tunnel.udp.local_addr()?;
773        Ok(Self {
774            local_addr,
775            index: inactive_udp_tunnel.index,
776            udp: Some(inactive_udp_tunnel.udp),
777            close_notify: None,
778            re_dispatcher: Some(re_sender),
779            sender: inactive_udp_tunnel.sender,
780        })
781    }
782    pub fn done(&mut self) {
783        _ = self.udp.take();
784        _ = self.close_notify.take();
785        _ = self.re_dispatcher.take();
786        _ = self.re_dispatcher.take();
787        _ = self.sender.take();
788    }
789    pub fn local_addr(&self) -> SocketAddr {
790        self.local_addr
791    }
792    pub fn sender(&self) -> io::Result<WeakUdpTunnelSender> {
793        if let Some(v) = &self.sender {
794            Ok(WeakUdpTunnelSender {
795                sender: v.sender.clone(),
796            })
797        } else {
798            Err(io::Error::new(io::ErrorKind::Other, "closed"))
799        }
800    }
801}
802
803impl UdpTunnel {
804    /// Writing `buf` to the target denoted by SocketAddr via this tunnel
805    pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
806        if let Some(udp) = &self.udp {
807            udp.send_to(buf, addr.into()).await?;
808            Ok(())
809        } else {
810            Err(io::Error::new(io::ErrorKind::Other, "closed"))
811        }
812    }
813    /// Try to write `buf` to the target denoted by SocketAddr via this tunnel
814    pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
815        if let Some(udp) = &self.udp {
816            udp.try_send_to(buf, addr.into())?;
817            Ok(())
818        } else {
819            Err(io::Error::new(io::ErrorKind::Other, "closed"))
820        }
821    }
822    pub async fn send_bytes_to<A: Into<SocketAddr>>(
823        &self,
824        buf: BytesMut,
825        addr: A,
826    ) -> io::Result<()> {
827        if let Some(sender) = &self.sender {
828            sender.send_to(buf, addr).await
829        } else {
830            Err(io::Error::new(io::ErrorKind::Other, "closed"))
831        }
832    }
833
834    /// Receving buf from this tunnel
835    /// `usize` in the `Ok` branch indicates how many bytes are received
836    /// `RouteKey` in the `Ok` branch denotes the source where these bytes are received from
837    pub async fn recv_from(&mut self, buf: &mut [u8]) -> Option<io::Result<(usize, RouteKey)>> {
838        let udp = if let Some(udp) = &self.udp {
839            udp
840        } else {
841            return None;
842        };
843        loop {
844            if let Some(close_notify) = &mut self.close_notify {
845                tokio::select! {
846                    _rs=close_notify.recv()=>{
847                         self.done();
848                         return None
849                    }
850                    result=udp.recv_from(buf)=>{
851                         let (len, addr) = match result {
852                            Ok(rs) => rs,
853                            Err(e) => {
854                                if should_ignore_error(&e) {
855                                    continue;
856                                }
857                                return Some(Err(e))
858                            }
859                         };
860                         return Some(Ok((len, RouteKey::new(self.index, addr))))
861                    }
862                }
863            } else {
864                let (len, addr) = match udp.recv_from(buf).await {
865                    Ok(rs) => rs,
866                    Err(e) => {
867                        if should_ignore_error(&e) {
868                            continue;
869                        }
870                        return Some(Err(e));
871                    }
872                };
873                return Some(Ok((len, RouteKey::new(self.index, addr))));
874            }
875        }
876    }
877    #[cfg(not(any(target_os = "linux", target_os = "android")))]
878    pub async fn batch_recv_from<B: AsMut<[u8]>>(
879        &mut self,
880        bufs: &mut [B],
881        sizes: &mut [usize],
882        addrs: &mut [RouteKey],
883    ) -> Option<io::Result<usize>> {
884        if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
885            return Some(Err(io::Error::new(io::ErrorKind::Other, "bufs error")));
886        }
887        let rs = self.recv_from(bufs[0].as_mut()).await?;
888        match rs {
889            Ok((len, addr)) => {
890                let udp = self.udp.as_ref()?;
891                sizes[0] = len;
892                addrs[0] = addr;
893                let mut num = 1;
894                while num < bufs.len() {
895                    match udp.try_recv_from(bufs[num].as_mut()) {
896                        Ok((len, addr)) => {
897                            sizes[num] = len;
898                            addrs[num] = RouteKey::new(self.index, addr);
899                            num += 1;
900                        }
901                        Err(_) => break,
902                    }
903                }
904                Some(Ok(num))
905            }
906            Err(e) => Some(Err(e)),
907        }
908    }
909    #[cfg(any(target_os = "linux", target_os = "android"))]
910    pub async fn batch_recv_from<B: AsMut<[u8]>>(
911        &mut self,
912        bufs: &mut [B],
913        sizes: &mut [usize],
914        addrs: &mut [RouteKey],
915    ) -> Option<io::Result<usize>> {
916        if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
917            return Some(Err(io::Error::new(
918                io::ErrorKind::Other,
919                "bufs/sizes/addrs error",
920            )));
921        }
922        let udp = self.udp.as_ref()?;
923        let fd = udp.as_raw_fd();
924        loop {
925            let rs = if let Some(close_notify) = &mut self.close_notify {
926                tokio::select! {
927                    _rs=close_notify.recv()=>{
928                        self.done();
929                        return None
930                    }
931                    rs=read_with(udp,|| recvmmsg(self.index, fd, bufs, sizes, addrs))=>{
932                        rs
933                    }
934                }
935            } else {
936                read_with(udp, || recvmmsg(self.index, fd, bufs, sizes, addrs)).await
937            };
938            return match rs {
939                Ok(size) => Some(Ok(size)),
940                Err(e) => {
941                    if should_ignore_error(&e) {
942                        continue;
943                    }
944                    Some(Err(e))
945                }
946            };
947        }
948    }
949}
950
951#[cfg(any(target_os = "linux", target_os = "android"))]
952fn recvmmsg<B: AsMut<[u8]>>(
953    index: Index,
954    fd: std::os::fd::RawFd,
955    bufs: &mut [B],
956    sizes: &mut [usize],
957    route_keys: &mut [RouteKey],
958) -> io::Result<usize> {
959    let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
960    let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
961    let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
962    let max_num = bufs.len().min(MAX_MESSAGES);
963    for i in 0..max_num {
964        iov[i].iov_base = bufs[i].as_mut().as_mut_ptr() as *mut libc::c_void;
965        iov[i].iov_len = bufs[i].as_mut().len();
966        msgs[i].msg_hdr.msg_iov = &mut iov[i];
967        msgs[i].msg_hdr.msg_iovlen = 1;
968        msgs[i].msg_hdr.msg_name = &mut addrs[i] as *const _ as *mut libc::c_void;
969        msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
970    }
971    let res = unsafe {
972        libc::recvmmsg(
973            fd,
974            msgs.as_mut_ptr(),
975            max_num as c_uint,
976            libc::MSG_DONTWAIT as _,
977            std::ptr::null_mut(),
978        )
979    };
980    if res == -1 {
981        return Err(io::Error::last_os_error());
982    }
983    let nmsgs = res as usize;
984    for i in 0..nmsgs {
985        let addr = sockaddr_to_socket_addr(&addrs[i], msgs[i].msg_hdr.msg_namelen);
986        sizes[i] = msgs[i].msg_len as usize;
987        route_keys[i] = RouteKey::new(index, addr);
988    }
989    Ok(nmsgs)
990}
991#[cfg(any(target_os = "linux", target_os = "android"))]
992fn sockaddr_to_socket_addr(addr: &sockaddr_storage, _len: socklen_t) -> SocketAddr {
993    match addr.ss_family as libc::c_int {
994        libc::AF_INET => {
995            let addr_in = unsafe { *(addr as *const _ as *const libc::sockaddr_in) };
996            let ip = u32::from_be(addr_in.sin_addr.s_addr);
997            let port = u16::from_be(addr_in.sin_port);
998            SocketAddr::V4(std::net::SocketAddrV4::new(
999                std::net::Ipv4Addr::from(ip),
1000                port,
1001            ))
1002        }
1003        libc::AF_INET6 => {
1004            let addr_in6 = unsafe { *(addr as *const _ as *const libc::sockaddr_in6) };
1005            let ip = std::net::Ipv6Addr::from(addr_in6.sin6_addr.s6_addr);
1006            let port = u16::from_be(addr_in6.sin6_port);
1007            SocketAddr::V6(std::net::SocketAddrV6::new(ip, port, 0, 0))
1008        }
1009        _ => panic!("Unsupported address family"),
1010    }
1011}
1012
1013fn should_ignore_error(e: &io::Error) -> bool {
1014    #[cfg(windows)]
1015    {
1016        // 检查错误码是否为 WSAECONNRESET
1017        if let Some(os_error) = e.raw_os_error() {
1018            return os_error == windows_sys::Win32::Networking::WinSock::WSAECONNRESET;
1019        }
1020    }
1021    _ = e;
1022    false
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027    use std::time::Duration;
1028
1029    use crate::tunnel::udp::{Model, UdpTunnel};
1030
1031    #[tokio::test]
1032    pub async fn create_udp_tunnel() {
1033        let config = crate::tunnel::config::UdpTunnelConfig::default()
1034            .set_main_udp_count(2)
1035            .set_sub_udp_count(10)
1036            .set_model(Model::Low)
1037            .set_use_v6(false);
1038        let mut udp_tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1039        let mut count = 0;
1040        let mut join = Vec::new();
1041        while let Ok(rs) =
1042            tokio::time::timeout(Duration::from_secs(1), udp_tunnel_factory.dispatch()).await
1043        {
1044            join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1045            count += 1;
1046        }
1047        assert_eq!(count, 2)
1048    }
1049
1050    #[tokio::test]
1051    pub async fn create_sub_udp_tunnel() {
1052        let config = crate::tunnel::config::UdpTunnelConfig::default()
1053            .set_main_udp_count(2)
1054            .set_sub_udp_count(10)
1055            .set_use_v6(false)
1056            .set_model(Model::High);
1057        let mut tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1058        let mut count = 0;
1059        let mut join = Vec::new();
1060        while let Ok(rs) =
1061            tokio::time::timeout(Duration::from_secs(1), tunnel_factory.dispatch()).await
1062        {
1063            join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1064            count += 1;
1065        }
1066        tunnel_factory.manager().switch_low();
1067
1068        let mut close_tunnel_count = 0;
1069        for x in join {
1070            let rs = tokio::time::timeout(Duration::from_secs(1), x).await;
1071            match rs {
1072                Ok(rs) => {
1073                    if rs.unwrap() {
1074                        // tunnel task done
1075                        close_tunnel_count += 1;
1076                    }
1077                }
1078                Err(_e) => {
1079                    _ = _e;
1080                }
1081            }
1082        }
1083        assert_eq!(count, 12);
1084        assert_eq!(close_tunnel_count, 10);
1085    }
1086
1087    async fn tunnel_recv(mut tunnel: UdpTunnel) -> bool {
1088        let mut buf = [0; 1400];
1089        tunnel.recv_from(&mut buf).await.is_none()
1090    }
1091}