1use std::io;
2use std::io::IoSlice;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6#[cfg(any(target_os = "linux", target_os = "android"))]
7use tokio::io::Interest;
8
9#[cfg(any(target_os = "linux", target_os = "android"))]
10pub async fn read_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
11 udp.async_io(Interest::READABLE, op).await
12}
13#[cfg(any(target_os = "linux", target_os = "android"))]
14pub async fn write_with<R>(udp: &UdpSocket, op: impl FnMut() -> io::Result<R>) -> io::Result<R> {
15 udp.async_io(Interest::WRITABLE, op).await
16}
17
18use bytes::BytesMut;
19use dashmap::DashMap;
20use parking_lot::{Mutex, RwLock};
21use tachyonix::{Receiver, Sender, TrySendError};
22use tokio::net::UdpSocket;
23
24use crate::route::{Index, RouteKey};
25use crate::socket::{bind_udp, LocalInterface};
26use crate::tunnel::config::UdpTunnelConfig;
27use crate::tunnel::recycle::RecycleBuf;
28use crate::tunnel::{DEFAULT_ADDRESS_V4, DEFAULT_ADDRESS_V6};
29
30#[cfg(any(target_os = "linux", target_os = "android"))]
31const MAX_MESSAGES: usize = 16;
32#[cfg(any(target_os = "linux", target_os = "android"))]
33use libc::{c_uint, iovec, mmsghdr, sockaddr_storage, socklen_t};
34#[cfg(any(target_os = "linux", target_os = "android"))]
35use std::os::fd::AsRawFd;
36
37#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
38pub enum Model {
39 High,
40 #[default]
41 Low,
42}
43
44impl Model {
45 pub fn is_low(&self) -> bool {
46 self == &Model::Low
47 }
48 pub fn is_high(&self) -> bool {
49 self == &Model::High
50 }
51}
52
53#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
54pub enum UDPIndex {
55 MainV4(usize),
56 MainV6(usize),
57 SubV4(usize),
58}
59
60impl UDPIndex {
61 pub(crate) fn index(&self) -> usize {
62 match self {
63 UDPIndex::MainV4(i) => *i,
64 UDPIndex::MainV6(i) => *i,
65 UDPIndex::SubV4(i) => *i,
66 }
67 }
68}
69
70pub trait ToRouteKeyForUdp<T> {
71 fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey>;
72}
73
74impl ToRouteKeyForUdp<()> for RouteKey {
75 fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
76 Ok(dest)
77 }
78}
79
80impl ToRouteKeyForUdp<()> for &RouteKey {
81 fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
82 Ok(*dest)
83 }
84}
85
86impl ToRouteKeyForUdp<()> for &mut RouteKey {
87 fn route_key(_: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
88 Ok(*dest)
89 }
90}
91
92impl<S: Into<SocketAddr>> ToRouteKeyForUdp<()> for S {
93 fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
94 let addr = dest.into();
95 socket_manager.generate_route_key_from_addr(0, addr)
96 }
97}
98
99impl<S: Into<SocketAddr>> ToRouteKeyForUdp<usize> for (usize, S) {
100 fn route_key(socket_manager: &UdpSocketManager, dest: Self) -> io::Result<RouteKey> {
101 let (index, addr) = dest;
102 socket_manager.generate_route_key_from_addr(index, addr.into())
103 }
104}
105
106pub(crate) fn create_tunnel_dispatcher(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
108 config.check()?;
109 let mut udp_ports = config.udp_ports;
110 udp_ports.resize(config.main_udp_count, 0);
111 let mut main_udp_v4: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
112 let mut main_udp_v6: Vec<Arc<UdpSocket>> = Vec::with_capacity(config.main_udp_count);
113 for port in &udp_ports {
115 loop {
116 let mut addr_v4 = DEFAULT_ADDRESS_V4;
117 addr_v4.set_port(*port);
118 let socket_v4 = bind_udp(addr_v4, config.default_interface.as_ref())?;
119 let udp_v4: std::net::UdpSocket = socket_v4.into();
120 if config.use_v6 {
121 let mut addr_v6 = DEFAULT_ADDRESS_V6;
122 let socket_v6 = if *port == 0 {
123 let port = udp_v4.local_addr()?.port();
124 addr_v6.set_port(port);
125 match bind_udp(addr_v6, config.default_interface.as_ref()) {
126 Ok(socket_v6) => socket_v6,
127 Err(_) => continue,
128 }
129 } else {
130 addr_v6.set_port(*port);
131 bind_udp(addr_v6, config.default_interface.as_ref())?
132 };
133 let udp_v6: std::net::UdpSocket = socket_v6.into();
134 main_udp_v6.push(Arc::new(UdpSocket::from_std(udp_v6)?))
135 }
136 main_udp_v4.push(Arc::new(UdpSocket::from_std(udp_v4)?));
137 break;
138 }
139 }
140 let (tunnel_sender, tunnel_receiver) =
141 tachyonix::channel(config.main_udp_count * 2 + config.sub_udp_count * 2);
142 let socket_manager = Arc::new(UdpSocketManager {
143 main_udp_v4,
144 main_udp_v6,
145 sub_udp: RwLock::new(Vec::with_capacity(config.sub_udp_count)),
146 sub_close_notify: Default::default(),
147 tunnel_dispatcher: tunnel_sender,
148 sub_udp_num: config.sub_udp_count,
149 default_interface: config.default_interface,
150 sender_map: Default::default(),
151 });
152 let tunnel_factory = UdpTunnelDispatcher {
153 tunnel_receiver,
154 socket_manager,
155 recycle_buf: config.recycle_buf,
156 };
157 tunnel_factory.init()?;
158 tunnel_factory.socket_manager.switch_model(config.model)?;
159 Ok(tunnel_factory)
160}
161
162pub struct UdpSocketManager {
163 main_udp_v4: Vec<Arc<UdpSocket>>,
164 main_udp_v6: Vec<Arc<UdpSocket>>,
165 sub_udp: RwLock<Vec<Arc<UdpSocket>>>,
166 sub_close_notify: Mutex<Option<async_broadcast::Sender<()>>>,
167 tunnel_dispatcher: Sender<InactiveUdpTunnel>,
168 sub_udp_num: usize,
169 default_interface: Option<LocalInterface>,
170 sender_map: DashMap<Index, Sender<(BytesMut, SocketAddr)>>,
171}
172
173impl UdpSocketManager {
174 pub(crate) fn try_sub_batch_send_to(&self, buf: &[u8], addr: SocketAddr) {
175 for (i, udp) in self.sub_udp.read().iter().enumerate() {
176 if let Err(e) = udp.try_send_to(buf, addr) {
177 log::info!("try_sub_send_to_addr_v4: {e:?},{i},{addr}")
178 }
179 }
180 }
181 pub(crate) fn try_main_v4_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
182 let len = self.main_udp_v4_count();
183 self.try_main_batch_send_to_impl(buf, addr, len);
184 }
185 pub(crate) fn try_main_v6_batch_send_to(&self, buf: &[u8], addr: &[SocketAddr]) {
186 let len = self.main_udp_v6_count();
187 self.try_main_batch_send_to_impl(buf, addr, len);
188 }
189
190 pub(crate) fn try_main_batch_send_to_impl(&self, buf: &[u8], addr: &[SocketAddr], len: usize) {
191 for (i, addr) in addr.iter().enumerate() {
192 if let Err(e) = self.try_send_to(buf, (i % len, *addr)) {
193 log::info!("try_main_send_to_addr: {e:?},{},{addr}", i % len);
194 }
195 }
196 }
197 pub(crate) fn generate_route_key_from_addr(
198 &self,
199 index: usize,
200 addr: SocketAddr,
201 ) -> io::Result<RouteKey> {
202 let route_key = if addr.is_ipv4() {
203 let len = self.main_udp_v4.len();
204 if index >= len {
205 return Err(io::Error::new(io::ErrorKind::Other, "index out of bounds"));
206 }
207 RouteKey::new(Index::Udp(UDPIndex::MainV4(index)), addr)
208 } else {
209 let len = self.main_udp_v6.len();
210 if len == 0 {
211 return Err(io::Error::new(io::ErrorKind::Other, "Not support IPV6"));
212 }
213 if index >= len {
214 return Err(io::Error::new(io::ErrorKind::Other, "index out of bounds"));
215 }
216 RouteKey::new(Index::Udp(UDPIndex::MainV6(index)), addr)
217 };
218 Ok(route_key)
219 }
220 pub(crate) fn switch_low(&self) {
221 let mut guard = self.sub_udp.write();
222 if guard.is_empty() {
223 return;
224 }
225 guard.clear();
226 if let Some(sub_close_notify) = self.sub_close_notify.lock().take() {
227 let _ = sub_close_notify.close();
228 }
229 }
230 pub(crate) fn switch_high(&self) -> io::Result<()> {
231 let mut guard = self.sub_udp.write();
232 if !guard.is_empty() {
233 return Ok(());
234 }
235 let mut sub_close_notify_guard = self.sub_close_notify.lock();
236 if let Some(sender) = sub_close_notify_guard.take() {
237 let _ = sender.close();
238 }
239 let (sub_close_notify_sender, sub_close_notify_receiver) = async_broadcast::broadcast(2);
240 let mut sub_udp_list = Vec::with_capacity(self.sub_udp_num);
241 for _ in 0..self.sub_udp_num {
242 let udp = bind_udp(DEFAULT_ADDRESS_V4, self.default_interface.as_ref())?;
243 let udp: std::net::UdpSocket = udp.into();
244 sub_udp_list.push(Arc::new(UdpSocket::from_std(udp)?));
245 }
246 for (index, udp) in sub_udp_list.iter().enumerate() {
247 let udp = udp.clone();
248 let udp_tunnel = InactiveUdpTunnel::new(
249 false,
250 Index::Udp(UDPIndex::SubV4(index)),
251 udp,
252 Some(sub_close_notify_receiver.clone()),
253 );
254 if self.tunnel_dispatcher.try_send(udp_tunnel).is_err() {
255 Err(io::Error::new(io::ErrorKind::Other, "tunnel channel error"))?
256 }
257 }
258 sub_close_notify_guard.replace(sub_close_notify_sender);
259 *guard = sub_udp_list;
260 Ok(())
261 }
262
263 #[inline]
264 fn get_udp(&self, udp_index: UDPIndex) -> io::Result<Arc<UdpSocket>> {
265 Ok(match udp_index {
266 UDPIndex::MainV4(index) => self
267 .main_udp_v4
268 .get(index)
269 .ok_or(io::Error::new(io::ErrorKind::Other, "index out of bounds"))?
270 .clone(),
271 UDPIndex::MainV6(index) => self
272 .main_udp_v6
273 .get(index)
274 .ok_or(io::Error::new(io::ErrorKind::Other, "index out of bounds"))?
275 .clone(),
276 UDPIndex::SubV4(index) => {
277 let guard = self.sub_udp.read();
278 let len = guard.len();
279 if len <= index {
280 return Err(io::Error::new(io::ErrorKind::Other, "index out of bounds"));
281 } else {
282 guard[index].clone()
283 }
284 }
285 })
286 }
287
288 #[inline]
289 fn get_udp_from_route(&self, route_key: &RouteKey) -> io::Result<Arc<UdpSocket>> {
290 Ok(match route_key.index() {
291 Index::Udp(index) => self.get_udp(index)?,
292 _ => return Err(io::Error::from(io::ErrorKind::InvalidInput)),
293 })
294 }
295}
296
297impl UdpSocketManager {
298 pub fn model(&self) -> Model {
299 if self.sub_udp.read().is_empty() {
300 Model::Low
301 } else {
302 Model::High
303 }
304 }
305
306 #[inline]
307 pub fn main_udp_v4_count(&self) -> usize {
308 self.main_udp_v4.len()
309 }
310 #[inline]
311 pub fn main_udp_v6_count(&self) -> usize {
312 self.main_udp_v6.len()
313 }
314
315 pub fn switch_model(&self, model: Model) -> io::Result<()> {
316 match model {
317 Model::High => self.switch_high(),
318 Model::Low => {
319 self.switch_low();
320 Ok(())
321 }
322 }
323 }
324 pub fn local_ports(&self) -> io::Result<Vec<u16>> {
326 let mut ports = Vec::with_capacity(self.main_udp_v4_count());
327 for udp in &self.main_udp_v4 {
328 ports.push(udp.local_addr()?.port());
329 }
330 Ok(ports)
331 }
332 pub async fn send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
334 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
335 let len = self
336 .get_udp_from_route(&route_key)?
337 .send_to(buf, route_key.addr())
338 .await?;
339 if len == 0 {
340 return Err(std::io::Error::from(io::ErrorKind::WriteZero));
341 }
342 Ok(())
343 }
344
345 pub fn try_send_to<T, D: ToRouteKeyForUdp<T>>(&self, buf: &[u8], dest: D) -> io::Result<()> {
347 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
348 let len = self
349 .get_udp_from_route(&route_key)?
350 .try_send_to(buf, route_key.addr())?;
351 if len == 0 {
352 return Err(std::io::Error::from(io::ErrorKind::WriteZero));
353 }
354 Ok(())
355 }
356
357 pub async fn batch_send_to<T, D: ToRouteKeyForUdp<T>>(
358 &self,
359 bufs: &[IoSlice<'_>],
360 dest: D,
361 ) -> io::Result<()> {
362 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
363 let udp = self.get_udp_from_route(&route_key)?;
364 for buf in bufs {
365 let len = udp.send_to(buf, route_key.addr()).await?;
366 if len == 0 {
367 return Err(std::io::Error::from(io::ErrorKind::WriteZero));
368 }
369 }
370
371 Ok(())
372 }
373 fn get_sender(&self, route_key: &RouteKey) -> io::Result<Sender<(BytesMut, SocketAddr)>> {
374 if let Some(sender) = self.sender_map.get(&route_key.index()) {
375 Ok(sender.value().clone())
376 } else {
377 Err(io::Error::new(io::ErrorKind::NotFound, "route not found"))
378 }
379 }
380 pub async fn send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
381 &self,
382 buf: BytesMut,
383 dest: D,
384 ) -> io::Result<()> {
385 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
386 let sender = self.get_sender(&route_key)?;
387 if let Err(_e) = sender.send((buf, route_key.addr())).await {
388 Err(io::Error::from(io::ErrorKind::WriteZero))
389 } else {
390 Ok(())
391 }
392 }
393 pub fn try_send_bytes_to<T, D: ToRouteKeyForUdp<T>>(
394 &self,
395 buf: BytesMut,
396 dest: D,
397 ) -> io::Result<()> {
398 let route_key = ToRouteKeyForUdp::route_key(self, dest)?;
399 let sender = self.get_sender(&route_key)?;
400 if let Err(e) = sender.try_send((buf, route_key.addr())) {
401 match e {
402 TrySendError::Full(_) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
403 TrySendError::Closed(_) => Err(io::Error::from(io::ErrorKind::WriteZero)),
404 }
405 } else {
406 Ok(())
407 }
408 }
409
410 pub async fn detect_pub_addrs<A: Into<SocketAddr>>(
412 &self,
413 buf: &[u8],
414 addr: A,
415 ) -> io::Result<()> {
416 let addr: SocketAddr = addr.into();
417 for index in 0..self.main_udp_v4_count() {
418 self.send_to(buf, (index, addr)).await?
419 }
420 Ok(())
421 }
422}
423
424pub struct UdpTunnelDispatcher {
425 tunnel_receiver: Receiver<InactiveUdpTunnel>,
426 pub(crate) socket_manager: Arc<UdpSocketManager>,
427 recycle_buf: Option<RecycleBuf>,
428}
429
430impl UdpTunnelDispatcher {
431 pub(crate) fn init(&self) -> io::Result<()> {
432 for (index, udp) in self.socket_manager.main_udp_v4.iter().enumerate() {
433 let udp = udp.clone();
434 let tunnel =
435 InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV4(index)), udp, None);
436 if self
437 .socket_manager
438 .tunnel_dispatcher
439 .try_send(tunnel)
440 .is_err()
441 {
442 Err(io::Error::new(io::ErrorKind::Other, "tunnel channel error"))?
443 }
444 }
445 for (index, udp) in self.socket_manager.main_udp_v6.iter().enumerate() {
446 let udp = udp.clone();
447 let tunnel =
448 InactiveUdpTunnel::new(true, Index::Udp(UDPIndex::MainV6(index)), udp, None);
449 if self
450 .socket_manager
451 .tunnel_dispatcher
452 .try_send(tunnel)
453 .is_err()
454 {
455 Err(io::Error::new(io::ErrorKind::Other, "tunnel channel error"))?
456 }
457 }
458 Ok(())
459 }
460}
461
462impl UdpTunnelDispatcher {
463 pub fn new(config: UdpTunnelConfig) -> io::Result<UdpTunnelDispatcher> {
465 create_tunnel_dispatcher(config)
466 }
467 pub async fn dispatch(&mut self) -> io::Result<UdpTunnel> {
469 let mut udp_tunnel = self
470 .tunnel_receiver
471 .recv()
472 .await
473 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Udp tunnel close"))?;
474 let option = self
475 .socket_manager
476 .sender_map
477 .get(&udp_tunnel.index)
478 .map(|v| v.value().clone());
479 let sender = if let Some(v) = option {
480 v
481 } else {
482 let (s, mut r) = tachyonix::channel(128);
483 let index = udp_tunnel.index;
484 let sender = s.clone();
485 self.socket_manager.sender_map.insert(index, s);
486
487 let socket_manager = self.socket_manager.clone();
488 let udp = udp_tunnel.udp.clone();
489 let recycle_buf = self.recycle_buf.clone();
490 tokio::spawn(async move {
491 #[cfg(any(target_os = "linux", target_os = "android"))]
492 let mut vec_buf = Vec::with_capacity(16);
493
494 while let Ok((buf, addr)) = r.recv().await {
495 #[cfg(any(target_os = "linux", target_os = "android"))]
496 {
497 vec_buf.push((buf, addr));
498 while let Ok(tup) = r.try_recv() {
499 vec_buf.push(tup);
500 if vec_buf.len() == MAX_MESSAGES {
501 break;
502 }
503 }
504 let mut bufs = &mut vec_buf[..];
505 let fd = udp.as_raw_fd();
506 loop {
507 if bufs.len() == 1 {
508 let (buf, addr) = unsafe { bufs.get_unchecked(0) };
509 if let Err(e) = udp.send_to(buf, *addr).await {
510 log::warn!("send_to {addr:?},{e:?}")
511 }
512 break;
513 } else {
514 let rs = write_with(&udp, || sendmmsg(fd, bufs)).await;
515 match rs {
516 Ok(size) => {
517 if size < bufs.len() {
518 bufs = &mut bufs[size..];
519 continue;
520 }
521 break;
522 }
523 Err(e) => {
524 log::warn!("sendmmsg {e:?}");
525 }
526 }
527 }
528 }
529 if let Some(recycle_buf) = recycle_buf.as_ref() {
530 while let Some((buf, _)) = vec_buf.pop() {
531 recycle_buf.push(buf);
532 }
533 } else {
534 vec_buf.clear();
535 }
536 }
537 #[cfg(not(any(target_os = "linux", target_os = "android")))]
538 {
539 let rs = udp.send_to(&buf, addr).await;
540 if let Some(recycle_buf) = recycle_buf.as_ref() {
541 recycle_buf.push(buf);
542 }
543 if let Err(e) = rs {
544 log::debug!("{addr:?},{e:?}")
545 }
546 }
547 }
548 socket_manager.sender_map.remove(&index);
549 });
550 sender
551 };
552 if udp_tunnel.sender.is_none() {
553 udp_tunnel.sender.replace(OwnedUdpTunnelSender { sender });
554 }
555 if udp_tunnel.reusable {
556 UdpTunnel::with_main(udp_tunnel, self.manager().tunnel_dispatcher.clone())
557 } else {
558 UdpTunnel::with_sub(udp_tunnel)
559 }
560 }
561 pub fn manager(&self) -> &Arc<UdpSocketManager> {
562 &self.socket_manager
563 }
564}
565
566#[cfg(any(target_os = "linux", target_os = "android"))]
567fn sendmmsg(fd: std::os::fd::RawFd, bufs: &mut [(BytesMut, SocketAddr)]) -> io::Result<usize> {
568 assert!(bufs.len() <= MAX_MESSAGES);
569 let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
570 let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
571 let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
572 for (i, (buf, addr)) in bufs.iter_mut().enumerate() {
573 addrs[i] = socket_addr_to_sockaddr(addr);
574 iov[i].iov_base = buf.as_mut_ptr() as *mut libc::c_void;
575 iov[i].iov_len = buf.len();
576 msgs[i].msg_hdr.msg_iov = &mut iov[i];
577 msgs[i].msg_hdr.msg_iovlen = 1;
578
579 msgs[i].msg_hdr.msg_name = &mut addrs[i] as *mut _ as *mut libc::c_void;
580 msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
581 }
582
583 unsafe {
584 let res = libc::sendmmsg(
585 fd,
586 msgs.as_mut_ptr(),
587 bufs.len() as _,
588 libc::MSG_DONTWAIT as _,
589 );
590 if res == -1 {
591 return Err(io::Error::last_os_error());
592 }
593 Ok(res as usize)
594 }
595}
596
597#[cfg(any(target_os = "linux", target_os = "android"))]
598fn socket_addr_to_sockaddr(addr: &SocketAddr) -> sockaddr_storage {
599 let mut storage: sockaddr_storage = unsafe { std::mem::zeroed() };
600
601 match addr {
602 SocketAddr::V4(v4_addr) => {
603 let sin = libc::sockaddr_in {
604 sin_family: libc::AF_INET as _,
605 sin_port: v4_addr.port().to_be(),
606 sin_addr: libc::in_addr {
607 s_addr: u32::from_ne_bytes(v4_addr.ip().octets()), },
609 sin_zero: [0; 8],
610 };
611
612 unsafe {
613 let sin_ptr = &sin as *const libc::sockaddr_in as *const u8;
614 let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
615 std::ptr::copy_nonoverlapping(
616 sin_ptr,
617 storage_ptr,
618 std::mem::size_of::<libc::sockaddr>(),
619 );
620 }
621 }
622 SocketAddr::V6(v6_addr) => {
623 let sin6 = libc::sockaddr_in6 {
624 sin6_family: libc::AF_INET6 as _,
625 sin6_port: v6_addr.port().to_be(),
626 sin6_flowinfo: v6_addr.flowinfo(),
627 sin6_addr: libc::in6_addr {
628 s6_addr: v6_addr.ip().octets(),
629 },
630 sin6_scope_id: v6_addr.scope_id(),
631 };
632
633 unsafe {
634 let sin6_ptr = &sin6 as *const libc::sockaddr_in6 as *const u8;
635 let storage_ptr = &mut storage as *mut sockaddr_storage as *mut u8;
636 std::ptr::copy_nonoverlapping(
637 sin6_ptr,
638 storage_ptr,
639 std::mem::size_of::<libc::sockaddr>(),
640 );
641 }
642 }
643 }
644 storage
645}
646
647pub struct UdpTunnel {
648 index: Index,
649 local_addr: SocketAddr,
650 udp: Option<Arc<UdpSocket>>,
651 close_notify: Option<async_broadcast::Receiver<()>>,
652 re_dispatcher: Option<Sender<InactiveUdpTunnel>>,
653 sender: Option<OwnedUdpTunnelSender>,
654}
655struct OwnedUdpTunnelSender {
656 sender: Sender<(BytesMut, SocketAddr)>,
657}
658#[derive(Clone)]
659pub struct WeakUdpTunnelSender {
660 sender: Sender<(BytesMut, SocketAddr)>,
661}
662struct InactiveUdpTunnel {
663 reusable: bool,
664 index: Index,
665 udp: Arc<UdpSocket>,
666 close_notify: Option<async_broadcast::Receiver<()>>,
667 sender: Option<OwnedUdpTunnelSender>,
668}
669impl InactiveUdpTunnel {
670 fn new(
671 reusable: bool,
672 index: Index,
673 udp: Arc<UdpSocket>,
674 close_notify: Option<async_broadcast::Receiver<()>>,
675 ) -> Self {
676 Self {
677 reusable,
678 index,
679 udp,
680 close_notify,
681 sender: None,
682 }
683 }
684 fn redistribute(index: Index, udp: Arc<UdpSocket>, sender: OwnedUdpTunnelSender) -> Self {
685 Self {
686 reusable: true,
687 index,
688 udp,
689 close_notify: None,
690 sender: Some(sender),
691 }
692 }
693}
694impl OwnedUdpTunnelSender {
695 async fn send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
696 if buf.is_empty() {
697 return Ok(());
698 }
699 self.sender
700 .send((buf, dest.into()))
701 .await
702 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
703 }
704 fn is_closed(&self) -> bool {
705 self.sender.is_closed()
706 }
707}
708impl WeakUdpTunnelSender {
709 pub async fn send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
710 if buf.is_empty() {
711 return Ok(());
712 }
713 self.sender
714 .send((buf, dest.into()))
715 .await
716 .map_err(|_| io::Error::from(io::ErrorKind::WriteZero))
717 }
718 pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: BytesMut, dest: A) -> io::Result<()> {
719 if buf.is_empty() {
720 return Ok(());
721 }
722 self.sender
723 .try_send((buf, dest.into()))
724 .map_err(|e| match e {
725 TrySendError::Full(_) => io::Error::from(io::ErrorKind::WouldBlock),
726 TrySendError::Closed(_) => io::Error::from(io::ErrorKind::WriteZero),
727 })
728 }
729}
730impl Drop for OwnedUdpTunnelSender {
731 fn drop(&mut self) {
732 self.sender.close();
733 }
734}
735impl Drop for UdpTunnel {
736 fn drop(&mut self) {
737 let Some(sender) = self.sender.take() else {
738 return;
739 };
740 if sender.is_closed() {
741 return;
742 }
743 let Some(udp) = self.udp.take() else {
744 return;
745 };
746 let Some(re_dispatcher) = self.re_dispatcher.take() else {
747 return;
748 };
749 let rs = re_dispatcher.try_send(InactiveUdpTunnel::redistribute(self.index, udp, sender));
750 if let Err(TrySendError::Full(_)) = rs {
751 log::warn!("Udp Tunnel TrySendError full");
752 }
753 }
754}
755
756impl UdpTunnel {
757 fn with_sub(inactive_udp_tunnel: InactiveUdpTunnel) -> io::Result<Self> {
758 let local_addr = inactive_udp_tunnel.udp.local_addr()?;
759 Ok(Self {
760 index: inactive_udp_tunnel.index,
761 local_addr,
762 udp: Some(inactive_udp_tunnel.udp),
763 close_notify: inactive_udp_tunnel.close_notify,
764 re_dispatcher: None,
765 sender: inactive_udp_tunnel.sender,
766 })
767 }
768 fn with_main(
769 inactive_udp_tunnel: InactiveUdpTunnel,
770 re_sender: Sender<InactiveUdpTunnel>,
771 ) -> io::Result<Self> {
772 let local_addr = inactive_udp_tunnel.udp.local_addr()?;
773 Ok(Self {
774 local_addr,
775 index: inactive_udp_tunnel.index,
776 udp: Some(inactive_udp_tunnel.udp),
777 close_notify: None,
778 re_dispatcher: Some(re_sender),
779 sender: inactive_udp_tunnel.sender,
780 })
781 }
782 pub fn done(&mut self) {
783 _ = self.udp.take();
784 _ = self.close_notify.take();
785 _ = self.re_dispatcher.take();
786 _ = self.re_dispatcher.take();
787 _ = self.sender.take();
788 }
789 pub fn local_addr(&self) -> SocketAddr {
790 self.local_addr
791 }
792 pub fn sender(&self) -> io::Result<WeakUdpTunnelSender> {
793 if let Some(v) = &self.sender {
794 Ok(WeakUdpTunnelSender {
795 sender: v.sender.clone(),
796 })
797 } else {
798 Err(io::Error::new(io::ErrorKind::Other, "closed"))
799 }
800 }
801}
802
803impl UdpTunnel {
804 pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
806 if let Some(udp) = &self.udp {
807 udp.send_to(buf, addr.into()).await?;
808 Ok(())
809 } else {
810 Err(io::Error::new(io::ErrorKind::Other, "closed"))
811 }
812 }
813 pub fn try_send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<()> {
815 if let Some(udp) = &self.udp {
816 udp.try_send_to(buf, addr.into())?;
817 Ok(())
818 } else {
819 Err(io::Error::new(io::ErrorKind::Other, "closed"))
820 }
821 }
822 pub async fn send_bytes_to<A: Into<SocketAddr>>(
823 &self,
824 buf: BytesMut,
825 addr: A,
826 ) -> io::Result<()> {
827 if let Some(sender) = &self.sender {
828 sender.send_to(buf, addr).await
829 } else {
830 Err(io::Error::new(io::ErrorKind::Other, "closed"))
831 }
832 }
833
834 pub async fn recv_from(&mut self, buf: &mut [u8]) -> Option<io::Result<(usize, RouteKey)>> {
838 let udp = if let Some(udp) = &self.udp {
839 udp
840 } else {
841 return None;
842 };
843 loop {
844 if let Some(close_notify) = &mut self.close_notify {
845 tokio::select! {
846 _rs=close_notify.recv()=>{
847 self.done();
848 return None
849 }
850 result=udp.recv_from(buf)=>{
851 let (len, addr) = match result {
852 Ok(rs) => rs,
853 Err(e) => {
854 if should_ignore_error(&e) {
855 continue;
856 }
857 return Some(Err(e))
858 }
859 };
860 return Some(Ok((len, RouteKey::new(self.index, addr))))
861 }
862 }
863 } else {
864 let (len, addr) = match udp.recv_from(buf).await {
865 Ok(rs) => rs,
866 Err(e) => {
867 if should_ignore_error(&e) {
868 continue;
869 }
870 return Some(Err(e));
871 }
872 };
873 return Some(Ok((len, RouteKey::new(self.index, addr))));
874 }
875 }
876 }
877 #[cfg(not(any(target_os = "linux", target_os = "android")))]
878 pub async fn batch_recv_from<B: AsMut<[u8]>>(
879 &mut self,
880 bufs: &mut [B],
881 sizes: &mut [usize],
882 addrs: &mut [RouteKey],
883 ) -> Option<io::Result<usize>> {
884 if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
885 return Some(Err(io::Error::new(io::ErrorKind::Other, "bufs error")));
886 }
887 let rs = self.recv_from(bufs[0].as_mut()).await?;
888 match rs {
889 Ok((len, addr)) => {
890 let udp = self.udp.as_ref()?;
891 sizes[0] = len;
892 addrs[0] = addr;
893 let mut num = 1;
894 while num < bufs.len() {
895 match udp.try_recv_from(bufs[num].as_mut()) {
896 Ok((len, addr)) => {
897 sizes[num] = len;
898 addrs[num] = RouteKey::new(self.index, addr);
899 num += 1;
900 }
901 Err(_) => break,
902 }
903 }
904 Some(Ok(num))
905 }
906 Err(e) => Some(Err(e)),
907 }
908 }
909 #[cfg(any(target_os = "linux", target_os = "android"))]
910 pub async fn batch_recv_from<B: AsMut<[u8]>>(
911 &mut self,
912 bufs: &mut [B],
913 sizes: &mut [usize],
914 addrs: &mut [RouteKey],
915 ) -> Option<io::Result<usize>> {
916 if bufs.is_empty() || bufs.len() != sizes.len() || bufs.len() != addrs.len() {
917 return Some(Err(io::Error::new(
918 io::ErrorKind::Other,
919 "bufs/sizes/addrs error",
920 )));
921 }
922 let udp = self.udp.as_ref()?;
923 let fd = udp.as_raw_fd();
924 loop {
925 let rs = if let Some(close_notify) = &mut self.close_notify {
926 tokio::select! {
927 _rs=close_notify.recv()=>{
928 self.done();
929 return None
930 }
931 rs=read_with(udp,|| recvmmsg(self.index, fd, bufs, sizes, addrs))=>{
932 rs
933 }
934 }
935 } else {
936 read_with(udp, || recvmmsg(self.index, fd, bufs, sizes, addrs)).await
937 };
938 return match rs {
939 Ok(size) => Some(Ok(size)),
940 Err(e) => {
941 if should_ignore_error(&e) {
942 continue;
943 }
944 Some(Err(e))
945 }
946 };
947 }
948 }
949}
950
951#[cfg(any(target_os = "linux", target_os = "android"))]
952fn recvmmsg<B: AsMut<[u8]>>(
953 index: Index,
954 fd: std::os::fd::RawFd,
955 bufs: &mut [B],
956 sizes: &mut [usize],
957 route_keys: &mut [RouteKey],
958) -> io::Result<usize> {
959 let mut iov: [iovec; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
960 let mut msgs: [mmsghdr; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
961 let mut addrs: [sockaddr_storage; MAX_MESSAGES] = unsafe { std::mem::zeroed() };
962 let max_num = bufs.len().min(MAX_MESSAGES);
963 for i in 0..max_num {
964 iov[i].iov_base = bufs[i].as_mut().as_mut_ptr() as *mut libc::c_void;
965 iov[i].iov_len = bufs[i].as_mut().len();
966 msgs[i].msg_hdr.msg_iov = &mut iov[i];
967 msgs[i].msg_hdr.msg_iovlen = 1;
968 msgs[i].msg_hdr.msg_name = &mut addrs[i] as *const _ as *mut libc::c_void;
969 msgs[i].msg_hdr.msg_namelen = std::mem::size_of::<sockaddr_storage>() as socklen_t;
970 }
971 let res = unsafe {
972 libc::recvmmsg(
973 fd,
974 msgs.as_mut_ptr(),
975 max_num as c_uint,
976 libc::MSG_DONTWAIT as _,
977 std::ptr::null_mut(),
978 )
979 };
980 if res == -1 {
981 return Err(io::Error::last_os_error());
982 }
983 let nmsgs = res as usize;
984 for i in 0..nmsgs {
985 let addr = sockaddr_to_socket_addr(&addrs[i], msgs[i].msg_hdr.msg_namelen);
986 sizes[i] = msgs[i].msg_len as usize;
987 route_keys[i] = RouteKey::new(index, addr);
988 }
989 Ok(nmsgs)
990}
991#[cfg(any(target_os = "linux", target_os = "android"))]
992fn sockaddr_to_socket_addr(addr: &sockaddr_storage, _len: socklen_t) -> SocketAddr {
993 match addr.ss_family as libc::c_int {
994 libc::AF_INET => {
995 let addr_in = unsafe { *(addr as *const _ as *const libc::sockaddr_in) };
996 let ip = u32::from_be(addr_in.sin_addr.s_addr);
997 let port = u16::from_be(addr_in.sin_port);
998 SocketAddr::V4(std::net::SocketAddrV4::new(
999 std::net::Ipv4Addr::from(ip),
1000 port,
1001 ))
1002 }
1003 libc::AF_INET6 => {
1004 let addr_in6 = unsafe { *(addr as *const _ as *const libc::sockaddr_in6) };
1005 let ip = std::net::Ipv6Addr::from(addr_in6.sin6_addr.s6_addr);
1006 let port = u16::from_be(addr_in6.sin6_port);
1007 SocketAddr::V6(std::net::SocketAddrV6::new(ip, port, 0, 0))
1008 }
1009 _ => panic!("Unsupported address family"),
1010 }
1011}
1012
1013fn should_ignore_error(e: &io::Error) -> bool {
1014 #[cfg(windows)]
1015 {
1016 if let Some(os_error) = e.raw_os_error() {
1018 return os_error == windows_sys::Win32::Networking::WinSock::WSAECONNRESET;
1019 }
1020 }
1021 _ = e;
1022 false
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027 use std::time::Duration;
1028
1029 use crate::tunnel::udp::{Model, UdpTunnel};
1030
1031 #[tokio::test]
1032 pub async fn create_udp_tunnel() {
1033 let config = crate::tunnel::config::UdpTunnelConfig::default()
1034 .set_main_udp_count(2)
1035 .set_sub_udp_count(10)
1036 .set_model(Model::Low)
1037 .set_use_v6(false);
1038 let mut udp_tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1039 let mut count = 0;
1040 let mut join = Vec::new();
1041 while let Ok(rs) =
1042 tokio::time::timeout(Duration::from_secs(1), udp_tunnel_factory.dispatch()).await
1043 {
1044 join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1045 count += 1;
1046 }
1047 assert_eq!(count, 2)
1048 }
1049
1050 #[tokio::test]
1051 pub async fn create_sub_udp_tunnel() {
1052 let config = crate::tunnel::config::UdpTunnelConfig::default()
1053 .set_main_udp_count(2)
1054 .set_sub_udp_count(10)
1055 .set_use_v6(false)
1056 .set_model(Model::High);
1057 let mut tunnel_factory = crate::tunnel::udp::create_tunnel_dispatcher(config).unwrap();
1058 let mut count = 0;
1059 let mut join = Vec::new();
1060 while let Ok(rs) =
1061 tokio::time::timeout(Duration::from_secs(1), tunnel_factory.dispatch()).await
1062 {
1063 join.push(tokio::spawn(tunnel_recv(rs.unwrap())));
1064 count += 1;
1065 }
1066 tunnel_factory.manager().switch_low();
1067
1068 let mut close_tunnel_count = 0;
1069 for x in join {
1070 let rs = tokio::time::timeout(Duration::from_secs(1), x).await;
1071 match rs {
1072 Ok(rs) => {
1073 if rs.unwrap() {
1074 close_tunnel_count += 1;
1076 }
1077 }
1078 Err(_e) => {
1079 _ = _e;
1080 }
1081 }
1082 }
1083 assert_eq!(count, 12);
1084 assert_eq!(close_tunnel_count, 10);
1085 }
1086
1087 async fn tunnel_recv(mut tunnel: UdpTunnel) -> bool {
1088 let mut buf = [0; 1400];
1089 tunnel.recv_from(&mut buf).await.is_none()
1090 }
1091}