Skip to main content

tun2proxy/
lib.rs

1#[cfg(target_os = "linux")]
2extern crate bincode_next as bincode;
3
4#[cfg(feature = "udpgw")]
5use crate::udpgw::UdpGwClient;
6use crate::{
7    directions::{IncomingDataEvent, IncomingDirection, OutgoingDirection},
8    http::HttpManager,
9    no_proxy::NoProxyManager,
10    session_info::{IpProtocol, SessionInfo},
11    virtual_dns::VirtualDns,
12};
13pub use clap::ValueEnum;
14use ipstack::{IpStackStream, IpStackTcpStream, IpStackUdpStream};
15use proxy_handler::{ProxyHandler, ProxyHandlerManager};
16use socks::SocksProxyManager;
17pub use socks5_impl::protocol::UserKey;
18#[cfg(feature = "udpgw")]
19use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
20use std::{
21    collections::VecDeque,
22    io::ErrorKind,
23    net::{IpAddr, SocketAddr},
24    sync::Arc,
25};
26use tokio::{
27    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
28    net::{TcpSocket, TcpStream, UdpSocket},
29    sync::{Mutex, mpsc::Receiver},
30};
31pub use tokio_util::sync::CancellationToken;
32use tproxy_config::is_private_ip;
33pub use tun::DEFAULT_MTU;
34use udp_stream::UdpStream;
35#[cfg(feature = "udpgw")]
36use udpgw::{UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS, UdpGwClientStream, UdpGwResponse};
37
38pub use {
39    args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
40    error::{BoxError, Error, Result},
41    traffic_status::{TrafficStatus, tun2proxy_set_traffic_status_callback},
42};
43
44pub use general_api::general_run_async;
45
46pub const FORCE_EXIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
47
48mod android;
49mod args;
50mod directions;
51mod dns;
52mod dump_logger;
53mod error;
54mod general_api;
55mod http;
56mod no_proxy;
57mod proxy_handler;
58mod session_info;
59pub mod socket_transfer;
60mod socks;
61mod traffic_status;
62#[cfg(feature = "udpgw")]
63pub mod udpgw;
64mod virtual_dns;
65#[doc(hidden)]
66pub mod win_svc;
67
68const DNS_PORT: u16 = 53;
69
70#[allow(unused)]
71#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)]
72#[cfg_attr(
73    target_os = "linux",
74    derive(bincode::Encode, bincode::Decode, serde::Serialize, serde::Deserialize)
75)]
76pub enum SocketProtocol {
77    Tcp,
78    Udp,
79}
80
81#[allow(unused)]
82#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)]
83#[cfg_attr(
84    target_os = "linux",
85    derive(bincode::Encode, bincode::Decode, serde::Serialize, serde::Deserialize)
86)]
87pub enum SocketDomain {
88    IpV4,
89    IpV6,
90}
91
92impl From<IpAddr> for SocketDomain {
93    fn from(value: IpAddr) -> Self {
94        match value {
95            IpAddr::V4(_) => Self::IpV4,
96            IpAddr::V6(_) => Self::IpV6,
97        }
98    }
99}
100
101struct SocketQueue {
102    tcp_v4: Mutex<Receiver<TcpSocket>>,
103    tcp_v6: Mutex<Receiver<TcpSocket>>,
104    udp_v4: Mutex<Receiver<UdpSocket>>,
105    udp_v6: Mutex<Receiver<UdpSocket>>,
106}
107
108impl SocketQueue {
109    async fn recv_tcp(&self, domain: SocketDomain) -> Result<TcpSocket, std::io::Error> {
110        match domain {
111            SocketDomain::IpV4 => &self.tcp_v4,
112            SocketDomain::IpV6 => &self.tcp_v6,
113        }
114        .lock()
115        .await
116        .recv()
117        .await
118        .ok_or(ErrorKind::Other.into())
119    }
120    async fn recv_udp(&self, domain: SocketDomain) -> Result<UdpSocket, std::io::Error> {
121        match domain {
122            SocketDomain::IpV4 => &self.udp_v4,
123            SocketDomain::IpV6 => &self.udp_v6,
124        }
125        .lock()
126        .await
127        .recv()
128        .await
129        .ok_or(ErrorKind::Other.into())
130    }
131}
132
133async fn create_tcp_stream(socket_queue: &Option<Arc<SocketQueue>>, peer: SocketAddr) -> std::io::Result<TcpStream> {
134    match &socket_queue {
135        None => TcpStream::connect(peer).await,
136        Some(queue) => queue.recv_tcp(peer.ip().into()).await?.connect(peer).await,
137    }
138}
139
140async fn create_udp_stream(socket_queue: &Option<Arc<SocketQueue>>, peer: SocketAddr) -> std::io::Result<UdpStream> {
141    match &socket_queue {
142        None => {
143            let bind_addr = match peer {
144                SocketAddr::V4(_) => SocketAddr::from((std::net::Ipv4Addr::UNSPECIFIED, 0)),
145                SocketAddr::V6(_) => SocketAddr::from((std::net::Ipv6Addr::UNSPECIFIED, 0)),
146            };
147            let socket = UdpSocket::bind(bind_addr).await?;
148            socket.connect(peer).await?;
149            UdpStream::from_tokio(socket, peer).await
150        }
151        Some(queue) => {
152            let socket = queue.recv_udp(peer.ip().into()).await?;
153            socket.connect(peer).await?;
154            UdpStream::from_tokio(socket, peer).await
155        }
156    }
157}
158
159/// Run the proxy server
160/// # Arguments
161/// * `device` - The network device to use
162/// * `mtu` - The MTU of the network device
163/// * `args` - The arguments to use
164/// * `shutdown_token` - The token to exit the server
165/// # Returns
166/// * The number of sessions while exiting
167pub async fn run<D>(device: D, mtu: u16, args: Args, shutdown_token: CancellationToken) -> crate::Result<usize>
168where
169    D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
170{
171    log::info!("{} {} starting...", env!("CARGO_PKG_NAME"), version_info!());
172    log::info!("Proxy {} server: {}", args.proxy.proxy_type, args.proxy.addr);
173
174    let server_addr = args.proxy.addr;
175    let key = args.proxy.credentials.clone();
176    let dns_addr = args.dns_addr;
177    let ipv6_enabled = args.ipv6_enabled;
178    let virtual_dns = if args.dns == ArgDns::Virtual {
179        Some(Arc::new(Mutex::new(VirtualDns::new(args.virtual_dns_pool))))
180    } else {
181        None
182    };
183
184    #[cfg(target_os = "linux")]
185    let socket_queue = match args.socket_transfer_fd {
186        None => None,
187        Some(fd) => {
188            use crate::socket_transfer::{reconstruct_socket, reconstruct_transfer_socket, request_sockets};
189            use tokio::sync::mpsc::channel;
190
191            let fd = reconstruct_socket(fd)?;
192            let socket = reconstruct_transfer_socket(fd)?;
193            let socket = Arc::new(Mutex::new(socket));
194
195            macro_rules! create_socket_queue {
196                ($domain:ident) => {{
197                    const SOCKETS_PER_REQUEST: usize = 64;
198
199                    let socket = socket.clone();
200                    let (tx, rx) = channel(SOCKETS_PER_REQUEST);
201                    tokio::spawn(async move {
202                        loop {
203                            let sockets =
204                                match request_sockets(socket.lock().await, SocketDomain::$domain, SOCKETS_PER_REQUEST as u32).await {
205                                    Ok(sockets) => sockets,
206                                    Err(err) => {
207                                        log::warn!("Socket allocation request failed: {err}");
208                                        continue;
209                                    }
210                                };
211                            for s in sockets {
212                                if let Err(_) = tx.send(s).await {
213                                    return;
214                                }
215                            }
216                        }
217                    });
218                    Mutex::new(rx)
219                }};
220            }
221
222            Some(Arc::new(SocketQueue {
223                tcp_v4: create_socket_queue!(IpV4),
224                tcp_v6: create_socket_queue!(IpV6),
225                udp_v4: create_socket_queue!(IpV4),
226                udp_v6: create_socket_queue!(IpV6),
227            }))
228        }
229    };
230
231    #[cfg(not(target_os = "linux"))]
232    let socket_queue = None;
233
234    use socks5_impl::protocol::Version::{V4, V5};
235    let mgr: Arc<dyn ProxyHandlerManager> = match args.proxy.proxy_type {
236        ProxyType::Socks5 => Arc::new(SocksProxyManager::new(server_addr, V5, key)),
237        ProxyType::Socks4 => Arc::new(SocksProxyManager::new(server_addr, V4, key)),
238        ProxyType::Http => Arc::new(HttpManager::new(server_addr, key)),
239        ProxyType::None => Arc::new(NoProxyManager::new()),
240    };
241
242    let mut ipstack_config = ipstack::IpStackConfig::default();
243    ipstack_config.mtu(mtu)?;
244    let mut tcp_cfg = ipstack::TcpConfig::default();
245    tcp_cfg.timeout = std::time::Duration::from_secs(args.tcp_timeout);
246    ipstack_config.with_tcp_config(tcp_cfg);
247    ipstack_config.udp_timeout(std::time::Duration::from_secs(args.udp_timeout));
248
249    let mut ip_stack = ipstack::IpStack::new(ipstack_config, device);
250
251    #[cfg(feature = "udpgw")]
252    let udpgw_client = args.udpgw_server.map(|addr| {
253        log::info!("UDP Gateway enabled, server: {addr}");
254        use std::time::Duration;
255        let client = Arc::new(UdpGwClient::new(
256            mtu,
257            args.udpgw_connections.unwrap_or(UDPGW_MAX_CONNECTIONS),
258            args.udpgw_keepalive.map(Duration::from_secs).unwrap_or(UDPGW_KEEPALIVE_TIME),
259            args.udp_timeout,
260            addr,
261        ));
262        let client_keepalive = client.clone();
263        let shutdown_clone = shutdown_token.clone();
264        tokio::spawn(async move {
265            if let Err(err) = client_keepalive.heartbeat_task(shutdown_clone).await {
266                log::error!("UDP Gateway heartbeat task error: {err}");
267            }
268        });
269        client
270    });
271
272    let task_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
273    use std::sync::atomic::Ordering::Relaxed;
274
275    loop {
276        let task_count = task_count.clone();
277        let virtual_dns = virtual_dns.clone();
278        let ip_stack_stream = tokio::select! {
279            _ = shutdown_token.cancelled() => {
280                log::info!("Shutdown received");
281                break;
282            }
283            ip_stack_stream = ip_stack.accept() => {
284                ip_stack_stream?
285            }
286        };
287        let max_sessions = args.max_sessions;
288        match ip_stack_stream {
289            IpStackStream::Tcp(tcp) => {
290                if task_count.load(Relaxed) >= max_sessions {
291                    if args.exit_on_fatal_error {
292                        log::info!("Too many sessions that over {max_sessions}, exiting...");
293                        break;
294                    }
295                    log::warn!("Too many sessions that over {max_sessions}, dropping new session");
296                    continue;
297                }
298                log::trace!("Session count {}", task_count.fetch_add(1, Relaxed).saturating_add(1));
299                let info = SessionInfo::new(tcp.local_addr(), tcp.peer_addr(), IpProtocol::Tcp);
300                let domain_name = if let Some(virtual_dns) = &virtual_dns {
301                    let mut virtual_dns = virtual_dns.lock().await;
302                    virtual_dns.touch_ip(&tcp.peer_addr().ip());
303                    virtual_dns.resolve_ip(&tcp.peer_addr().ip()).cloned()
304                } else {
305                    None
306                };
307                let proxy_handler = mgr.new_proxy_handler(info, domain_name, false).await?;
308                let socket_queue = socket_queue.clone();
309                tokio::spawn(async move {
310                    if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await {
311                        log::error!("{info} error \"{err}\"");
312                    }
313                    log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
314                });
315            }
316            IpStackStream::Udp(udp) => {
317                if task_count.load(Relaxed) >= max_sessions {
318                    if args.exit_on_fatal_error {
319                        log::info!("Too many sessions that over {max_sessions}, exiting...");
320                        break;
321                    }
322                    log::warn!("Too many sessions that over {max_sessions}, dropping new session");
323                    continue;
324                }
325                log::trace!("Session count {}", task_count.fetch_add(1, Relaxed).saturating_add(1));
326                let mut info = SessionInfo::new(udp.local_addr(), udp.peer_addr(), IpProtocol::Udp);
327                if info.dst.port() == DNS_PORT {
328                    if is_private_ip(info.dst.ip()) {
329                        info.dst.set_ip(dns_addr); // !!! Here we change the destination address to remote DNS server!!!
330                    }
331                    if args.dns == ArgDns::OverTcp {
332                        info.protocol = IpProtocol::Tcp;
333                        let proxy_handler = mgr.new_proxy_handler(info, None, false).await?;
334                        let socket_queue = socket_queue.clone();
335                        tokio::spawn(async move {
336                            if let Err(err) = handle_dns_over_tcp_session(udp, proxy_handler, socket_queue, ipv6_enabled).await {
337                                log::error!("{info} error \"{err}\"");
338                            }
339                            log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
340                        });
341                        continue;
342                    }
343                    if args.dns == ArgDns::Virtual {
344                        tokio::spawn(async move {
345                            if let Some(virtual_dns) = virtual_dns {
346                                if let Err(err) = handle_virtual_dns_session(udp, virtual_dns).await {
347                                    log::error!("{info} error \"{err}\"");
348                                }
349                            }
350                            log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
351                        });
352                        continue;
353                    }
354                    assert_eq!(args.dns, ArgDns::Direct);
355                }
356                let domain_name = if let Some(virtual_dns) = &virtual_dns {
357                    let mut virtual_dns = virtual_dns.lock().await;
358                    virtual_dns.touch_ip(&udp.peer_addr().ip());
359                    virtual_dns.resolve_ip(&udp.peer_addr().ip()).cloned()
360                } else {
361                    None
362                };
363                #[cfg(feature = "udpgw")]
364                if let Some(udpgw) = udpgw_client.clone() {
365                    let tcp_src = match udp.peer_addr() {
366                        SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
367                        SocketAddr::V6(_) => SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)),
368                    };
369                    let tcpinfo = SessionInfo::new(tcp_src, udpgw.get_udpgw_server_addr(), IpProtocol::Tcp);
370                    let proxy_handler = mgr.new_proxy_handler(tcpinfo, None, false).await?;
371                    let queue = socket_queue.clone();
372                    tokio::spawn(async move {
373                        let dst = info.dst; // real UDP destination address
374                        let dst_addr = match domain_name {
375                            Some(ref d) => socks5_impl::protocol::Address::from((d.clone(), dst.port())),
376                            None => dst.into(),
377                        };
378                        if let Err(e) = handle_udp_gateway_session(udp, udpgw, &dst_addr, proxy_handler, queue, ipv6_enabled).await {
379                            log::info!("Ending {info} with \"{e}\"");
380                        }
381                        log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
382                    });
383                    continue;
384                }
385                match mgr.new_proxy_handler(info, domain_name, true).await {
386                    Ok(proxy_handler) => {
387                        let socket_queue = socket_queue.clone();
388                        tokio::spawn(async move {
389                            let ty = args.proxy.proxy_type;
390                            if let Err(err) = handle_udp_associate_session(udp, ty, proxy_handler, socket_queue, ipv6_enabled).await {
391                                log::info!("Ending {info} with \"{err}\"");
392                            }
393                            log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
394                        });
395                    }
396                    Err(e) => {
397                        log::error!("Failed to create UDP connection: {e}");
398                    }
399                }
400            }
401            IpStackStream::UnknownTransport(u) => {
402                let len = u.payload().len();
403                log::info!("#0 unhandled transport - Ip Protocol {:?}, length {}", u.ip_protocol(), len);
404                continue;
405            }
406            IpStackStream::UnknownNetwork(pkt) => {
407                log::info!("#0 unknown transport - {} bytes", pkt.len());
408                continue;
409            }
410        }
411    }
412    Ok(task_count.load(Relaxed))
413}
414
415async fn handle_virtual_dns_session(mut udp: IpStackUdpStream, dns: Arc<Mutex<VirtualDns>>) -> crate::Result<()> {
416    let mut buf = [0_u8; 4096];
417    loop {
418        let len = match udp.read(&mut buf).await {
419            Err(e) => {
420                // indicate UDP read fails not an error.
421                log::debug!("Virtual DNS session error: {e}");
422                break;
423            }
424            Ok(len) => len,
425        };
426        if len == 0 {
427            break;
428        }
429        let (msg, qname, ip) = dns.lock().await.generate_query(&buf[..len])?;
430        udp.write_all(&msg).await?;
431        log::debug!("Virtual DNS query: {qname} -> {ip}");
432    }
433    Ok(())
434}
435
436async fn copy_and_record_traffic<R, W>(reader: &mut R, writer: &mut W, is_tx: bool) -> tokio::io::Result<u64>
437where
438    R: tokio::io::AsyncRead + Unpin + ?Sized,
439    W: tokio::io::AsyncWrite + Unpin + ?Sized,
440{
441    let mut buf = vec![0; 8192];
442    let mut total = 0;
443    loop {
444        match reader.read(&mut buf).await? {
445            0 => break, // EOF
446            n => {
447                total += n as u64;
448                let (tx, rx) = if is_tx { (n, 0) } else { (0, n) };
449                if let Err(e) = crate::traffic_status::traffic_status_update(tx, rx) {
450                    log::debug!("Record traffic status error: {e}");
451                }
452                writer.write_all(&buf[..n]).await?;
453            }
454        }
455    }
456    Ok(total)
457}
458
459async fn handle_tcp_session(
460    mut tcp_stack: IpStackTcpStream,
461    proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
462    socket_queue: Option<Arc<SocketQueue>>,
463) -> crate::Result<()> {
464    let (session_info, server_addr) = {
465        let handler = proxy_handler.lock().await;
466
467        (handler.get_session_info(), handler.get_server_addr())
468    };
469
470    let mut server = create_tcp_stream(&socket_queue, server_addr).await?;
471
472    log::info!("Beginning {session_info}");
473
474    if let Err(e) = handle_proxy_session(&mut server, proxy_handler).await {
475        tcp_stack.shutdown().await?;
476        return Err(e);
477    }
478
479    let (mut t_rx, mut t_tx) = tokio::io::split(tcp_stack);
480    let (mut s_rx, mut s_tx) = tokio::io::split(server);
481
482    let res = tokio::join!(
483        async move {
484            let r = copy_and_record_traffic(&mut t_rx, &mut s_tx, true).await;
485            if let Err(err) = s_tx.shutdown().await {
486                log::trace!("{session_info} s_tx shutdown error {err}");
487            }
488            r
489        },
490        async move {
491            let r = copy_and_record_traffic(&mut s_rx, &mut t_tx, false).await;
492            if let Err(err) = t_tx.shutdown().await {
493                log::trace!("{session_info} t_tx shutdown error {err}");
494            }
495            r
496        },
497    );
498    log::info!("Ending {session_info} with {res:?}");
499
500    Ok(())
501}
502
503#[cfg(feature = "udpgw")]
504async fn handle_udp_gateway_session(
505    mut udp_stack: IpStackUdpStream,
506    udpgw_client: Arc<UdpGwClient>,
507    udp_dst: &socks5_impl::protocol::Address,
508    proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
509    socket_queue: Option<Arc<SocketQueue>>,
510    ipv6_enabled: bool,
511) -> crate::Result<()> {
512    let proxy_server_addr = { proxy_handler.lock().await.get_server_addr() };
513    let udp_mtu = udpgw_client.get_udp_mtu();
514    let udp_timeout = udpgw_client.get_udp_timeout();
515
516    let mut stream = loop {
517        match udpgw_client.pop_server_connection_from_queue().await {
518            Some(stream) => {
519                if stream.is_closed() {
520                    continue;
521                } else {
522                    break stream;
523                }
524            }
525            None => {
526                let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?;
527                if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
528                    return Err(format!("udpgw connection error: {e}").into());
529                }
530                break UdpGwClientStream::new(tcp_server_stream);
531            }
532        }
533    };
534
535    let tcp_local_addr = stream.local_addr();
536    let sn = stream.serial_number();
537
538    log::info!("[UdpGw] Beginning stream {} {} -> {}", sn, &tcp_local_addr, udp_dst);
539
540    let Some(mut reader) = stream.get_reader() else {
541        return Err("get reader failed".into());
542    };
543
544    let Some(mut writer) = stream.get_writer() else {
545        return Err("get writer failed".into());
546    };
547
548    let mut tmp_buf = vec![0; udp_mtu.into()];
549
550    loop {
551        tokio::select! {
552            len = udp_stack.read(&mut tmp_buf) => {
553                let read_len = match len {
554                    Ok(0) => {
555                        log::info!("[UdpGw] Ending stream {} {} <> {}", sn, &tcp_local_addr, udp_dst);
556                        break;
557                    }
558                    Ok(n) => n,
559                    Err(e) => {
560                        log::info!("[UdpGw] Ending stream {} {} <> {} with udp stack \"{}\"", sn, &tcp_local_addr, udp_dst, e);
561                        break;
562                    }
563                };
564                crate::traffic_status::traffic_status_update(read_len, 0)?;
565                let sn = stream.serial_number();
566                if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], udp_dst, sn, &mut writer).await {
567                    log::info!("[UdpGw] Ending stream {} {} <> {} with send_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e);
568                    break;
569                }
570                log::debug!("[UdpGw] stream {} {} -> {} send len {}", sn, &tcp_local_addr, udp_dst, read_len);
571                stream.update_activity();
572            }
573            ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut reader) => {
574                if let Ok((len, _)) = ret {
575                    crate::traffic_status::traffic_status_update(0, len)?;
576                }
577                match ret {
578                    Err(e) => {
579                        log::warn!("[UdpGw] Ending stream {} {} <> {} with recv_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e);
580                        stream.close();
581                        break;
582                    }
583                    Ok((_, packet)) => match packet {
584                        //should not received keepalive
585                        UdpGwResponse::KeepAlive => {
586                            log::error!("[UdpGw] Ending stream {} {} <> {} with recv keepalive", sn, &tcp_local_addr, udp_dst);
587                            stream.close();
588                            break;
589                        }
590                        //server udp may be timeout,can continue to receive udp data?
591                        UdpGwResponse::Error => {
592                            log::info!("[UdpGw] Ending stream {} {} <> {} with recv udp error", sn, &tcp_local_addr, udp_dst);
593                            stream.update_activity();
594                            continue;
595                        }
596                        UdpGwResponse::TcpClose => {
597                            log::error!("[UdpGw] Ending stream {} {} <> {} with tcp closed", sn, &tcp_local_addr, udp_dst);
598                            stream.close();
599                            break;
600                        }
601                        UdpGwResponse::Data(data) => {
602                            use socks5_impl::protocol::StreamOperation;
603                            let len = data.len();
604                            let f = data.header.flags;
605                            log::debug!("[UdpGw] stream {sn} {} <- {} receive {f} len {len}", &tcp_local_addr, udp_dst);
606                            if let Err(e) = udp_stack.write_all(&data.data).await {
607                                log::error!("[UdpGw] Ending stream {} {} <> {} with send_udp_packet {}", sn, &tcp_local_addr, udp_dst, e);
608                                break;
609                            }
610                        }
611                    }
612                }
613                stream.update_activity();
614            }
615        }
616    }
617
618    if !stream.is_closed() {
619        udpgw_client.store_server_connection_full(stream, reader, writer).await;
620    }
621
622    Ok(())
623}
624
625async fn handle_udp_associate_session(
626    mut udp_stack: IpStackUdpStream,
627    proxy_type: ProxyType,
628    proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
629    socket_queue: Option<Arc<SocketQueue>>,
630    ipv6_enabled: bool,
631) -> crate::Result<()> {
632    use socks5_impl::protocol::{Address, StreamOperation, UdpHeader};
633
634    let (session_info, server_addr, domain_name, udp_addr) = {
635        let handler = proxy_handler.lock().await;
636        (
637            handler.get_session_info(),
638            handler.get_server_addr(),
639            handler.get_domain_name(),
640            handler.get_udp_associate(),
641        )
642    };
643
644    log::info!("Beginning {session_info}");
645
646    // `_server` is meaningful here, it must be alive all the time
647    // to ensure that UDP transmission will not be interrupted accidentally.
648    let (_server, udp_addr) = match udp_addr {
649        Some(udp_addr) => (None, udp_addr),
650        None => {
651            let mut server = create_tcp_stream(&socket_queue, server_addr).await?;
652            let udp_addr = handle_proxy_session(&mut server, proxy_handler).await?;
653            (Some(server), udp_addr.ok_or("udp associate failed")?)
654        }
655    };
656
657    let mut udp_server = create_udp_stream(&socket_queue, udp_addr).await?;
658
659    let mut buf1 = [0_u8; 4096];
660    let mut buf2 = [0_u8; 4096];
661    loop {
662        tokio::select! {
663            len = udp_stack.read(&mut buf1) => {
664                let len = len?;
665                if len == 0 {
666                    break;
667                }
668                let buf1 = &buf1[..len];
669
670                crate::traffic_status::traffic_status_update(len, 0)?;
671
672                if let ProxyType::Socks4 | ProxyType::Socks5 = proxy_type {
673                    let s5addr = if let Some(domain_name) = &domain_name {
674                        Address::DomainAddress(domain_name.clone().into(), session_info.dst.port())
675                    } else {
676                        session_info.dst.into()
677                    };
678
679                    // Add SOCKS5 UDP header to the incoming data
680                    let mut s5_udp_data = Vec::<u8>::new();
681                    UdpHeader::new(0, s5addr).write_to_stream(&mut s5_udp_data)?;
682                    s5_udp_data.extend_from_slice(buf1);
683
684                    udp_server.write_all(&s5_udp_data).await?;
685                } else {
686                    udp_server.write_all(buf1).await?;
687                }
688            }
689            len = udp_server.read(&mut buf2) => {
690                let len = len?;
691                if len == 0 {
692                    break;
693                }
694                let buf2 = &buf2[..len];
695
696                crate::traffic_status::traffic_status_update(0, len)?;
697
698                if let ProxyType::Socks4 | ProxyType::Socks5 = proxy_type {
699                    // Remove SOCKS5 UDP header from the server data
700                    let header = UdpHeader::retrieve_from_stream(&mut &buf2[..])?;
701                    let data = &buf2[header.len()..];
702
703                    let buf = if session_info.dst.port() == DNS_PORT {
704                        let mut message = dns::parse_data_to_dns_message(data, false)?;
705                        if !ipv6_enabled {
706                            dns::remove_ipv6_entries(&mut message);
707                        }
708                        message.to_vec()?
709                    } else {
710                        data.to_vec()
711                    };
712
713                    udp_stack.write_all(&buf).await?;
714                } else {
715                    udp_stack.write_all(buf2).await?;
716                }
717            }
718        }
719    }
720
721    log::info!("Ending {session_info}");
722
723    Ok(())
724}
725
726async fn handle_dns_over_tcp_session(
727    mut udp_stack: IpStackUdpStream,
728    proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
729    socket_queue: Option<Arc<SocketQueue>>,
730    ipv6_enabled: bool,
731) -> crate::Result<()> {
732    let (session_info, server_addr) = {
733        let handler = proxy_handler.lock().await;
734
735        (handler.get_session_info(), handler.get_server_addr())
736    };
737
738    let mut server = create_tcp_stream(&socket_queue, server_addr).await?;
739
740    log::info!("Beginning {session_info}");
741
742    let _ = handle_proxy_session(&mut server, proxy_handler).await?;
743
744    let mut buf1 = [0_u8; 4096];
745    let mut buf2 = [0_u8; 4096];
746    loop {
747        tokio::select! {
748            len = udp_stack.read(&mut buf1) => {
749                let len = len?;
750                if len == 0 {
751                    break;
752                }
753                let buf1 = &buf1[..len];
754
755                _ = dns::parse_data_to_dns_message(buf1, false)?;
756
757                // Insert the DNS message length in front of the payload
758                let len = u16::try_from(buf1.len())?;
759                let mut buf = Vec::with_capacity(std::mem::size_of::<u16>() + usize::from(len));
760                buf.extend_from_slice(&len.to_be_bytes());
761                buf.extend_from_slice(buf1);
762
763                server.write_all(&buf).await?;
764
765                crate::traffic_status::traffic_status_update(buf.len(), 0)?;
766            }
767            len = server.read(&mut buf2) => {
768                let len = len?;
769                if len == 0 {
770                    break;
771                }
772                let mut buf = buf2[..len].to_vec();
773
774                crate::traffic_status::traffic_status_update(0, len)?;
775
776                let mut to_send: VecDeque<Vec<u8>> = VecDeque::new();
777                loop {
778                    if buf.len() < 2 {
779                        break;
780                    }
781                    let len = u16::from_be_bytes([buf[0], buf[1]]) as usize;
782                    if buf.len() < len + 2 {
783                        break;
784                    }
785
786                    // remove the length field
787                    let data = buf[2..len + 2].to_vec();
788
789                    let mut message = dns::parse_data_to_dns_message(&data, false)?;
790
791                    let name = dns::extract_domain_from_dns_message(&message)?;
792                    let ip = dns::extract_ipaddr_from_dns_message(&message);
793                    log::trace!("DNS over TCP query result: {name} -> {ip:?}");
794
795                    if !ipv6_enabled {
796                        dns::remove_ipv6_entries(&mut message);
797                    }
798
799                    to_send.push_back(message.to_vec()?);
800                    if len + 2 == buf.len() {
801                        break;
802                    }
803                    buf = buf[len + 2..].to_vec();
804                }
805
806                while let Some(packet) = to_send.pop_front() {
807                    udp_stack.write_all(&packet).await?;
808                }
809            }
810        }
811    }
812
813    log::info!("Ending {session_info}");
814
815    Ok(())
816}
817
818/// This function is used to handle the business logic of tun2proxy and SOCKS5 server.
819/// When handling UDP proxy, the return value UDP associate IP address is the result of this business logic.
820/// However, when handling TCP business logic, the return value Ok(None) is meaningless, just indicating that the operation was successful.
821async fn handle_proxy_session(server: &mut TcpStream, proxy_handler: Arc<Mutex<dyn ProxyHandler>>) -> crate::Result<Option<SocketAddr>> {
822    let mut launched = false;
823    let mut proxy_handler = proxy_handler.lock().await;
824    let dir = OutgoingDirection::ToServer;
825    let (mut tx, mut rx) = (0, 0);
826
827    loop {
828        if proxy_handler.connection_established() {
829            break;
830        }
831
832        if !launched {
833            let data = proxy_handler.peek_data(dir).buffer;
834            let len = data.len();
835            if len == 0 {
836                return Err("proxy_handler launched went wrong".into());
837            }
838            server.write_all(data).await?;
839            proxy_handler.consume_data(dir, len);
840            tx += len;
841
842            launched = true;
843        }
844
845        let mut buf = [0_u8; 4096];
846        let len = server.read(&mut buf).await?;
847        if len == 0 {
848            return Err("server closed accidentially".into());
849        }
850        rx += len;
851        let event = IncomingDataEvent {
852            direction: IncomingDirection::FromServer,
853            buffer: &buf[..len],
854        };
855        proxy_handler.push_data(event).await?;
856
857        let data = proxy_handler.peek_data(dir).buffer;
858        let len = data.len();
859        if len > 0 {
860            server.write_all(data).await?;
861            proxy_handler.consume_data(dir, len);
862            tx += len;
863        }
864    }
865    crate::traffic_status::traffic_status_update(tx, rx)?;
866    Ok(proxy_handler.get_udp_associate())
867}