1use std::io;
2use std::io::IoSlice;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6#[cfg(any(target_os = "linux", target_os = "android"))]
7use tokio::io::Interest;
8
9#[cfg(any(target_os = "linux", target_os = "android"))]
10pub async fn read_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
11 udp.async_io(Interest::READABLE, op).await
12}
13#[cfg(any(target_os = "linux", target_os = "android"))]
14pub async fn write_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
15 udp.async_io(Interest::WRITABLE, op).await
16}
17
18use bytes::BytesMut;
19use dashmap::DashMap;
20use parking_lot::{Mutex, RwLock};
21use tachyonix::{Receiver, Sender, TrySendError};
22use tokio::net::UdpSocket;
23
24use crate::route::{Index, RouteKey};
25use crate::socket::{bind_udp, LocalInterface};
26use crate::tunnel::config::UdpTunnelConfig;
27use crate::tunnel::recycle::RecycleBuf;
28use crate::tunnel::{DEFAULT_ADDRESS_V4, DEFAULT_ADDRESS_V6};
29
30#[cfg(any(target_os = "linux", target_os = "android"))]
31const MAX_MESSAGES: usize = 16;
32#[cfg(any(target_os = "linux", target_os = "android"))]
33use libc::{c_uint, iovec, mmsghdr, sockaddr_storage, socklen_t};
34#[cfg(any(target_os = "linux", target_os = "android"))]
35use std::os::fd::AsRawFd;
36
37#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
38pub enum Model {
39 High,
40 #[default]
41 Low,
42}
43
44impl Model {
45 pub fn is_low(&self) -> bool {
46 self == &Model::Low
47 }
48 pub fn is_high(&self) -> bool {
49 self == &Model::High
50 }
51}
52
53#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
54pub enum UDPIndex {
55 MainV4(usize),
56 MainV6(usize),
57 SubV4(usize),
58}
59
60impl UDPIndex {
61 pub(crate) fn index(&self) -> usize {
62 match self {
63 UDPIndex::MainV4(i) => *i,
64 UDPIndex::MainV6(i) => *i,
65 UDPIndex::SubV4(i) => *i,
66 }
67 }
68}
69
70pub trait ToRouteKeyForUdp<T> {
71 fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey>;
72}
73
74impl ToRouteKeyForUdp<()> for RouteKey {
75 fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
76 Ok(dest)
77 }
78}
79
80impl ToRouteKeyForUdp<()> for &RouteKey {
81 fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
82 Ok(*dest)
83 }
84}
85
86impl ToRouteKeyForUdp<()> for &mut RouteKey {
87 fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
88 Ok(*dest)
89 }
90}
91
92impl<S: Into<SocketAddr>> ToRouteKeyForUdp<()> for S {
93 fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
94 let addr = dest.into();
95 socket_manager.generate_route_key_from_addr(0, addr)
96 }
97}
98
99impl<S: Into<SocketAddr>> ToRouteKeyForUdp<usize> for (usize, S) {
100 fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
101 let (index, addr) = dest;
102 socket_manager.generate_route_key_from_addr(index, addr.into())
103 }
104}
105
106pub(crate) fn create_tunnel_dispatcher(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
108 config.check()?;
109 let mut udp_ports = config.udp_ports;
110 udp_ports.resize(config.main_udp_count, 0);
111 let mut main_udp_v4: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
112 let mut main_udp_v6: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
113 for port in &udp_ports {
115 loop {
116 let mut addr_v4 = DEFAULT_ADDRESS_V4;
117 addr_v4.set_port(*port);
118 let socket_v4 = bind_udp(addr_v4, config.default_interface.as_ref())?;
119 let udp_v4: std::net::UdpSocket = socket_v4.into();
120 if config.use_v6 {
121 let mut addr_v6 = DEFAULT_ADDRESS_V6;
122 let socket_v6 = if *port == 0 {
123 let port = udp_v4.local_addr()?.port();
124 addr_v6.set_port(port);
125 match bind_udp(addr_v6, config.default_interface.as_ref()) {
126 Ok(socket_v6) => socket_v6,
127 Err(_) => continue,
128 }
129 } else {
130 addr_v6.set_port(*port);
131 bind_udp(addr_v6, config.default_interface.as_ref())?
132 };
133 let udp_v6: std::net::UdpSocket = socket_v6.into();
134 main_udp_v6.push(Arc::new(UdpSocket::from_std(udp_v6)?))
135 }
136 main_udp_v4.push(Arc::new(UdpSocket::from_std(udp_v4)?));
137 break;
138 }
139 }
140 let (tunnel_sender, tunnel_receiver) =
141 tachyonix::channel(config.main_udp_count * 2 + config.sub_udp_count * 2);
142 let socket_manager = Arc::new(UdpSocketManager {
143 main_udp_v4,
144 main_udp_v6,
145 sub_udp: RwLock::new(Vec::with_capacity(config.sub_udp_count)),
146 sub_close_notify: Default::default(),
147 tunnel_dispatcher: tunnel_sender,
148 sub_udp_num: config.sub_udp_count,
149 default_interface: config.default_interface,
150 sender_map: Default::default(),
151 });
152 let tunnel_factory = UdpTunnelDispatcher {
153 tunnel_receiver,
154 socket_manager,
155 recycle_buf: config.recycle_buf,
156 };
157 tunnel_factory.init()?;
158 tunnel_factory.socket_manager.switch_model(config.model)?;
159 Ok(tunnel_factory)
160}
161
162pub struct UdpSocketManager {
163 main_udp_v4: Vec<Arc<UdpSocket>>,
164 main_udp_v6: Vec<Arc<UdpSocket>>,
165 sub_udp: RwLock<Vec<Arc<UdpSocket>>>,
166 sub_close_notify: Mutex<Option<async_broadcast::Sender<()>>>,
167 tunnel_dispatcher: Sender<InactiveUdpTunnel>,
168 sub_udp_num: usize,
169 default_interface: Option<LocalInterface>,
170 sender_map: DashMap<Index, Sender<(BytesMut, SocketAddr)>>,
171}
172
173impl UdpSocketManager {
174 pub(crate) fn try_sub_batch_send_to(&self, buf: &[u8], addr: SocketAddr) {
175 for (i, udp) in self.sub_udp.read().iter().enumerate() {
176 if let Err(e) = udp.try_send_to(buf, addr) {
177 log::info!("try_sub_send_to_addr_v4: {e:?},{i},{addr}")
178 }
179 }
180 }
181 pub(crate) fn try_main_v4_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
182 let len = self.main_udp_v4_count();
183 self.try_main_batch_send_to_impl(buf, addr, len);
184 }
185 pub(crate) fn try_main_v6_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
186 let len = self.main_udp_v6_count();
187 self.try_main_batch_send_to_impl(buf, addr, len);
188 }
189
190 pub(crate) fn try_main_batch_send_to_impl(&self, buf: &[u8], addr: &[SocketAddr], len: usize) {
191 for (i, addr) in addr.iter().enumerate() {
192 if let Err(e) = self.try_send_to(buf, (i % len, *addr)) {
193 log::info!("try_main_send_to_addr: {e:?},{},{addr}", i % len);
194 }
195 }
196 }
197 pub(crate) fn generate_route_key_from_addr(
198 &self,
199 index: usize,
200 addr: SocketAddr,
201 ) -> io::Result<RouteKey> {
202 let route_key = if addr.is_ipv4() {
203 let len = self.main_udp_v4.len();
204 if index >= len {
205 return Err(io::Error::other("index out of bounds"));
206 }
207 RouteKey::new(Index::Udp(UDPIndex::MainV4(index)), addr)
208 } else {
209 let len = self.main_udp_v6.len();
210 if len == 0 {
211 return Err(io::Error::other("Not support IPV6"));
212 }
213 if index >= len {
214 return Err(io::Error::other("index out of bounds"));
215 }
216 RouteKey::new(Index::Udp(UDPIndex::MainV6(index)), addr)
217 };
218 Ok(route_key)
219 }
220 pub(crate) fn switch_low(&self) {
221 let mut guard = self.sub_udp.write();
222 if guard.is_empty() {
223 return;
224 }
225 guard.clear();
226 if let Some(sub_close_notify) = self.sub_close_notify.lock().take() {
227 let _ = sub_close_notify.close();
228 }
229 }
230 pub(crate) fn switch_high(&self) -> io::Result<()> {
231 let mut guard = self.sub_udp.write();
232 if !guard.is_empty() {
233 return Ok(());
234 }
235 let mut sub_close_notify_guard = self.sub_close_notify.lock();
236 if let Some(sender) = sub_close_notify_guard.take() {
237 let _ = sender.close();
238 }
239 let (sub_close_notify_sender, sub_close_notify_receiver) = async_broadcast::broadcast(2);
240 let mut sub_udp_list = Vec::with_capacity(self.sub_udp_num);
241 for _ in 0..self.sub_udp_num {
242 let udp = bind_udp(DEFAULT_ADDRESS_V4, self.default_interface.as_ref())?;
243 let udp: std::net::UdpSocket = udp.into();
244 sub_udp_list.push(Arc::new(UdpSocket::from_std(udp)?));
245 }
246 for (index, udp) in sub_udp_list.iter().enumerate() {
247 let udp = udp.clone();
248 let udp_tunnel = InactiveUdpTunnel::new(
249 false,
250 Index::Udp(UDPIndex::SubV4(index)),
251 udp,
252 Some(sub_close_notify_receiver.clone()),
253 );
254 if self.tunnel_dispatcher.try_send(udp_tunnel).is_err() {
255 Err(io::Error::other("tunnel channel error"))?
256 }
257 }
258 sub_close_notify_guard.replace(sub_close_notify_sender);
259 *guard = sub_udp_list;
260 Ok(())
261 }
262
263 #[inline]
264 fn get_udp(&self, udp_index: UDPIndex) -> io::Result<Arc<UdpSocket>> {
265 Ok(match udp_index {
266 UDPIndex::MainV4(index) => self
267 .main_udp_v4
268 .get(index)
269 .ok_or(io::Error::other("index out of bounds"))?
270 .clone(),
271 UDPIndex::MainV6(index) => self
272 .main_udp_v6
273 .get(index)
274 .ok_or(io::Error::other("index out of bounds"))?
275 .clone(),
276 UDPIndex::SubV4(index) => {
277 let guard = self.sub_udp.read();
278 let len = guard.len();
279 if len <= index {
280 return Err(io::Error::other("index out of bounds"));
281 } else {
282 guard[index].clone()
283 }
284 }
285 })
286 }
287
288 #[inline]
289 fn get_udp_from_route(&self, route_key: &RouteKey) -> io::Result<Arc<UdpSocket>> {
290 Ok(match route_key.index() {
291 Index::Udp(index) => self.get_udp(index)?,
292 _ => return Err(io::Error::from(io::ErrorKind::InvalidInput)),
293 })
294 }
295}
296
297impl UdpSocketManager {
298 pub fn model(&self) -> Model {
299 if self.sub_udp.read().is_empty() {
300 Model::Low
301 } else {
302 Model::High
303 }
304 }
305
306 #[inline]
307 pub fn main_udp_v4_count(&self) -> usize {
308 self.main_udp_v4.len()
309 }
310 #[inline]
311 pub fn main_udp_v6_count(&self) -> usize {
312 self.main_udp_v6.len()
313 }
314
315 pub fn switch_model(&self, model: Model) -> io::Result<()> {
316 match model {
317 Model::High => self.switch_high(),
318 Model::Low => {
319 self.switch_low();
320 Ok(())
321 }
322 }
323 }
324 pub fn local_ports(&self) -> io::Result<Vec<u16>> {
326 let mut ports = Vec::with_capacity(self.main_udp_v4_count());
327 for udp in &self.main_udp_v4 {
328 ports.push(udp.local_addr()?.port());
329 }
330 Ok(ports)
331 }
332 pub async fn send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
334 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
335 let len = self
336 .get_udp_from_route(&route_key)?
337 .send_to(buf, route_key.addr())
338 .await?;
339 if len == 0 {
340 return Err(std::io::Error::from(io::ErrorKind::WriteZero));
341 }
342 Ok(())
343 }
344
345 pub fn try_send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
347 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
348 let len = self
349 .get_udp_from_route(&route_key)?
350 .try_send_to(buf, route_key.addr())?;
351 if len == 0 {
352 return Err(std::io::Error::from(io::ErrorKind::WriteZero));
353 }
354 Ok(())
355 }
356
357 pub async fn batch_send_to<T, D: ToRouteKeyForUdp<T>>(
358 &self,
359 bufs: &[IoSlice<'_>],
360 dest: D,
361 ) -> io::Result<()> {
362 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
363 let udp = self.get_udp_from_route(&route_key)?;
364 for buf in bufs {
365 let len = udp.send_to(buf, route_key.addr()).await?;
366 if len == 0 {
367 return Err(std::io::Error::from(io::ErrorKind::WriteZero));
368 }
369 }
370
371 Ok(())
372 }
373 fn get_sender(&self, route_key: &RouteKey) -> io::Result<Sender<(BytesMut, SocketAddr)>> {
374 if let Some(sender) = self.sender_map.get(&route_key.index()) {
375 Ok(sender.value().clone())
376 } else {
377 Err(io::Error::new(io::ErrorKind::NotFound, "route not found"))
378 }
379 }
380 pub async fn send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
381 &self,
382 buf: BytesMut,
383 dest: D,
384 ) -> io::Result<()> {
385 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
386 let sender = self.get_sender(&route_key)?;
387 if let Err(_e) = sender.send((buf, route_key.addr())).await {
388 Err(io::Error::from(io::ErrorKind::WriteZero))
389 } else {
390 Ok(())
391 }
392 }
393 pub fn try_send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
394 &self,
395 buf: BytesMut,
396 dest: D,
397 ) -> io::Result<()> {
398 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
399 let sender = self.get_sender(&route_key)?;
400 if let Err(e) = sender.try_send((buf, route_key.addr())) {
401 match e {
402 TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
403 TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
404 }
405 } else {
406 Ok(())
407 }
408 }
409
410 pub async fn detect_pub_addrs<A: Into<SocketAddr>>(
412 &self,
413 buf: &[u8],
414 addr: A,
415 ) -> io::Result<()> {
416 let addr: SocketAddr = addr.into();
417 for index in 0..self.main_udp_v4_count() {
418 self.send_to(buf, (index, addr)).await?
419 }
420 Ok(())
421 }
422}
423
424pub struct UdpTunnelDispatcher {
425 tunnel_receiver: Receiver<InactiveUdpTunnel>,
426 pub(crate) socket_manager: Arc<UdpSocketManager>,
427 recycle_buf: Option<RecycleBuf>,
428}
429
430impl UdpTunnelDispatcher {
431 pub(crate) fn init(&self) -> io::Result<()> {
432 for (index, udp) in self.socket_manager.main_udp_v4.iter().enumerate() {
433 let udp = udp.clone();
434 let tunnel =
435 InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV4(index)), udp, None);
436 if self
437 .socket_manager
438 .tunnel_dispatcher
439 .try_send(tunnel)
440 .is_err()
441 {
442 Err(io::Error::other("tunnel channel error"))?
443 }
444 }
445 for (index, udp) in self.socket_manager.main_udp_v6.iter().enumerate() {
446 let udp = udp.clone();
447 let tunnel =
448 InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV6(index)), udp, None);
449 if self
450 .socket_manager
451 .tunnel_dispatcher
452 .try_send(tunnel)
453 .is_err()
454 {
455 Err(io::Error::other("tunnel channel error"))?
456 }
457 }
458 Ok(())
459 }
460}
461
462impl UdpTunnelDispatcher {
463 pub fn new(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
465 create_tunnel_dispatcher(config)
466 }
467 pub async fn dispatch(&mut self) -> io::Result<UdpTunnel> {
469 let mut udp_tunnel = self
470 .tunnel_receiver
471 .recv()
472 .await
473 .map_err(|_| io::Error::other("Udp tunnel close"))?;
474 let option = self
475 .socket_manager
476 .sender_map
477 .get(&udp_tunnel.index)
478 .map(|v| v.value().clone());
479 let sender = if let Some(v) = option {
480 v
481 } else {
482 let (s, mut r) = tachyonix::channel(128);
483 let index = udp_tunnel.index;
484 let sender = s.clone();
485 self.socket_manager.sender_map.insert(index, s);
486
487 let socket_manager = self.socket_manager.clone();
488 let udp = udp_tunnel.udp.clone();
489 let recycle_buf = self.recycle_buf.clone();
490 tokio::spawn(async move {
491 #[cfg(any(target_os = "linux", target_os = "android"))]
492 let mut vec_buf = Vec::with_capacity(16);
493
494 while let Ok((buf, addr)) = r.recv().await {
495 #[cfg(any(target_os = "linux", target_os = "android"))]
496 {
497 vec_buf.push((buf, addr));
498 while let Ok(tup) = r.try_recv() {
499 vec_buf.push(tup);
500 if vec_buf.len() == MAX_MESSAGES {
501 break;
502 }
503 }
504 let mut bufs = &mut vec_buf[..];
505 let fd = udp.as_raw_fd();
506 loop {
507 if bufs.len() == 1 {
508 let (buf, addr) = unsafe { bufs.get_unchecked(0) };
509 if let Err(e) = udp.send_to(buf, *addr).await {
510 log::warn!("send_to {addr:?},{e:?}")
511 }
512 break;
513 } else {
514 let rs = write_with(&udp, || sendmmsg(fd, bufs)).await;
515 match rs {
516 Ok(size) => {
517 if size == 0 {
518 break;
519 }
520 if size < bufs.len() {
521 bufs = &mut bufs[size..];
522 continue;
523 }
524 break;
525 }
526 Err(e) => {
527 log::warn!("sendmmsg {e:?}");
528 }
529 }
530 }
531 }
532 if let Some(recycle_buf) = recycle_buf.as_ref() {
533 while let Some((buf, _)) = vec_buf.pop() {
534 recycle_buf.push(buf);
535 }
536 } else {
537 vec_buf.clear();
538 }
539 }
540 #[cfg(not(any(target_os = "linux", target_os = "android")))]
541 {
542 let rs = udp.send_to(&buf, addr).await;
543 if let Some(recycle_buf) = recycle_buf.as_ref() {
544 recycle_buf.push(buf);
545 }
546 if let Err(e) = rs {
547 log::debug!("{addr:?},{e:?}")
548 }
549 }
550 }
551 socket_manager.sender_map.remove(&index);
552 });
553 sender
554 };
555 if udp_tunnel.sender.is_none() {
556 udp_tunnel.sender.replace(OwnedUdpTunnelSender { sender });
557 }
558 if udp_tunnel.reusable {
559 UdpTunnel::with_main(udp_tunnel, self.manager().tunnel_dispatcher.clone())
560 } else {
561 UdpTunnel::with_sub(udp_tunnel)
562 }
563 }
564 pub fn manager(&self) -> &Arc<UdpSocketManager> {
565 &self.socket_manager
566 }
567}
568
569#[cfg(any(target_os = "linux", target_os = "android"))]
570fn sendmmsg(fd: std::os::fd::RawFd, bufs: &mut [(BytesMut, SocketAddr)]) -> io::Result<usize> {
571 assert!(bufs.len() <= MAX_MESSAGES);
572 let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
573 let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
574 let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
575 for (i, (buf, addr)) in bufs.iter_mut().enumerate() {
576 addrs[i] = socket_addr_to_sockaddr(addr);
577 iov[i].iov_base = buf.as_mut_ptr() as *mut libc::c_void;
578 iov[i].iov_len = buf.len();
579 msgs[i].msg_hdr.msg_iov = &mut iov[i];
580 msgs[i].msg_hdr.msg_iovlen = 1;
581
582 msgs[i].msg_hdr.msg_name = &mut addrs[i] as *mut _ as *mut libc::c_void;
583 msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
584 }
585
586 unsafe {
587 let res = libc::sendmmsg(
588 fd,
589 msgs.as_mut_ptr(),
590 bufs.len() as _,
591 libc::MSG_DONTWAIT as _,
592 );
593 if res == -1 {
594 return Err(io::Error::last_os_error());
595 }
596 Ok(res as usize)
597 }
598}
599
600#[cfg(any(target_os = "linux", target_os = "android"))]
601fn socket_addr_to_sockaddr(addr: &SocketAddr) -> sockaddr_storage {
602 let mut storage: sockaddr_storage = unsafe { std::mem::zeroed() };
603
604 match addr {
605 SocketAddr::V4(v4_addr) => {
606 let sin = libc::sockaddr_in {
607 sin_family: libc::AF_INET as _,
608 sin_port: v4_addr.port().to_be(),
609 sin_addr: libc::in_addr {
610 s_addr: u32::from_ne_bytes(v4_addr.ip().octets()), },
612 sin_zero: [0; 8],
613 };
614
615 unsafe {
616 let sin_ptr = &sin as *const libc::sockaddr_in as *const u8;
617 let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
618 std::ptr::copy_nonoverlapping(
619 sin_ptr,
620 storage_ptr,
621 std::mem::size_of::<libc::sockaddr>(),
622 );
623 }
624 }
625 SocketAddr::V6(v6_addr) => {
626 let sin6 = libc::sockaddr_in6 {
627 sin6_family: libc::AF_INET6 as _,
628 sin6_port: v6_addr.port().to_be(),
629 sin6_flowinfo: v6_addr.flowinfo(),
630 sin6_addr: libc::in6_addr {
631 s6_addr: v6_addr.ip().octets(),
632 },
633 sin6_scope_id: v6_addr.scope_id(),
634 };
635
636 unsafe {
637 let sin6_ptr = &sin6 as *const libc::sockaddr_in6 as *const u8;
638 let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
639 std::ptr::copy_nonoverlapping(
640 sin6_ptr,
641 storage_ptr,
642 std::mem::size_of::<libc::sockaddr>(),
643 );
644 }
645 }
646 }
647 storage
648}
649
650pub struct UdpTunnel {
651 index: Index,
652 local_addr: SocketAddr,
653 udp: Option<Arc<UdpSocket>>,
654 close_notify: Option<async_broadcast::Receiver<()>>,
655 re_dispatcher: Option<Sender<InactiveUdpTunnel>>,
656 sender: Option<OwnedUdpTunnelSender>,
657}
658struct OwnedUdpTunnelSender {
659 sender: Sender<(BytesMut, SocketAddr)>,
660}
661#[derive(Clone)]
662pub struct WeakUdpTunnelSender {
663 sender: Sender<(BytesMut, SocketAddr)>,
664}
665struct InactiveUdpTunnel {
666 reusable: bool,
667 index: Index,
668 udp: Arc<UdpSocket>,
669 close_notify: Option<async_broadcast::Receiver<()>>,
670 sender: Option<OwnedUdpTunnelSender>,
671}
672impl InactiveUdpTunnel {
673 fn new(
674 reusable: bool,
675 index: Index,
676 udp: Arc<UdpSocket>,
677 close_notify: Option<async_broadcast::Receiver<()>>,
678 ) -> Self {
679 Self {
680 reusable,
681 index,
682 udp,
683 close_notify,
684 sender: None,
685 }
686 }
687 fn redistribute(index: Index, udp: Arc<UdpSocket>, sender: OwnedUdpTunnelSender) -> Self {
688 Self {
689 reusable: true,
690 index,
691 udp,
692 close_notify: None,
693 sender: Some(sender),
694 }
695 }
696}
697impl OwnedUdpTunnelSender {
698 async fn send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
699 if buf.is_empty() {
700 return Ok(());
701 }
702 self.sender
703 .send((buf, dest.into()))
704 .await
705 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
706 }
707 fn is_closed(&self) -> bool {
708 self.sender.is_closed()
709 }
710}
711impl WeakUdpTunnelSender {
712 pub async fn send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
713 if buf.is_empty() {
714 return Ok(());
715 }
716 self.sender
717 .send((buf, dest.into()))
718 .await
719 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
720 }
721 pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
722 if buf.is_empty() {
723 return Ok(());
724 }
725 self.sender
726 .try_send((buf, dest.into()))
727 .map_err(|e| match e {
728 TrySendError::Full(_) => io::Error::from(io::ErrorKind::WouldBlock),
729 TrySendError::Closed(_) => io::Error::from(io::ErrorKind::WriteZero),
730 })
731 }
732}
733impl Drop for OwnedUdpTunnelSender {
734 fn drop(&mut self) {
735 self.sender.close();
736 }
737}
738impl Drop for UdpTunnel {
739 fn drop(&mut self) {
740 let Some(sender) = self.sender.take() else {
741 return;
742 };
743 if sender.is_closed() {
744 return;
745 }
746 let Some(udp) = self.udp.take() else {
747 return;
748 };
749 let Some(re_dispatcher) = self.re_dispatcher.take() else {
750 return;
751 };
752 let rs = re_dispatcher.try_send(InactiveUdpTunnel::redistribute(self.index, udp, sender));
753 if let Err(TrySendError::Full(_)) = rs {
754 log::warn!("Udp Tunnel TrySendError full");
755 }
756 }
757}
758
759impl UdpTunnel {
760 fn with_sub(inactive_udp_tunnel: InactiveUdpTunnel) -> io::Result<Self> {
761 let local_addr = inactive_udp_tunnel.udp.local_addr()?;
762 Ok(Self {
763 index: inactive_udp_tunnel.index,
764 local_addr,
765 udp: Some(inactive_udp_tunnel.udp),
766 close_notify: inactive_udp_tunnel.close_notify,
767 re_dispatcher: None,
768 sender: inactive_udp_tunnel.sender,
769 })
770 }
771 fn with_main(
772 inactive_udp_tunnel: InactiveUdpTunnel,
773 re_sender: Sender<InactiveUdpTunnel>,
774 ) -> io::Result<Self> {
775 let local_addr = inactive_udp_tunnel.udp.local_addr()?;
776 Ok(Self {
777 local_addr,
778 index: inactive_udp_tunnel.index,
779 udp: Some(inactive_udp_tunnel.udp),
780 close_notify: None,
781 re_dispatcher: Some(re_sender),
782 sender: inactive_udp_tunnel.sender,
783 })
784 }
785 pub fn done(&mut self) {
786 _ = self.udp.take();
787 _ = self.close_notify.take();
788 _ = self.re_dispatcher.take();
789 _ = self.re_dispatcher.take();
790 _ = self.sender.take();
791 }
792 pub fn local_addr(&self) -> SocketAddr {
793 self.local_addr
794 }
795 pub fn sender(&self) -> io::Result<WeakUdpTunnelSender> {
796 if let Some(v) = &self.sender {
797 Ok(WeakUdpTunnelSender {
798 sender: v.sender.clone(),
799 })
800 } else {
801 Err(io::Error::other("closed"))
802 }
803 }
804}
805
806impl UdpTunnel {
807 pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
809 if let Some(udp) = &self.udp {
810 udp.send_to(buf, addr.into()).await?;
811 Ok(())
812 } else {
813 Err(io::Error::other("closed"))
814 }
815 }
816 pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
818 if let Some(udp) = &self.udp {
819 udp.try_send_to(buf, addr.into())?;
820 Ok(())
821 } else {
822 Err(io::Error::other("closed"))
823 }
824 }
825 pub async fn send_bytes_to<A: Into<SocketAddr>>(
826 &self,
827 buf: BytesMut,
828 addr: A,
829 ) -> io::Result<()> {
830 if let Some(sender) = &self.sender {
831 sender.send_to(buf, addr).await
832 } else {
833 Err(io::Error::other("closed"))
834 }
835 }
836
837 pub async fn recv_from(&mut self, buf: &mut [u8]) -> Option<io::Result<(usize, RouteKey)>> {
841 let udp = if let Some(udp) = &self.udp {
842 udp
843 } else {
844 return None;
845 };
846 loop {
847 if let Some(close_notify) = &mut self.close_notify {
848 tokio::select! {
849 _rs=close_notify.recv()=>{
850 self.done();
851 return None
852 }
853 result=udp.recv_from(buf)=>{
854 let (len, addr) = match result {
855 Ok(rs) => rs,
856 Err(e) => {
857 if should_ignore_error(&e) {
858 continue;
859 }
860 return Some(Err(e))
861 }
862 };
863 return Some(Ok((len, RouteKey::new(self.index, addr))))
864 }
865 }
866 } else {
867 let (len, addr) = match udp.recv_from(buf).await {
868 Ok(rs) => rs,
869 Err(e) => {
870 if should_ignore_error(&e) {
871 continue;
872 }
873 return Some(Err(e));
874 }
875 };
876 return Some(Ok((len, RouteKey::new(self.index, addr))));
877 }
878 }
879 }
880 #[cfg(not(any(target_os = "linux", target_os = "android")))]
881 pub async fn batch_recv_from<B: AsMut<[u8]>>(
882 &mut self,
883 bufs: &mut [B],
884 sizes: &mut [usize],
885 addrs: &mut [RouteKey],
886 ) -> Option<io::Result<usize>> {
887 if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
888 return Some(Err(io::Error::other("bufs error")));
889 }
890 let rs = self.recv_from(bufs[0].as_mut()).await?;
891 match rs {
892 Ok((len, addr)) => {
893 let udp = self.udp.as_ref()?;
894 sizes[0] = len;
895 addrs[0] = addr;
896 let mut num = 1;
897 while num < bufs.len() {
898 match udp.try_recv_from(bufs[num].as_mut()) {
899 Ok((len, addr)) => {
900 sizes[num] = len;
901 addrs[num] = RouteKey::new(self.index, addr);
902 num += 1;
903 }
904 Err(_) => break,
905 }
906 }
907 Some(Ok(num))
908 }
909 Err(e) => Some(Err(e)),
910 }
911 }
912 #[cfg(any(target_os = "linux", target_os = "android"))]
913 pub async fn batch_recv_from<B: AsMut<[u8]>>(
914 &mut self,
915 bufs: &mut [B],
916 sizes: &mut [usize],
917 addrs: &mut [RouteKey],
918 ) -> Option<io::Result<usize>> {
919 if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
920 return Some(Err(io::Error::other("bufs/sizes/addrs error")));
921 }
922 let udp = self.udp.as_ref()?;
923 let fd = udp.as_raw_fd();
924 loop {
925 let rs = if let Some(close_notify) = &mut self.close_notify {
926 tokio::select! {
927 _rs=close_notify.recv()=>{
928 self.done();
929 return None
930 }
931 rs=read_with(udp,|| recvmmsg(self.index, fd, bufs, sizes, addrs))=>{
932 rs
933 }
934 }
935 } else {
936 read_with(udp, || recvmmsg(self.index, fd, bufs, sizes, addrs)).await
937 };
938 return match rs {
939 Ok(size) => Some(Ok(size)),
940 Err(e) => {
941 if should_ignore_error(&e) {
942 continue;
943 }
944 Some(Err(e))
945 }
946 };
947 }
948 }
949}
950
951#[cfg(any(target_os = "linux", target_os = "android"))]
952fn recvmmsg<B: AsMut<[u8]>>(
953 index: Index,
954 fd: std::os::fd::RawFd,
955 bufs: &mut [B],
956 sizes: &mut [usize],
957 route_keys: &mut [RouteKey],
958) -> io::Result<usize> {
959 let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
960 let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
961 let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
962 let max_num = bufs.len().min(MAX_MESSAGES);
963 for i in 0..max_num {
964 iov[i].iov_base = bufs[i].as_mut().as_mut_ptr() as *mut libc::c_void;
965 iov[i].iov_len = bufs[i].as_mut().len();
966 msgs[i].msg_hdr.msg_iov = &mut iov[i];
967 msgs[i].msg_hdr.msg_iovlen = 1;
968 msgs[i].msg_hdr.msg_name = &mut addrs[i] as *const _ as *mut libc::c_void;
969 msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
970 }
971 let res = unsafe {
972 libc::recvmmsg(
973 fd,
974 msgs.as_mut_ptr(),
975 max_num as c_uint,
976 libc::MSG_DONTWAIT as _,
977 std::ptr::null_mut(),
978 )
979 };
980 if res == -1 {
981 return Err(io::Error::last_os_error());
982 }
983 let nmsgs = res as usize;
984 if nmsgs == 0 {
985 return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
986 }
987 for i in 0..nmsgs {
988 let addr = sockaddr_to_socket_addr(&addrs[i], msgs[i].msg_hdr.msg_namelen);
989 sizes[i] = msgs[i].msg_len as usize;
990 route_keys[i] = RouteKey::new(index, addr);
991 }
992 Ok(nmsgs)
993}
994#[cfg(any(target_os = "linux", target_os = "android"))]
995fn sockaddr_to_socket_addr(addr: &sockaddr_storage, _len: socklen_t) -> SocketAddr {
996 match addr.ss_family as libc::c_int {
997 libc::AF_INET => {
998 let addr_in = unsafe { *(addr as *const _ as *const libc::sockaddr_in) };
999 let ip = u32::from_be(addr_in.sin_addr.s_addr);
1000 let port = u16::from_be(addr_in.sin_port);
1001 SocketAddr::V4(std::net::SocketAddrV4::new(
1002 std::net::Ipv4Addr::from(ip),
1003 port,
1004 ))
1005 }
1006 libc::AF_INET6 => {
1007 let addr_in6 = unsafe { *(addr as *const _ as *const libc::sockaddr_in6) };
1008 let ip = std::net::Ipv6Addr::from(addr_in6.sin6_addr.s6_addr);
1009 let port = u16::from_be(addr_in6.sin6_port);
1010 SocketAddr::V6(std::net::SocketAddrV6::new(ip, port, 0, 0))
1011 }
1012 _ => panic!("Unsupported address family"),
1013 }
1014}
1015
1016fn should_ignore_error(e: &io::Error) -> bool {
1017 #[cfg(windows)]
1018 {
1019 if let Some(os_error) = e.raw_os_error() {
1021 return os_error == windows_sys::Win32::Networking::WinSock::WSAECONNRESET;
1022 }
1023 }
1024 _ = e;
1025 false
1026}
1027
1028#[cfg(test)]
1029mod tests {
1030 use std::time::Duration;
1031
1032 use crate::tunnel::udp::{Model, UdpTunnel};
1033
1034 #[tokio::test]
1035 pub async fn create_udp_tunnel() {
1036 let config = crate::tunnel::config::UdpTunnelConfig::default()
1037 .set_main_udp_count(2)
1038 .set_sub_udp_count(10)
1039 .set_model(Model::Low)
1040 .set_use_v6(false);
1041 let mut udp_tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1042 let mut count = 0;
1043 let mut join = Vec::new();
1044 while let Ok(rs) =
1045 tokio::time::timeout(Duration::from_secs(1), udp_tunnel_factory.dispatch()).await
1046 {
1047 join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1048 count += 1;
1049 }
1050 assert_eq!(count, 2)
1051 }
1052
1053 #[tokio::test]
1054 pub async fn create_sub_udp_tunnel() {
1055 let config = crate::tunnel::config::UdpTunnelConfig::default()
1056 .set_main_udp_count(2)
1057 .set_sub_udp_count(10)
1058 .set_use_v6(false)
1059 .set_model(Model::High);
1060 let mut tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1061 let mut count = 0;
1062 let mut join = Vec::new();
1063 while let Ok(rs) =
1064 tokio::time::timeout(Duration::from_secs(1), tunnel_factory.dispatch()).await
1065 {
1066 join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1067 count += 1;
1068 }
1069 tunnel_factory.manager().switch_low();
1070
1071 let mut close_tunnel_count = 0;
1072 for x in join {
1073 let rs = tokio::time::timeout(Duration::from_secs(1), x).await;
1074 match rs {
1075 Ok(rs) => {
1076 if rs.unwrap() {
1077 close_tunnel_count += 1;
1079 }
1080 }
1081 Err(_e) => {
1082 _ = _e;
1083 }
1084 }
1085 }
1086 assert_eq!(count, 12);
1087 assert_eq!(close_tunnel_count, 10);
1088 }
1089
1090 async fn tunnel_recv(mut tunnel: UdpTunnel) -> bool {
1091 let mut buf = [0; 1400];
1092 tunnel.recv_from(&mut buf).await.is_none()
1093 }
1094}