1use std::{
2 io::{self, Cursor},
3 net::SocketAddr,
4};
5
6use bytes::BytesMut;
7use tokio::{
8 io::{copy, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
9 net::{ToSocketAddrs, UdpSocket},
10};
11use tracing::error;
12
13pub use crate::crypto::util::*;
14use crate::Error;
15use crate::{crypto::PacketCipher, Address};
16
17pub async fn copy_bidirectional<SA, SB>(a: SA, b: SB) -> Result<(u64, u64), Error>
18where
19 SA: AsyncRead + AsyncWrite + Unpin + Send + 'static,
20 SB: AsyncRead + AsyncWrite + Unpin + Send + 'static,
21{
22 let (mut ar, mut aw) = tokio::io::split(a);
23 let (mut br, mut bw) = tokio::io::split(b);
24
25 let handle = tokio::spawn(async move {
27 let rn = copy(&mut br, &mut aw).await;
28 let result = aw.shutdown().await;
29 if let Err(e) = result {
30 error!("shutdown stream a err {}", e);
31 }
32 let n = match rn {
33 Ok(n) => n,
34 Err(e) => return Err(Error::CopyError(e, "b -> a".into())),
35 };
36 Ok::<u64, Error>(n)
37 });
38
39 let rn = copy(&mut ar, &mut bw).await;
41 let result = bw.shutdown().await;
42 if let Err(e) = result {
43 error!("shutdown stream b err {}", e);
44 }
45
46 let b2a = handle.await.unwrap()?;
47
48 let a2b = match rn {
49 Ok(n) => n,
50 Err(e) => return Err(Error::CopyError(e, "a -> b".into())),
51 };
52
53 Ok((a2b, b2a))
54}
55
56impl PacketCipher {
58 pub async fn send_to<A: ToSocketAddrs>(
60 &self,
61 socket: &UdpSocket,
62 buf: &[u8],
63 target: A,
64 socks5_address: SocketAddr,
65 ) -> Result<usize, Error> {
66 let mut addr = BytesMut::new();
67
68 Address::write_socket_addr_to_buf(&socks5_address, &mut addr);
69
70 let data = self.encrypt_vec_slice_to(vec![&addr, buf])?;
71
72 let n = socket.send_to(&data, target).await?;
73 Ok(n)
74 }
75
76 pub async fn recv_from(
77 &self,
78 socket: &UdpSocket,
79 buf: &mut [u8],
80 ) -> Result<(usize, SocketAddr, Address), Error> {
81 let (n, peer) = socket.recv_from(buf).await?;
82
83 let data_size = self.decrypt_from(&mut buf[..n])?;
84
85 let mut cur = Cursor::new(&mut buf[..data_size]);
86
87 let target = Address::read_from(&mut cur).await?;
88
89 let pos = cur.position() as usize;
90 let payload = cur.into_inner();
91 payload.copy_within(pos.., 0);
92
93 Ok((payload.len() - pos, peer, target))
94 }
95}
96
97pub async fn read_forever<R>(reader: &mut R) -> io::Result<()>
98where
99 R: AsyncRead + Unpin,
100{
101 static mut READ_FOREVER_BUF: &mut [u8] = &mut [0u8; 1024];
102 loop {
103 let n = unsafe { reader.read(READ_FOREVER_BUF).await? };
104 if n == 0 {
105 break;
106 }
107 }
108 Ok(())
109}