1use std::io;
2use std::io::IoSlice;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6#[cfg(any(target_os = "linux", target_os = "android"))]
7use tokio::io::Interest;
8
9#[cfg(any(target_os = "linux", target_os = "android"))]
10pub async fn read_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
11 udp.async_io(Interest::READABLE, op).await
12}
13#[cfg(any(target_os = "linux", target_os = "android"))]
14pub async fn write_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
15 udp.async_io(Interest::WRITABLE, op).await
16}
17
18use bytes::Bytes;
19use dashmap::DashMap;
20use parking_lot::{Mutex, RwLock};
21use tachyonix::{Receiver, Sender, TrySendError};
22use tokio::net::UdpSocket;
23
24use crate::route::{Index, RouteKey};
25use crate::socket::{bind_udp, LocalInterface};
26use crate::tunnel::config::UdpTunnelConfig;
27use crate::tunnel::{DEFAULT_ADDRESS_V4, DEFAULT_ADDRESS_V6};
28
29#[cfg(any(target_os = "linux", target_os = "android"))]
30const MAX_MESSAGES: usize = 16;
31#[cfg(any(target_os = "linux", target_os = "android"))]
32use libc::{c_uint, iovec, mmsghdr, sockaddr_storage, socklen_t};
33#[cfg(any(target_os = "linux", target_os = "android"))]
34use std::os::fd::AsRawFd;
35
36#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
37pub enum Model {
38 High,
39 #[default]
40 Low,
41}
42
43impl Model {
44 pub fn is_low(&self) -> bool {
45 self == &Model::Low
46 }
47 pub fn is_high(&self) -> bool {
48 self == &Model::High
49 }
50}
51
52#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
53pub enum UDPIndex {
54 MainV4(usize),
55 MainV6(usize),
56 SubV4(usize),
57}
58
59impl UDPIndex {
60 pub(crate) fn index(&self) -> usize {
61 match self {
62 UDPIndex::MainV4(i) => *i,
63 UDPIndex::MainV6(i) => *i,
64 UDPIndex::SubV4(i) => *i,
65 }
66 }
67}
68
69pub trait ToRouteKeyForUdp<T> {
70 fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey>;
71}
72
73impl ToRouteKeyForUdp<()> for RouteKey {
74 fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
75 Ok(dest)
76 }
77}
78
79impl ToRouteKeyForUdp<()> for &RouteKey {
80 fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
81 Ok(*dest)
82 }
83}
84
85impl ToRouteKeyForUdp<()> for &mut RouteKey {
86 fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
87 Ok(*dest)
88 }
89}
90
91impl<S: Into<SocketAddr>> ToRouteKeyForUdp<()> for S {
92 fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
93 let addr = dest.into();
94 socket_manager.generate_route_key_from_addr(0, addr)
95 }
96}
97
98impl<S: Into<SocketAddr>> ToRouteKeyForUdp<usize> for (usize, S) {
99 fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
100 let (index, addr) = dest;
101 socket_manager.generate_route_key_from_addr(index, addr.into())
102 }
103}
104
105pub(crate) fn create_tunnel_dispatcher(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
107 config.check()?;
108 let mut udp_ports = config.udp_ports;
109 udp_ports.resize(config.main_udp_count, 0);
110 let mut main_udp_v4: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
111 let mut main_udp_v6: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
112 for port in &udp_ports {
114 loop {
115 let mut addr_v4 = DEFAULT_ADDRESS_V4;
116 addr_v4.set_port(*port);
117 let socket_v4 = bind_udp(addr_v4, config.default_interface.as_ref())?;
118 let udp_v4: std::net::UdpSocket = socket_v4.into();
119 if config.use_v6 {
120 let mut addr_v6 = DEFAULT_ADDRESS_V6;
121 let socket_v6 = if *port == 0 {
122 let port = udp_v4.local_addr()?.port();
123 addr_v6.set_port(port);
124 match bind_udp(addr_v6, config.default_interface.as_ref()) {
125 Ok(socket_v6) => socket_v6,
126 Err(_) => continue,
127 }
128 } else {
129 addr_v6.set_port(*port);
130 bind_udp(addr_v6, config.default_interface.as_ref())?
131 };
132 let udp_v6: std::net::UdpSocket = socket_v6.into();
133 main_udp_v6.push(Arc::new(UdpSocket::from_std(udp_v6)?))
134 }
135 main_udp_v4.push(Arc::new(UdpSocket::from_std(udp_v4)?));
136 break;
137 }
138 }
139 let (tunnel_sender, tunnel_receiver) =
140 tachyonix::channel(config.main_udp_count * 2 + config.sub_udp_count * 2);
141 let socket_manager = Arc::new(UdpSocketManager {
142 main_udp_v4,
143 main_udp_v6,
144 sub_udp: RwLock::new(Vec::with_capacity(config.sub_udp_count)),
145 sub_close_notify: Default::default(),
146 tunnel_dispatcher: tunnel_sender,
147 sub_udp_num: config.sub_udp_count,
148 default_interface: config.default_interface,
149 sender_map: Default::default(),
150 });
151 let tunnel_factory = UdpTunnelDispatcher {
152 tunnel_receiver,
153 socket_manager,
154 };
155 tunnel_factory.init()?;
156 tunnel_factory.socket_manager.switch_model(config.model)?;
157 Ok(tunnel_factory)
158}
159
160pub struct UdpSocketManager {
161 main_udp_v4: Vec<Arc<UdpSocket>>,
162 main_udp_v6: Vec<Arc<UdpSocket>>,
163 sub_udp: RwLock<Vec<Arc<UdpSocket>>>,
164 sub_close_notify: Mutex<Option<async_broadcast::Sender<()>>>,
165 tunnel_dispatcher: Sender<InactiveUdpTunnel>,
166 sub_udp_num: usize,
167 default_interface: Option<LocalInterface>,
168 sender_map: DashMap<Index, Sender<(Bytes, SocketAddr)>>,
169}
170
171impl UdpSocketManager {
172 pub(crate) fn try_sub_batch_send_to(&self, buf: &[u8], addr: SocketAddr) {
173 for (i, udp) in self.sub_udp.read().iter().enumerate() {
174 if let Err(e) = udp.try_send_to(buf, addr) {
175 log::info!("try_sub_send_to_addr_v4: {e:?},{i},{addr}")
176 }
177 }
178 }
179 pub(crate) fn try_main_v4_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
180 let len = self.main_udp_v4_count();
181 self.try_main_batch_send_to_impl(buf, addr, len);
182 }
183 pub(crate) fn try_main_v6_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
184 let len = self.main_udp_v6_count();
185 self.try_main_batch_send_to_impl(buf, addr, len);
186 }
187
188 pub(crate) fn try_main_batch_send_to_impl(&self, buf: &[u8], addr: &[SocketAddr], len: usize) {
189 for (i, addr) in addr.iter().enumerate() {
190 if let Err(e) = self.try_send_to(buf, (i % len, *addr)) {
191 log::info!("try_main_send_to_addr: {e:?},{},{addr}", i % len);
192 }
193 }
194 }
195 pub(crate) fn generate_route_key_from_addr(
196 &self,
197 index: usize,
198 addr: SocketAddr,
199 ) -> io::Result<RouteKey> {
200 let route_key = if addr.is_ipv4() {
201 let len = self.main_udp_v4.len();
202 if index >= len {
203 return Err(io::Error::other("index out of bounds"));
204 }
205 RouteKey::new(Index::Udp(UDPIndex::MainV4(index)), addr)
206 } else {
207 let len = self.main_udp_v6.len();
208 if len == 0 {
209 return Err(io::Error::other("Not support IPV6"));
210 }
211 if index >= len {
212 return Err(io::Error::other("index out of bounds"));
213 }
214 RouteKey::new(Index::Udp(UDPIndex::MainV6(index)), addr)
215 };
216 Ok(route_key)
217 }
218 pub(crate) fn switch_low(&self) {
219 let mut guard = self.sub_udp.write();
220 if guard.is_empty() {
221 return;
222 }
223 guard.clear();
224 if let Some(sub_close_notify) = self.sub_close_notify.lock().take() {
225 let _ = sub_close_notify.close();
226 }
227 }
228 pub(crate) fn switch_high(&self) -> io::Result<()> {
229 let mut guard = self.sub_udp.write();
230 if !guard.is_empty() {
231 return Ok(());
232 }
233 let mut sub_close_notify_guard = self.sub_close_notify.lock();
234 if let Some(sender) = sub_close_notify_guard.take() {
235 let _ = sender.close();
236 }
237 let (sub_close_notify_sender, sub_close_notify_receiver) = async_broadcast::broadcast(2);
238 let mut sub_udp_list = Vec::with_capacity(self.sub_udp_num);
239 for _ in 0..self.sub_udp_num {
240 let udp = bind_udp(DEFAULT_ADDRESS_V4, self.default_interface.as_ref())?;
241 let udp: std::net::UdpSocket = udp.into();
242 sub_udp_list.push(Arc::new(UdpSocket::from_std(udp)?));
243 }
244 for (index, udp) in sub_udp_list.iter().enumerate() {
245 let udp = udp.clone();
246 let udp_tunnel = InactiveUdpTunnel::new(
247 false,
248 Index::Udp(UDPIndex::SubV4(index)),
249 udp,
250 Some(sub_close_notify_receiver.clone()),
251 );
252 if self.tunnel_dispatcher.try_send(udp_tunnel).is_err() {
253 Err(io::Error::other("tunnel channel error"))?
254 }
255 }
256 sub_close_notify_guard.replace(sub_close_notify_sender);
257 *guard = sub_udp_list;
258 Ok(())
259 }
260
261 #[inline]
262 fn get_udp(&self, udp_index: UDPIndex) -> io::Result<Arc<UdpSocket>> {
263 Ok(match udp_index {
264 UDPIndex::MainV4(index) => self
265 .main_udp_v4
266 .get(index)
267 .ok_or(io::Error::other("index out of bounds"))?
268 .clone(),
269 UDPIndex::MainV6(index) => self
270 .main_udp_v6
271 .get(index)
272 .ok_or(io::Error::other("index out of bounds"))?
273 .clone(),
274 UDPIndex::SubV4(index) => {
275 let guard = self.sub_udp.read();
276 let len = guard.len();
277 if len <= index {
278 return Err(io::Error::other("index out of bounds"));
279 } else {
280 guard[index].clone()
281 }
282 }
283 })
284 }
285
286 #[inline]
287 fn get_udp_from_route(&self, route_key: &RouteKey) -> io::Result<Arc<UdpSocket>> {
288 Ok(match route_key.index() {
289 Index::Udp(index) => self.get_udp(index)?,
290 _ => return Err(io::Error::from(io::ErrorKind::InvalidInput)),
291 })
292 }
293}
294
295impl UdpSocketManager {
296 pub fn model(&self) -> Model {
297 if self.sub_udp.read().is_empty() {
298 Model::Low
299 } else {
300 Model::High
301 }
302 }
303
304 #[inline]
305 pub fn main_udp_v4_count(&self) -> usize {
306 self.main_udp_v4.len()
307 }
308 #[inline]
309 pub fn main_udp_v6_count(&self) -> usize {
310 self.main_udp_v6.len()
311 }
312
313 pub fn switch_model(&self, model: Model) -> io::Result<()> {
314 match model {
315 Model::High => self.switch_high(),
316 Model::Low => {
317 self.switch_low();
318 Ok(())
319 }
320 }
321 }
322 pub fn local_ports(&self) -> io::Result<Vec<u16>> {
324 let mut ports = Vec::with_capacity(self.main_udp_v4_count());
325 for udp in &self.main_udp_v4 {
326 ports.push(udp.local_addr()?.port());
327 }
328 Ok(ports)
329 }
330 pub async fn send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
332 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
333 let len = self
334 .get_udp_from_route(&route_key)?
335 .send_to(buf, route_key.addr())
336 .await?;
337 if len == 0 {
338 return Err(std::io::Error::from(io::ErrorKind::WriteZero));
339 }
340 Ok(())
341 }
342
343 pub fn try_send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
345 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
346 let len = self
347 .get_udp_from_route(&route_key)?
348 .try_send_to(buf, route_key.addr())?;
349 if len == 0 {
350 return Err(std::io::Error::from(io::ErrorKind::WriteZero));
351 }
352 Ok(())
353 }
354
355 pub async fn batch_send_to<T, D: ToRouteKeyForUdp<T>>(
356 &self,
357 bufs: &[IoSlice<'_>],
358 dest: D,
359 ) -> io::Result<()> {
360 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
361 let udp = self.get_udp_from_route(&route_key)?;
362 for buf in bufs {
363 let len = udp.send_to(buf, route_key.addr()).await?;
364 if len == 0 {
365 return Err(std::io::Error::from(io::ErrorKind::WriteZero));
366 }
367 }
368
369 Ok(())
370 }
371 fn get_sender(&self, route_key: &RouteKey) -> io::Result<Sender<(Bytes, SocketAddr)>> {
372 if let Some(sender) = self.sender_map.get(&route_key.index()) {
373 Ok(sender.value().clone())
374 } else {
375 Err(io::Error::new(io::ErrorKind::NotFound, "route not found"))
376 }
377 }
378 pub async fn send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
379 &self,
380 buf: Bytes,
381 dest: D,
382 ) -> io::Result<()> {
383 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
384 let sender = self.get_sender(&route_key)?;
385 if let Err(_e) = sender.send((buf, route_key.addr())).await {
386 Err(io::Error::from(io::ErrorKind::WriteZero))
387 } else {
388 Ok(())
389 }
390 }
391 pub fn try_send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
392 &self,
393 buf: Bytes,
394 dest: D,
395 ) -> io::Result<()> {
396 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
397 let sender = self.get_sender(&route_key)?;
398 if let Err(e) = sender.try_send((buf, route_key.addr())) {
399 match e {
400 TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
401 TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
402 }
403 } else {
404 Ok(())
405 }
406 }
407
408 pub async fn detect_pub_addrs<A: Into<SocketAddr>>(
410 &self,
411 buf: &[u8],
412 addr: A,
413 ) -> io::Result<()> {
414 let addr: SocketAddr = addr.into();
415 for index in 0..self.main_udp_v4_count() {
416 self.send_to(buf, (index, addr)).await?
417 }
418 Ok(())
419 }
420}
421
422pub struct UdpTunnelDispatcher {
423 tunnel_receiver: Receiver<InactiveUdpTunnel>,
424 pub(crate) socket_manager: Arc<UdpSocketManager>,
425}
426
427impl UdpTunnelDispatcher {
428 pub(crate) fn init(&self) -> io::Result<()> {
429 for (index, udp) in self.socket_manager.main_udp_v4.iter().enumerate() {
430 let udp = udp.clone();
431 let tunnel =
432 InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV4(index)), udp, None);
433 if self
434 .socket_manager
435 .tunnel_dispatcher
436 .try_send(tunnel)
437 .is_err()
438 {
439 Err(io::Error::other("tunnel channel error"))?
440 }
441 }
442 for (index, udp) in self.socket_manager.main_udp_v6.iter().enumerate() {
443 let udp = udp.clone();
444 let tunnel =
445 InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV6(index)), udp, None);
446 if self
447 .socket_manager
448 .tunnel_dispatcher
449 .try_send(tunnel)
450 .is_err()
451 {
452 Err(io::Error::other("tunnel channel error"))?
453 }
454 }
455 Ok(())
456 }
457}
458
459impl UdpTunnelDispatcher {
460 pub fn new(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
462 create_tunnel_dispatcher(config)
463 }
464 pub async fn dispatch(&mut self) -> io::Result<UdpTunnel> {
466 let mut udp_tunnel = self
467 .tunnel_receiver
468 .recv()
469 .await
470 .map_err(|_| io::Error::other("Udp tunnel close"))?;
471 let option = self
472 .socket_manager
473 .sender_map
474 .get(&udp_tunnel.index)
475 .map(|v| v.value().clone());
476 let sender = if let Some(v) = option {
477 v
478 } else {
479 let (s, mut r) = tachyonix::channel(128);
480 let index = udp_tunnel.index;
481 let sender = s.clone();
482 self.socket_manager.sender_map.insert(index, s);
483
484 let socket_manager = self.socket_manager.clone();
485 let udp = udp_tunnel.udp.clone();
486 tokio::spawn(async move {
487 #[cfg(all(feature = "sendmmsg", any(target_os = "linux", target_os = "android")))]
488 let mut vec_buf = Vec::with_capacity(16);
489
490 while let Ok((buf, addr)) = r.recv().await {
491 #[cfg(all(
492 feature = "sendmmsg",
493 any(target_os = "linux", target_os = "android")
494 ))]
495 {
496 vec_buf.push((buf, addr));
497 while let Ok(tup) = r.try_recv() {
498 vec_buf.push(tup);
499 if vec_buf.len() == MAX_MESSAGES {
500 break;
501 }
502 }
503 let mut bufs = &mut vec_buf[..];
504 let fd = udp.as_raw_fd();
505 loop {
506 if bufs.len() == 1 {
507 let (buf, addr) = unsafe { bufs.get_unchecked(0) };
508 if let Err(e) = udp.send_to(buf, *addr).await {
509 log::warn!("send_to {addr:?},{e:?}")
510 }
511 break;
512 } else {
513 let rs = write_with(&udp, || sendmmsg(fd, bufs)).await;
514 match rs {
515 Ok(size) => {
516 if size == 0 {
517 break;
518 }
519 if size < bufs.len() {
520 bufs = &mut bufs[size..];
521 continue;
522 }
523 break;
524 }
525 Err(e) => {
526 log::warn!("sendmmsg {e:?}");
527 }
528 }
529 }
530 }
531 vec_buf.clear();
532 }
533 #[cfg(any(
534 not(any(target_os = "linux", target_os = "android")),
535 not(feature = "sendmmsg")
536 ))]
537 {
538 let rs = udp.send_to(&buf, addr).await;
539 if let Err(e) = rs {
540 log::debug!("{addr:?},{e:?}")
541 }
542 }
543 }
544 socket_manager.sender_map.remove(&index);
545 });
546 sender
547 };
548 if udp_tunnel.sender.is_none() {
549 udp_tunnel.sender.replace(OwnedUdpTunnelSender { sender });
550 }
551 if udp_tunnel.reusable {
552 UdpTunnel::with_main(udp_tunnel, self.manager().tunnel_dispatcher.clone())
553 } else {
554 UdpTunnel::with_sub(udp_tunnel)
555 }
556 }
557 pub fn manager(&self) -> &Arc<UdpSocketManager> {
558 &self.socket_manager
559 }
560}
561
562#[cfg(all(feature = "sendmmsg", any(target_os = "linux", target_os = "android")))]
563fn sendmmsg(fd: std::os::fd::RawFd, bufs: &mut [(Bytes, SocketAddr)]) -> io::Result<usize> {
564 assert!(bufs.len() <= MAX_MESSAGES);
565 let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
566 let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
567 let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
568 for (i, (buf, addr)) in bufs.iter_mut().enumerate() {
569 addrs[i] = socket_addr_to_sockaddr(addr);
570 iov[i].iov_base = buf.as_mut_ptr() as *mut libc::c_void;
571 iov[i].iov_len = buf.len();
572 msgs[i].msg_hdr.msg_iov = &mut iov[i];
573 msgs[i].msg_hdr.msg_iovlen = 1;
574
575 msgs[i].msg_hdr.msg_name = &mut addrs[i] as *mut _ as *mut libc::c_void;
576 msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
577 }
578
579 unsafe {
580 let res = libc::sendmmsg(
581 fd,
582 msgs.as_mut_ptr(),
583 bufs.len() as _,
584 libc::MSG_DONTWAIT as _,
585 );
586 if res == -1 {
587 return Err(io::Error::last_os_error());
588 }
589 Ok(res as usize)
590 }
591}
592
593#[cfg(all(feature = "sendmmsg", any(target_os = "linux", target_os = "android")))]
594fn socket_addr_to_sockaddr(addr: &SocketAddr) -> sockaddr_storage {
595 let mut storage: sockaddr_storage = unsafe { std::mem::zeroed() };
596
597 match addr {
598 SocketAddr::V4(v4_addr) => {
599 let sin = libc::sockaddr_in {
600 sin_family: libc::AF_INET as _,
601 sin_port: v4_addr.port().to_be(),
602 sin_addr: libc::in_addr {
603 s_addr: u32::from_ne_bytes(v4_addr.ip().octets()), },
605 sin_zero: [0; 8],
606 };
607
608 unsafe {
609 let sin_ptr = &sin as *const libc::sockaddr_in as *const u8;
610 let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
611 std::ptr::copy_nonoverlapping(
612 sin_ptr,
613 storage_ptr,
614 std::mem::size_of::<libc::sockaddr>(),
615 );
616 }
617 }
618 SocketAddr::V6(v6_addr) => {
619 let sin6 = libc::sockaddr_in6 {
620 sin6_family: libc::AF_INET6 as _,
621 sin6_port: v6_addr.port().to_be(),
622 sin6_flowinfo: v6_addr.flowinfo(),
623 sin6_addr: libc::in6_addr {
624 s6_addr: v6_addr.ip().octets(),
625 },
626 sin6_scope_id: v6_addr.scope_id(),
627 };
628
629 unsafe {
630 let sin6_ptr = &sin6 as *const libc::sockaddr_in6 as *const u8;
631 let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
632 std::ptr::copy_nonoverlapping(
633 sin6_ptr,
634 storage_ptr,
635 std::mem::size_of::<libc::sockaddr>(),
636 );
637 }
638 }
639 }
640 storage
641}
642
643pub struct UdpTunnel {
644 index: Index,
645 local_addr: SocketAddr,
646 udp: Option<Arc<UdpSocket>>,
647 close_notify: Option<async_broadcast::Receiver<()>>,
648 re_dispatcher: Option<Sender<InactiveUdpTunnel>>,
649 sender: Option<OwnedUdpTunnelSender>,
650}
651struct OwnedUdpTunnelSender {
652 sender: Sender<(Bytes, SocketAddr)>,
653}
654#[derive(Clone)]
655pub struct WeakUdpTunnelSender {
656 sender: Sender<(Bytes, SocketAddr)>,
657}
658struct InactiveUdpTunnel {
659 reusable: bool,
660 index: Index,
661 udp: Arc<UdpSocket>,
662 close_notify: Option<async_broadcast::Receiver<()>>,
663 sender: Option<OwnedUdpTunnelSender>,
664}
665impl InactiveUdpTunnel {
666 fn new(
667 reusable: bool,
668 index: Index,
669 udp: Arc<UdpSocket>,
670 close_notify: Option<async_broadcast::Receiver<()>>,
671 ) -> Self {
672 Self {
673 reusable,
674 index,
675 udp,
676 close_notify,
677 sender: None,
678 }
679 }
680 fn redistribute(index: Index, udp: Arc<UdpSocket>, sender: OwnedUdpTunnelSender) -> Self {
681 Self {
682 reusable: true,
683 index,
684 udp,
685 close_notify: None,
686 sender: Some(sender),
687 }
688 }
689}
690impl OwnedUdpTunnelSender {
691 async fn send_to<A: Into<SocketAddr>>(&self, buf: Bytes, dest: A) -> io::Result<()> {
692 if buf.is_empty() {
693 return Ok(());
694 }
695 self.sender
696 .send((buf, dest.into()))
697 .await
698 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
699 }
700 fn is_closed(&self) -> bool {
701 self.sender.is_closed()
702 }
703}
704impl WeakUdpTunnelSender {
705 pub async fn send_to<A: Into<SocketAddr>>(&self, buf: Bytes, dest: A) -> io::Result<()> {
706 if buf.is_empty() {
707 return Ok(());
708 }
709 self.sender
710 .send((buf, dest.into()))
711 .await
712 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
713 }
714 pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: Bytes, dest: A) -> io::Result<()> {
715 if buf.is_empty() {
716 return Ok(());
717 }
718 self.sender
719 .try_send((buf, dest.into()))
720 .map_err(|e| match e {
721 TrySendError::Full(_) => io::Error::from(io::ErrorKind::WouldBlock),
722 TrySendError::Closed(_) => io::Error::from(io::ErrorKind::WriteZero),
723 })
724 }
725}
726impl Drop for OwnedUdpTunnelSender {
727 fn drop(&mut self) {
728 self.sender.close();
729 }
730}
731impl Drop for UdpTunnel {
732 fn drop(&mut self) {
733 let Some(sender) = self.sender.take() else {
734 return;
735 };
736 if sender.is_closed() {
737 return;
738 }
739 let Some(udp) = self.udp.take() else {
740 return;
741 };
742 let Some(re_dispatcher) = self.re_dispatcher.take() else {
743 return;
744 };
745 let rs = re_dispatcher.try_send(InactiveUdpTunnel::redistribute(self.index, udp, sender));
746 if let Err(TrySendError::Full(_)) = rs {
747 log::warn!("Udp Tunnel TrySendError full");
748 }
749 }
750}
751
752impl UdpTunnel {
753 fn with_sub(inactive_udp_tunnel: InactiveUdpTunnel) -> io::Result<Self> {
754 let local_addr = inactive_udp_tunnel.udp.local_addr()?;
755 Ok(Self {
756 index: inactive_udp_tunnel.index,
757 local_addr,
758 udp: Some(inactive_udp_tunnel.udp),
759 close_notify: inactive_udp_tunnel.close_notify,
760 re_dispatcher: None,
761 sender: inactive_udp_tunnel.sender,
762 })
763 }
764 fn with_main(
765 inactive_udp_tunnel: InactiveUdpTunnel,
766 re_sender: Sender<InactiveUdpTunnel>,
767 ) -> io::Result<Self> {
768 let local_addr = inactive_udp_tunnel.udp.local_addr()?;
769 Ok(Self {
770 local_addr,
771 index: inactive_udp_tunnel.index,
772 udp: Some(inactive_udp_tunnel.udp),
773 close_notify: None,
774 re_dispatcher: Some(re_sender),
775 sender: inactive_udp_tunnel.sender,
776 })
777 }
778 pub fn done(&mut self) {
779 _ = self.udp.take();
780 _ = self.close_notify.take();
781 _ = self.re_dispatcher.take();
782 _ = self.re_dispatcher.take();
783 _ = self.sender.take();
784 }
785 pub fn local_addr(&self) -> SocketAddr {
786 self.local_addr
787 }
788 pub fn sender(&self) -> io::Result<WeakUdpTunnelSender> {
789 if let Some(v) = &self.sender {
790 Ok(WeakUdpTunnelSender {
791 sender: v.sender.clone(),
792 })
793 } else {
794 Err(io::Error::other("closed"))
795 }
796 }
797}
798
799impl UdpTunnel {
800 pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
802 if let Some(udp) = &self.udp {
803 udp.send_to(buf, addr.into()).await?;
804 Ok(())
805 } else {
806 Err(io::Error::other("closed"))
807 }
808 }
809 pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
811 if let Some(udp) = &self.udp {
812 udp.try_send_to(buf, addr.into())?;
813 Ok(())
814 } else {
815 Err(io::Error::other("closed"))
816 }
817 }
818 pub async fn send_bytes_to<A: Into<SocketAddr>>(&self, buf: Bytes, addr: A) -> io::Result<()> {
819 if let Some(sender) = &self.sender {
820 sender.send_to(buf, addr).await
821 } else {
822 Err(io::Error::other("closed"))
823 }
824 }
825
826 pub async fn recv_from(&mut self, buf: &mut [u8]) -> Option<io::Result<(usize, RouteKey)>> {
830 let udp = if let Some(udp) = &self.udp {
831 udp
832 } else {
833 return None;
834 };
835 loop {
836 if let Some(close_notify) = &mut self.close_notify {
837 tokio::select! {
838 _rs=close_notify.recv()=>{
839 self.done();
840 return None
841 }
842 result=udp.recv_from(buf)=>{
843 let (len, addr) = match result {
844 Ok(rs) => rs,
845 Err(e) => {
846 if should_ignore_error(&e) {
847 continue;
848 }
849 return Some(Err(e))
850 }
851 };
852 return Some(Ok((len, RouteKey::new(self.index, addr))))
853 }
854 }
855 } else {
856 let (len, addr) = match udp.recv_from(buf).await {
857 Ok(rs) => rs,
858 Err(e) => {
859 if should_ignore_error(&e) {
860 continue;
861 }
862 return Some(Err(e));
863 }
864 };
865 return Some(Ok((len, RouteKey::new(self.index, addr))));
866 }
867 }
868 }
869 #[cfg(not(any(target_os = "linux", target_os = "android")))]
870 pub async fn batch_recv_from<B: AsMut<[u8]>>(
871 &mut self,
872 bufs: &mut [B],
873 sizes: &mut [usize],
874 addrs: &mut [RouteKey],
875 ) -> Option<io::Result<usize>> {
876 if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
877 return Some(Err(io::Error::other("bufs error")));
878 }
879 let rs = self.recv_from(bufs[0].as_mut()).await?;
880 match rs {
881 Ok((len, addr)) => {
882 let udp = self.udp.as_ref()?;
883 sizes[0] = len;
884 addrs[0] = addr;
885 let mut num = 1;
886 while num < bufs.len() {
887 match udp.try_recv_from(bufs[num].as_mut()) {
888 Ok((len, addr)) => {
889 sizes[num] = len;
890 addrs[num] = RouteKey::new(self.index, addr);
891 num += 1;
892 }
893 Err(_) => break,
894 }
895 }
896 Some(Ok(num))
897 }
898 Err(e) => Some(Err(e)),
899 }
900 }
901 #[cfg(any(target_os = "linux", target_os = "android"))]
902 pub async fn batch_recv_from<B: AsMut<[u8]>>(
903 &mut self,
904 bufs: &mut [B],
905 sizes: &mut [usize],
906 addrs: &mut [RouteKey],
907 ) -> Option<io::Result<usize>> {
908 if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
909 return Some(Err(io::Error::other("bufs/sizes/addrs error")));
910 }
911 let udp = self.udp.as_ref()?;
912 let fd = udp.as_raw_fd();
913 loop {
914 let rs = if let Some(close_notify) = &mut self.close_notify {
915 tokio::select! {
916 _rs=close_notify.recv()=>{
917 self.done();
918 return None
919 }
920 rs=read_with(udp,|| recvmmsg(self.index, fd, bufs, sizes, addrs))=>{
921 rs
922 }
923 }
924 } else {
925 read_with(udp, || recvmmsg(self.index, fd, bufs, sizes, addrs)).await
926 };
927 return match rs {
928 Ok(size) => Some(Ok(size)),
929 Err(e) => {
930 if should_ignore_error(&e) {
931 continue;
932 }
933 Some(Err(e))
934 }
935 };
936 }
937 }
938}
939
940#[cfg(any(target_os = "linux", target_os = "android"))]
941fn recvmmsg<B: AsMut<[u8]>>(
942 index: Index,
943 fd: std::os::fd::RawFd,
944 bufs: &mut [B],
945 sizes: &mut [usize],
946 route_keys: &mut [RouteKey],
947) -> io::Result<usize> {
948 let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
949 let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
950 let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
951 let max_num = bufs.len().min(MAX_MESSAGES);
952 for i in 0..max_num {
953 iov[i].iov_base = bufs[i].as_mut().as_mut_ptr() as *mut libc::c_void;
954 iov[i].iov_len = bufs[i].as_mut().len();
955 msgs[i].msg_hdr.msg_iov = &mut iov[i];
956 msgs[i].msg_hdr.msg_iovlen = 1;
957 msgs[i].msg_hdr.msg_name = &mut addrs[i] as *const _ as *mut libc::c_void;
958 msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
959 }
960 let res = unsafe {
961 libc::recvmmsg(
962 fd,
963 msgs.as_mut_ptr(),
964 max_num as c_uint,
965 libc::MSG_DONTWAIT as _,
966 std::ptr::null_mut(),
967 )
968 };
969 if res == -1 {
970 return Err(io::Error::last_os_error());
971 }
972 let nmsgs = res as usize;
973 if nmsgs == 0 {
974 return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
975 }
976 for i in 0..nmsgs {
977 let addr = sockaddr_to_socket_addr(&addrs[i], msgs[i].msg_hdr.msg_namelen);
978 sizes[i] = msgs[i].msg_len as usize;
979 route_keys[i] = RouteKey::new(index, addr);
980 }
981 Ok(nmsgs)
982}
983#[cfg(any(target_os = "linux", target_os = "android"))]
984fn sockaddr_to_socket_addr(addr: &sockaddr_storage, _len: socklen_t) -> SocketAddr {
985 match addr.ss_family as libc::c_int {
986 libc::AF_INET => {
987 let addr_in = unsafe { *(addr as *const _ as *const libc::sockaddr_in) };
988 let ip = u32::from_be(addr_in.sin_addr.s_addr);
989 let port = u16::from_be(addr_in.sin_port);
990 SocketAddr::V4(std::net::SocketAddrV4::new(
991 std::net::Ipv4Addr::from(ip),
992 port,
993 ))
994 }
995 libc::AF_INET6 => {
996 let addr_in6 = unsafe { *(addr as *const _ as *const libc::sockaddr_in6) };
997 let ip = std::net::Ipv6Addr::from(addr_in6.sin6_addr.s6_addr);
998 let port = u16::from_be(addr_in6.sin6_port);
999 SocketAddr::V6(std::net::SocketAddrV6::new(ip, port, 0, 0))
1000 }
1001 _ => panic!("Unsupported address family"),
1002 }
1003}
1004
1005fn should_ignore_error(e: &io::Error) -> bool {
1006 #[cfg(windows)]
1007 {
1008 if let Some(os_error) = e.raw_os_error() {
1010 return os_error == windows_sys::Win32::Networking::WinSock::WSAECONNRESET;
1011 }
1012 }
1013 _ = e;
1014 false
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019 use std::time::Duration;
1020
1021 use crate::tunnel::udp::{Model, UdpTunnel};
1022
1023 #[tokio::test]
1024 pub async fn create_udp_tunnel() {
1025 let config = crate::tunnel::config::UdpTunnelConfig::default()
1026 .set_main_udp_count(2)
1027 .set_sub_udp_count(10)
1028 .set_model(Model::Low)
1029 .set_use_v6(false);
1030 let mut udp_tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1031 let mut count = 0;
1032 let mut join = Vec::new();
1033 while let Ok(rs) =
1034 tokio::time::timeout(Duration::from_secs(1), udp_tunnel_factory.dispatch()).await
1035 {
1036 join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1037 count += 1;
1038 }
1039 assert_eq!(count, 2)
1040 }
1041
1042 #[tokio::test]
1043 pub async fn create_sub_udp_tunnel() {
1044 let config = crate::tunnel::config::UdpTunnelConfig::default()
1045 .set_main_udp_count(2)
1046 .set_sub_udp_count(10)
1047 .set_use_v6(false)
1048 .set_model(Model::High);
1049 let mut tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1050 let mut count = 0;
1051 let mut join = Vec::new();
1052 while let Ok(rs) =
1053 tokio::time::timeout(Duration::from_secs(1), tunnel_factory.dispatch()).await
1054 {
1055 join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1056 count += 1;
1057 }
1058 tunnel_factory.manager().switch_low();
1059
1060 let mut close_tunnel_count = 0;
1061 for x in join {
1062 let rs = tokio::time::timeout(Duration::from_secs(1), x).await;
1063 match rs {
1064 Ok(rs) => {
1065 if rs.unwrap() {
1066 close_tunnel_count += 1;
1068 }
1069 }
1070 Err(_e) => {
1071 _ = _e;
1072 }
1073 }
1074 }
1075 assert_eq!(count, 12);
1076 assert_eq!(close_tunnel_count, 10);
1077 }
1078
1079 async fn tunnel_recv(mut tunnel: UdpTunnel) -> bool {
1080 let mut buf = [0; 1400];
1081 tunnel.recv_from(&mut buf).await.is_none()
1082 }
1083}