Skip to main content

rust_p2p_core/tunnel/
udp.rs

1use std::io;
2use std::io::IoSlice;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6#[cfg(any(target_os = "linux", target_os = "android"))]
7use tokio::io::Interest;
8
9#[cfg(any(target_os = "linux", target_os = "android"))]
10pub async fn read_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
11    udp.async_io(Interest::READABLE, op).await
12}
13#[cfg(any(target_os = "linux", target_os = "android"))]
14pub async fn write_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
15    udp.async_io(Interest::WRITABLE, op).await
16}
17
18use bytes::Bytes;
19use dashmap::DashMap;
20use parking_lot::{Mutex, RwLock};
21use tachyonix::{Receiver, Sender, TrySendError};
22use tokio::net::UdpSocket;
23
24use crate::route::{Index, RouteKey};
25use crate::socket::{bind_udp, LocalInterface};
26use crate::tunnel::config::UdpTunnelConfig;
27use crate::tunnel::{DEFAULT_ADDRESS_V4, DEFAULT_ADDRESS_V6};
28
29#[cfg(any(target_os = "linux", target_os = "android"))]
30const MAX_MESSAGES: usize = 16;
31#[cfg(any(target_os = "linux", target_os = "android"))]
32use libc::{c_uint, iovec, mmsghdr, sockaddr_storage, socklen_t};
33#[cfg(any(target_os = "linux", target_os = "android"))]
34use std::os::fd::AsRawFd;
35
36#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
37pub enum Model {
38    High,
39    #[default]
40    Low,
41}
42
43impl Model {
44    pub fn is_low(&self) -> bool {
45        self == &Model::Low
46    }
47    pub fn is_high(&self) -> bool {
48        self == &Model::High
49    }
50}
51
52#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
53pub enum UDPIndex {
54    MainV4(usize),
55    MainV6(usize),
56    SubV4(usize),
57}
58
59impl UDPIndex {
60    pub(crate) fn index(&self) -> usize {
61        match self {
62            UDPIndex::MainV4(i) => *i,
63            UDPIndex::MainV6(i) => *i,
64            UDPIndex::SubV4(i) => *i,
65        }
66    }
67}
68
69pub trait ToRouteKeyForUdp<T> {
70    fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey>;
71}
72
73impl ToRouteKeyForUdp<()> for RouteKey {
74    fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
75        Ok(dest)
76    }
77}
78
79impl ToRouteKeyForUdp<()> for &RouteKey {
80    fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
81        Ok(*dest)
82    }
83}
84
85impl ToRouteKeyForUdp<()> for &mut RouteKey {
86    fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
87        Ok(*dest)
88    }
89}
90
91impl<S: Into<SocketAddr>> ToRouteKeyForUdp<()> for S {
92    fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
93        let addr = dest.into();
94        socket_manager.generate_route_key_from_addr(0, addr)
95    }
96}
97
98impl<S: Into<SocketAddr>> ToRouteKeyForUdp<usize> for (usize, S) {
99    fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
100        let (index, addr) = dest;
101        socket_manager.generate_route_key_from_addr(index, addr.into())
102    }
103}
104
105/// initialize udp tunnel by config
106pub(crate) fn create_tunnel_dispatcher(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
107    config.check()?;
108    let mut udp_ports = config.udp_ports;
109    udp_ports.resize(config.main_udp_count, 0);
110    let mut main_udp_v4: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
111    let mut main_udp_v6: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
112    // 因为在mac上v4和v6的对绑定网卡的处理不同,所以这里分开监听,并且分开监听更容易处理发送目标为v4的情况,因为双协议栈下发送v4目标需要转换成v6
113    for port in &udp_ports {
114        loop {
115            let mut addr_v4 = DEFAULT_ADDRESS_V4;
116            addr_v4.set_port(*port);
117            let socket_v4 = bind_udp(addr_v4, config.default_interface.as_ref())?;
118            let udp_v4: std::net::UdpSocket = socket_v4.into();
119            if config.use_v6 {
120                let mut addr_v6 = DEFAULT_ADDRESS_V6;
121                let socket_v6 = if *port == 0 {
122                    let port = udp_v4.local_addr()?.port();
123                    addr_v6.set_port(port);
124                    match bind_udp(addr_v6, config.default_interface.as_ref()) {
125                        Ok(socket_v6) => socket_v6,
126                        Err(_) => continue,
127                    }
128                } else {
129                    addr_v6.set_port(*port);
130                    bind_udp(addr_v6, config.default_interface.as_ref())?
131                };
132                let udp_v6: std::net::UdpSocket = socket_v6.into();
133                main_udp_v6.push(Arc::new(UdpSocket::from_std(udp_v6)?))
134            }
135            main_udp_v4.push(Arc::new(UdpSocket::from_std(udp_v4)?));
136            break;
137        }
138    }
139    let (tunnel_sender, tunnel_receiver) =
140        tachyonix::channel(config.main_udp_count * 2 + config.sub_udp_count * 2);
141    let socket_manager = Arc::new(UdpSocketManager {
142        main_udp_v4,
143        main_udp_v6,
144        sub_udp: RwLock::new(Vec::with_capacity(config.sub_udp_count)),
145        sub_close_notify: Default::default(),
146        tunnel_dispatcher: tunnel_sender,
147        sub_udp_num: config.sub_udp_count,
148        default_interface: config.default_interface,
149        sender_map: Default::default(),
150    });
151    let tunnel_factory = UdpTunnelDispatcher {
152        tunnel_receiver,
153        socket_manager,
154    };
155    tunnel_factory.init()?;
156    tunnel_factory.socket_manager.switch_model(config.model)?;
157    Ok(tunnel_factory)
158}
159
160pub struct UdpSocketManager {
161    main_udp_v4: Vec<Arc<UdpSocket>>,
162    main_udp_v6: Vec<Arc<UdpSocket>>,
163    sub_udp: RwLock<Vec<Arc<UdpSocket>>>,
164    sub_close_notify: Mutex<Option<async_broadcast::Sender<()>>>,
165    tunnel_dispatcher: Sender<InactiveUdpTunnel>,
166    sub_udp_num: usize,
167    default_interface: Option<LocalInterface>,
168    sender_map: DashMap<Index, Sender<(Bytes, SocketAddr)>>,
169}
170
171impl UdpSocketManager {
172    pub(crate) fn try_sub_batch_send_to(&self, buf: &[u8], addr: SocketAddr) {
173        for (i, udp) in self.sub_udp.read().iter().enumerate() {
174            if let Err(e) = udp.try_send_to(buf, addr) {
175                log::info!("try_sub_send_to_addr_v4: {e:?},{i},{addr}")
176            }
177        }
178    }
179    pub(crate) fn try_main_v4_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
180        let len = self.main_udp_v4_count();
181        self.try_main_batch_send_to_impl(buf, addr, len);
182    }
183    pub(crate) fn try_main_v6_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
184        let len = self.main_udp_v6_count();
185        self.try_main_batch_send_to_impl(buf, addr, len);
186    }
187
188    pub(crate) fn try_main_batch_send_to_impl(&self, buf: &[u8], addr: &[SocketAddr], len: usize) {
189        for (i, addr) in addr.iter().enumerate() {
190            if let Err(e) = self.try_send_to(buf, (i % len, *addr)) {
191                log::info!("try_main_send_to_addr: {e:?},{},{addr}", i % len);
192            }
193        }
194    }
195    pub(crate) fn generate_route_key_from_addr(
196        &self,
197        index: usize,
198        addr: SocketAddr,
199    ) -> io::Result<RouteKey> {
200        let route_key = if addr.is_ipv4() {
201            let len = self.main_udp_v4.len();
202            if index >= len {
203                return Err(io::Error::other("index out of bounds"));
204            }
205            RouteKey::new(Index::Udp(UDPIndex::MainV4(index)), addr)
206        } else {
207            let len = self.main_udp_v6.len();
208            if len == 0 {
209                return Err(io::Error::other("Not support IPV6"));
210            }
211            if index >= len {
212                return Err(io::Error::other("index out of bounds"));
213            }
214            RouteKey::new(Index::Udp(UDPIndex::MainV6(index)), addr)
215        };
216        Ok(route_key)
217    }
218    pub(crate) fn switch_low(&self) {
219        let mut guard = self.sub_udp.write();
220        if guard.is_empty() {
221            return;
222        }
223        guard.clear();
224        if let Some(sub_close_notify) = self.sub_close_notify.lock().take() {
225            let _ = sub_close_notify.close();
226        }
227    }
228    pub(crate) fn switch_high(&self) -> io::Result<()> {
229        let mut guard = self.sub_udp.write();
230        if !guard.is_empty() {
231            return Ok(());
232        }
233        let mut sub_close_notify_guard = self.sub_close_notify.lock();
234        if let Some(sender) = sub_close_notify_guard.take() {
235            let _ = sender.close();
236        }
237        let (sub_close_notify_sender, sub_close_notify_receiver) = async_broadcast::broadcast(2);
238        let mut sub_udp_list = Vec::with_capacity(self.sub_udp_num);
239        for _ in 0..self.sub_udp_num {
240            let udp = bind_udp(DEFAULT_ADDRESS_V4, self.default_interface.as_ref())?;
241            let udp: std::net::UdpSocket = udp.into();
242            sub_udp_list.push(Arc::new(UdpSocket::from_std(udp)?));
243        }
244        for (index, udp) in sub_udp_list.iter().enumerate() {
245            let udp = udp.clone();
246            let udp_tunnel = InactiveUdpTunnel::new(
247                false,
248                Index::Udp(UDPIndex::SubV4(index)),
249                udp,
250                Some(sub_close_notify_receiver.clone()),
251            );
252            if self.tunnel_dispatcher.try_send(udp_tunnel).is_err() {
253                Err(io::Error::other("tunnel channel error"))?
254            }
255        }
256        sub_close_notify_guard.replace(sub_close_notify_sender);
257        *guard = sub_udp_list;
258        Ok(())
259    }
260
261    #[inline]
262    fn get_udp(&self, udp_index: UDPIndex) -> io::Result<Arc<UdpSocket>> {
263        Ok(match udp_index {
264            UDPIndex::MainV4(index) => self
265                .main_udp_v4
266                .get(index)
267                .ok_or(io::Error::other("index out of bounds"))?
268                .clone(),
269            UDPIndex::MainV6(index) => self
270                .main_udp_v6
271                .get(index)
272                .ok_or(io::Error::other("index out of bounds"))?
273                .clone(),
274            UDPIndex::SubV4(index) => {
275                let guard = self.sub_udp.read();
276                let len = guard.len();
277                if len <= index {
278                    return Err(io::Error::other("index out of bounds"));
279                } else {
280                    guard[index].clone()
281                }
282            }
283        })
284    }
285
286    #[inline]
287    fn get_udp_from_route(&self, route_key: &RouteKey) -> io::Result<Arc<UdpSocket>> {
288        Ok(match route_key.index() {
289            Index::Udp(index) => self.get_udp(index)?,
290            _ => return Err(io::Error::from(io::ErrorKind::InvalidInput)),
291        })
292    }
293}
294
295impl UdpSocketManager {
296    pub fn model(&self) -> Model {
297        if self.sub_udp.read().is_empty() {
298            Model::Low
299        } else {
300            Model::High
301        }
302    }
303
304    #[inline]
305    pub fn main_udp_v4_count(&self) -> usize {
306        self.main_udp_v4.len()
307    }
308    #[inline]
309    pub fn main_udp_v6_count(&self) -> usize {
310        self.main_udp_v6.len()
311    }
312
313    pub fn switch_model(&self, model: Model) -> io::Result<()> {
314        match model {
315            Model::High => self.switch_high(),
316            Model::Low => {
317                self.switch_low();
318                Ok(())
319            }
320        }
321    }
322    /// Acquire the local ports `UDP` sockets bind on
323    pub fn local_ports(&self) -> io::Result<Vec<u16>> {
324        let mut ports = Vec::with_capacity(self.main_udp_v4_count());
325        for udp in &self.main_udp_v4 {
326            ports.push(udp.local_addr()?.port());
327        }
328        Ok(ports)
329    }
330    /// Writing `buf` to the target denoted by `route_key`
331    pub async fn send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
332        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
333        let len = self
334            .get_udp_from_route(&route_key)?
335            .send_to(buf, route_key.addr())
336            .await?;
337        if len == 0 {
338            return Err(std::io::Error::from(io::ErrorKind::WriteZero));
339        }
340        Ok(())
341    }
342
343    /// Try to write `buf` to the target denoted by `route_key`
344    pub fn try_send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
345        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
346        let len = self
347            .get_udp_from_route(&route_key)?
348            .try_send_to(buf, route_key.addr())?;
349        if len == 0 {
350            return Err(std::io::Error::from(io::ErrorKind::WriteZero));
351        }
352        Ok(())
353    }
354
355    pub async fn batch_send_to<T, D: ToRouteKeyForUdp<T>>(
356        &self,
357        bufs: &[IoSlice<'_>],
358        dest: D,
359    ) -> io::Result<()> {
360        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
361        let udp = self.get_udp_from_route(&route_key)?;
362        for buf in bufs {
363            let len = udp.send_to(buf, route_key.addr()).await?;
364            if len == 0 {
365                return Err(std::io::Error::from(io::ErrorKind::WriteZero));
366            }
367        }
368
369        Ok(())
370    }
371    fn get_sender(&self, route_key: &RouteKey) -> io::Result<Sender<(Bytes, SocketAddr)>> {
372        if let Some(sender) = self.sender_map.get(&route_key.index()) {
373            Ok(sender.value().clone())
374        } else {
375            Err(io::Error::new(io::ErrorKind::NotFound, "route not found"))
376        }
377    }
378    pub async fn send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
379        &self,
380        buf: Bytes,
381        dest: D,
382    ) -> io::Result<()> {
383        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
384        let sender = self.get_sender(&route_key)?;
385        if let Err(_e) = sender.send((buf, route_key.addr())).await {
386            Err(io::Error::from(io::ErrorKind::WriteZero))
387        } else {
388            Ok(())
389        }
390    }
391    pub fn try_send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
392        &self,
393        buf: Bytes,
394        dest: D,
395    ) -> io::Result<()> {
396        let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
397        let sender = self.get_sender(&route_key)?;
398        if let Err(e) = sender.try_send((buf, route_key.addr())) {
399            match e {
400                TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
401                TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
402            }
403        } else {
404            Ok(())
405        }
406    }
407
408    /// Send bytes to the target denoted by SocketAddr with every main underlying socket
409    pub async fn detect_pub_addrs<A: Into<SocketAddr>>(
410        &self,
411        buf: &[u8],
412        addr: A,
413    ) -> io::Result<()> {
414        let addr: SocketAddr = addr.into();
415        for index in 0..self.main_udp_v4_count() {
416            self.send_to(buf, (index, addr)).await?
417        }
418        Ok(())
419    }
420}
421
422pub struct UdpTunnelDispatcher {
423    tunnel_receiver: Receiver<InactiveUdpTunnel>,
424    pub(crate) socket_manager: Arc<UdpSocketManager>,
425}
426
427impl UdpTunnelDispatcher {
428    pub(crate) fn init(&self) -> io::Result<()> {
429        for (index, udp) in self.socket_manager.main_udp_v4.iter().enumerate() {
430            let udp = udp.clone();
431            let tunnel =
432                InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV4(index)), udp, None);
433            if self
434                .socket_manager
435                .tunnel_dispatcher
436                .try_send(tunnel)
437                .is_err()
438            {
439                Err(io::Error::other("tunnel channel error"))?
440            }
441        }
442        for (index, udp) in self.socket_manager.main_udp_v6.iter().enumerate() {
443            let udp = udp.clone();
444            let tunnel =
445                InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV6(index)), udp, None);
446            if self
447                .socket_manager
448                .tunnel_dispatcher
449                .try_send(tunnel)
450                .is_err()
451            {
452                Err(io::Error::other("tunnel channel error"))?
453            }
454        }
455        Ok(())
456    }
457}
458
459impl UdpTunnelDispatcher {
460    /// Construct a `UDP` tunnel with the specified configuration
461    pub fn new(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
462        create_tunnel_dispatcher(config)
463    }
464    /// Dispatch `UDP` tunnel from this kind dispatcher
465    pub async fn dispatch(&mut self) -> io::Result<UdpTunnel> {
466        let mut udp_tunnel = self
467            .tunnel_receiver
468            .recv()
469            .await
470            .map_err(|_| io::Error::other("Udp tunnel close"))?;
471        let option = self
472            .socket_manager
473            .sender_map
474            .get(&udp_tunnel.index)
475            .map(|v| v.value().clone());
476        let sender = if let Some(v) = option {
477            v
478        } else {
479            let (s, mut r) = tachyonix::channel(128);
480            let index = udp_tunnel.index;
481            let sender = s.clone();
482            self.socket_manager.sender_map.insert(index, s);
483
484            let socket_manager = self.socket_manager.clone();
485            let udp = udp_tunnel.udp.clone();
486            tokio::spawn(async move {
487                #[cfg(all(feature = "sendmmsg", any(target_os = "linux", target_os = "android")))]
488                let mut vec_buf = Vec::with_capacity(16);
489
490                while let Ok((buf, addr)) = r.recv().await {
491                    #[cfg(all(
492                        feature = "sendmmsg",
493                        any(target_os = "linux", target_os = "android")
494                    ))]
495                    {
496                        vec_buf.push((buf, addr));
497                        while let Ok(tup) = r.try_recv() {
498                            vec_buf.push(tup);
499                            if vec_buf.len() == MAX_MESSAGES {
500                                break;
501                            }
502                        }
503                        let mut bufs = &mut vec_buf[..];
504                        let fd = udp.as_raw_fd();
505                        loop {
506                            if bufs.len() == 1 {
507                                let (buf, addr) = unsafe { bufs.get_unchecked(0) };
508                                if let Err(e) = udp.send_to(buf, *addr).await {
509                                    log::warn!("send_to {addr:?},{e:?}")
510                                }
511                                break;
512                            } else {
513                                let rs = write_with(&udp, || sendmmsg(fd, bufs)).await;
514                                match rs {
515                                    Ok(size) => {
516                                        if size == 0 {
517                                            break;
518                                        }
519                                        if size < bufs.len() {
520                                            bufs = &mut bufs[size..];
521                                            continue;
522                                        }
523                                        break;
524                                    }
525                                    Err(e) => {
526                                        log::warn!("sendmmsg {e:?}");
527                                    }
528                                }
529                            }
530                        }
531                        vec_buf.clear();
532                    }
533                    #[cfg(any(
534                        not(any(target_os = "linux", target_os = "android")),
535                        not(feature = "sendmmsg")
536                    ))]
537                    {
538                        let rs = udp.send_to(&buf, addr).await;
539                        if let Err(e) = rs {
540                            log::debug!("{addr:?},{e:?}")
541                        }
542                    }
543                }
544                socket_manager.sender_map.remove(&index);
545            });
546            sender
547        };
548        if udp_tunnel.sender.is_none() {
549            udp_tunnel.sender.replace(OwnedUdpTunnelSender { sender });
550        }
551        if udp_tunnel.reusable {
552            UdpTunnel::with_main(udp_tunnel, self.manager().tunnel_dispatcher.clone())
553        } else {
554            UdpTunnel::with_sub(udp_tunnel)
555        }
556    }
557    pub fn manager(&self) -> &Arc<UdpSocketManager> {
558        &self.socket_manager
559    }
560}
561
562#[cfg(all(feature = "sendmmsg", any(target_os = "linux", target_os = "android")))]
563fn sendmmsg(fd: std::os::fd::RawFd, bufs: &mut [(Bytes, SocketAddr)]) -> io::Result<usize> {
564    assert!(bufs.len() <= MAX_MESSAGES);
565    let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
566    let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
567    let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
568    for (i, (buf, addr)) in bufs.iter_mut().enumerate() {
569        addrs[i] = socket_addr_to_sockaddr(addr);
570        iov[i].iov_base = buf.as_mut_ptr() as *mut libc::c_void;
571        iov[i].iov_len = buf.len();
572        msgs[i].msg_hdr.msg_iov = &mut iov[i];
573        msgs[i].msg_hdr.msg_iovlen = 1;
574
575        msgs[i].msg_hdr.msg_name = &mut addrs[i] as *mut _ as *mut libc::c_void;
576        msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
577    }
578
579    unsafe {
580        let res = libc::sendmmsg(
581            fd,
582            msgs.as_mut_ptr(),
583            bufs.len() as _,
584            libc::MSG_DONTWAIT as _,
585        );
586        if res == -1 {
587            return Err(io::Error::last_os_error());
588        }
589        Ok(res as usize)
590    }
591}
592
593#[cfg(all(feature = "sendmmsg", any(target_os = "linux", target_os = "android")))]
594fn socket_addr_to_sockaddr(addr: &SocketAddr) -> sockaddr_storage {
595    let mut storage: sockaddr_storage = unsafe { std::mem::zeroed() };
596
597    match addr {
598        SocketAddr::V4(v4_addr) => {
599            let sin = libc::sockaddr_in {
600                sin_family: libc::AF_INET as _,
601                sin_port: v4_addr.port().to_be(),
602                sin_addr: libc::in_addr {
603                    s_addr: u32::from_ne_bytes(v4_addr.ip().octets()), // IP 地址
604                },
605                sin_zero: [0; 8],
606            };
607
608            unsafe {
609                let sin_ptr = &sin as *const libc::sockaddr_in as *const u8;
610                let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
611                std::ptr::copy_nonoverlapping(
612                    sin_ptr,
613                    storage_ptr,
614                    std::mem::size_of::<libc::sockaddr>(),
615                );
616            }
617        }
618        SocketAddr::V6(v6_addr) => {
619            let sin6 = libc::sockaddr_in6 {
620                sin6_family: libc::AF_INET6 as _,
621                sin6_port: v6_addr.port().to_be(),
622                sin6_flowinfo: v6_addr.flowinfo(),
623                sin6_addr: libc::in6_addr {
624                    s6_addr: v6_addr.ip().octets(),
625                },
626                sin6_scope_id: v6_addr.scope_id(),
627            };
628
629            unsafe {
630                let sin6_ptr = &sin6 as *const libc::sockaddr_in6 as *const u8;
631                let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
632                std::ptr::copy_nonoverlapping(
633                    sin6_ptr,
634                    storage_ptr,
635                    std::mem::size_of::<libc::sockaddr>(),
636                );
637            }
638        }
639    }
640    storage
641}
642
643pub struct UdpTunnel {
644    index: Index,
645    local_addr: SocketAddr,
646    udp: Option<Arc<UdpSocket>>,
647    close_notify: Option<async_broadcast::Receiver<()>>,
648    re_dispatcher: Option<Sender<InactiveUdpTunnel>>,
649    sender: Option<OwnedUdpTunnelSender>,
650}
651struct OwnedUdpTunnelSender {
652    sender: Sender<(Bytes, SocketAddr)>,
653}
654#[derive(Clone)]
655pub struct WeakUdpTunnelSender {
656    sender: Sender<(Bytes, SocketAddr)>,
657}
658struct InactiveUdpTunnel {
659    reusable: bool,
660    index: Index,
661    udp: Arc<UdpSocket>,
662    close_notify: Option<async_broadcast::Receiver<()>>,
663    sender: Option<OwnedUdpTunnelSender>,
664}
665impl InactiveUdpTunnel {
666    fn new(
667        reusable: bool,
668        index: Index,
669        udp: Arc<UdpSocket>,
670        close_notify: Option<async_broadcast::Receiver<()>>,
671    ) -> Self {
672        Self {
673            reusable,
674            index,
675            udp,
676            close_notify,
677            sender: None,
678        }
679    }
680    fn redistribute(index: Index, udp: Arc<UdpSocket>, sender: OwnedUdpTunnelSender) -> Self {
681        Self {
682            reusable: true,
683            index,
684            udp,
685            close_notify: None,
686            sender: Some(sender),
687        }
688    }
689}
690impl OwnedUdpTunnelSender {
691    async fn send_to<A: Into<SocketAddr>>(&self, buf: Bytes, dest: A) -> io::Result<()> {
692        if buf.is_empty() {
693            return Ok(());
694        }
695        self.sender
696            .send((buf, dest.into()))
697            .await
698            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
699    }
700    fn is_closed(&self) -> bool {
701        self.sender.is_closed()
702    }
703}
704impl WeakUdpTunnelSender {
705    pub async fn send_to<A: Into<SocketAddr>>(&self, buf: Bytes, dest: A) -> io::Result<()> {
706        if buf.is_empty() {
707            return Ok(());
708        }
709        self.sender
710            .send((buf, dest.into()))
711            .await
712            .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
713    }
714    pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: Bytes, dest: A) -> io::Result<()> {
715        if buf.is_empty() {
716            return Ok(());
717        }
718        self.sender
719            .try_send((buf, dest.into()))
720            .map_err(|e| match e {
721                TrySendError::Full(_) => io::Error::from(io::ErrorKind::WouldBlock),
722                TrySendError::Closed(_) => io::Error::from(io::ErrorKind::WriteZero),
723            })
724    }
725}
726impl Drop for OwnedUdpTunnelSender {
727    fn drop(&mut self) {
728        self.sender.close();
729    }
730}
731impl Drop for UdpTunnel {
732    fn drop(&mut self) {
733        let Some(sender) = self.sender.take() else {
734            return;
735        };
736        if sender.is_closed() {
737            return;
738        }
739        let Some(udp) = self.udp.take() else {
740            return;
741        };
742        let Some(re_dispatcher) = self.re_dispatcher.take() else {
743            return;
744        };
745        let rs = re_dispatcher.try_send(InactiveUdpTunnel::redistribute(self.index, udp, sender));
746        if let Err(TrySendError::Full(_)) = rs {
747            log::warn!("Udp Tunnel TrySendError full");
748        }
749    }
750}
751
752impl UdpTunnel {
753    fn with_sub(inactive_udp_tunnel: InactiveUdpTunnel) -> io::Result<Self> {
754        let local_addr = inactive_udp_tunnel.udp.local_addr()?;
755        Ok(Self {
756            index: inactive_udp_tunnel.index,
757            local_addr,
758            udp: Some(inactive_udp_tunnel.udp),
759            close_notify: inactive_udp_tunnel.close_notify,
760            re_dispatcher: None,
761            sender: inactive_udp_tunnel.sender,
762        })
763    }
764    fn with_main(
765        inactive_udp_tunnel: InactiveUdpTunnel,
766        re_sender: Sender<InactiveUdpTunnel>,
767    ) -> io::Result<Self> {
768        let local_addr = inactive_udp_tunnel.udp.local_addr()?;
769        Ok(Self {
770            local_addr,
771            index: inactive_udp_tunnel.index,
772            udp: Some(inactive_udp_tunnel.udp),
773            close_notify: None,
774            re_dispatcher: Some(re_sender),
775            sender: inactive_udp_tunnel.sender,
776        })
777    }
778    pub fn done(&mut self) {
779        _ = self.udp.take();
780        _ = self.close_notify.take();
781        _ = self.re_dispatcher.take();
782        _ = self.re_dispatcher.take();
783        _ = self.sender.take();
784    }
785    pub fn local_addr(&self) -> SocketAddr {
786        self.local_addr
787    }
788    pub fn sender(&self) -> io::Result<WeakUdpTunnelSender> {
789        if let Some(v) = &self.sender {
790            Ok(WeakUdpTunnelSender {
791                sender: v.sender.clone(),
792            })
793        } else {
794            Err(io::Error::other("closed"))
795        }
796    }
797}
798
799impl UdpTunnel {
800    /// Writing `buf` to the target denoted by SocketAddr via this tunnel
801    pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
802        if let Some(udp) = &self.udp {
803            udp.send_to(buf, addr.into()).await?;
804            Ok(())
805        } else {
806            Err(io::Error::other("closed"))
807        }
808    }
809    /// Try to write `buf` to the target denoted by SocketAddr via this tunnel
810    pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
811        if let Some(udp) = &self.udp {
812            udp.try_send_to(buf, addr.into())?;
813            Ok(())
814        } else {
815            Err(io::Error::other("closed"))
816        }
817    }
818    pub async fn send_bytes_to<A: Into<SocketAddr>>(&self, buf: Bytes, addr: A) -> io::Result<()> {
819        if let Some(sender) = &self.sender {
820            sender.send_to(buf, addr).await
821        } else {
822            Err(io::Error::other("closed"))
823        }
824    }
825
826    /// Receving buf from this tunnel
827    /// `usize` in the `Ok` branch indicates how many bytes are received
828    /// `RouteKey` in the `Ok` branch denotes the source where these bytes are received from
829    pub async fn recv_from(&mut self, buf: &mut [u8]) -> Option<io::Result<(usize, RouteKey)>> {
830        let udp = if let Some(udp) = &self.udp {
831            udp
832        } else {
833            return None;
834        };
835        loop {
836            if let Some(close_notify) = &mut self.close_notify {
837                tokio::select! {
838                    _rs=close_notify.recv()=>{
839                         self.done();
840                         return None
841                    }
842                    result=udp.recv_from(buf)=>{
843                         let (len, addr) = match result {
844                            Ok(rs) => rs,
845                            Err(e) => {
846                                if should_ignore_error(&e) {
847                                    continue;
848                                }
849                                return Some(Err(e))
850                            }
851                         };
852                         return Some(Ok((len, RouteKey::new(self.index, addr))))
853                    }
854                }
855            } else {
856                let (len, addr) = match udp.recv_from(buf).await {
857                    Ok(rs) => rs,
858                    Err(e) => {
859                        if should_ignore_error(&e) {
860                            continue;
861                        }
862                        return Some(Err(e));
863                    }
864                };
865                return Some(Ok((len, RouteKey::new(self.index, addr))));
866            }
867        }
868    }
869    #[cfg(not(any(target_os = "linux", target_os = "android")))]
870    pub async fn batch_recv_from<B: AsMut<[u8]>>(
871        &mut self,
872        bufs: &mut [B],
873        sizes: &mut [usize],
874        addrs: &mut [RouteKey],
875    ) -> Option<io::Result<usize>> {
876        if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
877            return Some(Err(io::Error::other("bufs error")));
878        }
879        let rs = self.recv_from(bufs[0].as_mut()).await?;
880        match rs {
881            Ok((len, addr)) => {
882                let udp = self.udp.as_ref()?;
883                sizes[0] = len;
884                addrs[0] = addr;
885                let mut num = 1;
886                while num < bufs.len() {
887                    match udp.try_recv_from(bufs[num].as_mut()) {
888                        Ok((len, addr)) => {
889                            sizes[num] = len;
890                            addrs[num] = RouteKey::new(self.index, addr);
891                            num += 1;
892                        }
893                        Err(_) => break,
894                    }
895                }
896                Some(Ok(num))
897            }
898            Err(e) => Some(Err(e)),
899        }
900    }
901    #[cfg(any(target_os = "linux", target_os = "android"))]
902    pub async fn batch_recv_from<B: AsMut<[u8]>>(
903        &mut self,
904        bufs: &mut [B],
905        sizes: &mut [usize],
906        addrs: &mut [RouteKey],
907    ) -> Option<io::Result<usize>> {
908        if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
909            return Some(Err(io::Error::other("bufs/sizes/addrs error")));
910        }
911        let udp = self.udp.as_ref()?;
912        let fd = udp.as_raw_fd();
913        loop {
914            let rs = if let Some(close_notify) = &mut self.close_notify {
915                tokio::select! {
916                    _rs=close_notify.recv()=>{
917                        self.done();
918                        return None
919                    }
920                    rs=read_with(udp,|| recvmmsg(self.index, fd, bufs, sizes, addrs))=>{
921                        rs
922                    }
923                }
924            } else {
925                read_with(udp, || recvmmsg(self.index, fd, bufs, sizes, addrs)).await
926            };
927            return match rs {
928                Ok(size) => Some(Ok(size)),
929                Err(e) => {
930                    if should_ignore_error(&e) {
931                        continue;
932                    }
933                    Some(Err(e))
934                }
935            };
936        }
937    }
938}
939
940#[cfg(any(target_os = "linux", target_os = "android"))]
941fn recvmmsg<B: AsMut<[u8]>>(
942    index: Index,
943    fd: std::os::fd::RawFd,
944    bufs: &mut [B],
945    sizes: &mut [usize],
946    route_keys: &mut [RouteKey],
947) -> io::Result<usize> {
948    let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
949    let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
950    let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
951    let max_num = bufs.len().min(MAX_MESSAGES);
952    for i in 0..max_num {
953        iov[i].iov_base = bufs[i].as_mut().as_mut_ptr() as *mut libc::c_void;
954        iov[i].iov_len = bufs[i].as_mut().len();
955        msgs[i].msg_hdr.msg_iov = &mut iov[i];
956        msgs[i].msg_hdr.msg_iovlen = 1;
957        msgs[i].msg_hdr.msg_name = &mut addrs[i] as *const _ as *mut libc::c_void;
958        msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
959    }
960    let res = unsafe {
961        libc::recvmmsg(
962            fd,
963            msgs.as_mut_ptr(),
964            max_num as c_uint,
965            libc::MSG_DONTWAIT as _,
966            std::ptr::null_mut(),
967        )
968    };
969    if res == -1 {
970        return Err(io::Error::last_os_error());
971    }
972    let nmsgs = res as usize;
973    if nmsgs == 0 {
974        return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
975    }
976    for i in 0..nmsgs {
977        let addr = sockaddr_to_socket_addr(&addrs[i], msgs[i].msg_hdr.msg_namelen);
978        sizes[i] = msgs[i].msg_len as usize;
979        route_keys[i] = RouteKey::new(index, addr);
980    }
981    Ok(nmsgs)
982}
983#[cfg(any(target_os = "linux", target_os = "android"))]
984fn sockaddr_to_socket_addr(addr: &sockaddr_storage, _len: socklen_t) -> SocketAddr {
985    match addr.ss_family as libc::c_int {
986        libc::AF_INET => {
987            let addr_in = unsafe { *(addr as *const _ as *const libc::sockaddr_in) };
988            let ip = u32::from_be(addr_in.sin_addr.s_addr);
989            let port = u16::from_be(addr_in.sin_port);
990            SocketAddr::V4(std::net::SocketAddrV4::new(
991                std::net::Ipv4Addr::from(ip),
992                port,
993            ))
994        }
995        libc::AF_INET6 => {
996            let addr_in6 = unsafe { *(addr as *const _ as *const libc::sockaddr_in6) };
997            let ip = std::net::Ipv6Addr::from(addr_in6.sin6_addr.s6_addr);
998            let port = u16::from_be(addr_in6.sin6_port);
999            SocketAddr::V6(std::net::SocketAddrV6::new(ip, port, 0, 0))
1000        }
1001        _ => panic!("Unsupported address family"),
1002    }
1003}
1004
1005fn should_ignore_error(e: &io::Error) -> bool {
1006    #[cfg(windows)]
1007    {
1008        // 检查错误码是否为 WSAECONNRESET
1009        if let Some(os_error) = e.raw_os_error() {
1010            return os_error == windows_sys::Win32::Networking::WinSock::WSAECONNRESET;
1011        }
1012    }
1013    _ = e;
1014    false
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    use std::time::Duration;
1020
1021    use crate::tunnel::udp::{Model, UdpTunnel};
1022
1023    #[tokio::test]
1024    pub async fn create_udp_tunnel() {
1025        let config = crate::tunnel::config::UdpTunnelConfig::default()
1026            .set_main_udp_count(2)
1027            .set_sub_udp_count(10)
1028            .set_model(Model::Low)
1029            .set_use_v6(false);
1030        let mut udp_tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1031        let mut count = 0;
1032        let mut join = Vec::new();
1033        while let Ok(rs) =
1034            tokio::time::timeout(Duration::from_secs(1), udp_tunnel_factory.dispatch()).await
1035        {
1036            join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1037            count += 1;
1038        }
1039        assert_eq!(count, 2)
1040    }
1041
1042    #[tokio::test]
1043    pub async fn create_sub_udp_tunnel() {
1044        let config = crate::tunnel::config::UdpTunnelConfig::default()
1045            .set_main_udp_count(2)
1046            .set_sub_udp_count(10)
1047            .set_use_v6(false)
1048            .set_model(Model::High);
1049        let mut tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1050        let mut count = 0;
1051        let mut join = Vec::new();
1052        while let Ok(rs) =
1053            tokio::time::timeout(Duration::from_secs(1), tunnel_factory.dispatch()).await
1054        {
1055            join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1056            count += 1;
1057        }
1058        tunnel_factory.manager().switch_low();
1059
1060        let mut close_tunnel_count = 0;
1061        for x in join {
1062            let rs = tokio::time::timeout(Duration::from_secs(1), x).await;
1063            match rs {
1064                Ok(rs) => {
1065                    if rs.unwrap() {
1066                        // tunnel task done
1067                        close_tunnel_count += 1;
1068                    }
1069                }
1070                Err(_e) => {
1071                    _ = _e;
1072                }
1073            }
1074        }
1075        assert_eq!(count, 12);
1076        assert_eq!(close_tunnel_count, 10);
1077    }
1078
1079    async fn tunnel_recv(mut tunnel: UdpTunnel) -> bool {
1080        let mut buf = [0; 1400];
1081        tunnel.recv_from(&mut buf).await.is_none()
1082    }
1083}