1use crate::{
2 error::{Error, Result},
3 protocol::{Address, AddressType, AsyncStreamOperation, AuthMethod, Command, Reply, StreamOperation, UserKey, Version},
4};
5use std::{
6 fmt::Debug,
7 io::Cursor,
8 net::{SocketAddr, ToSocketAddrs},
9 time::Duration,
10};
11use tokio::{
12 io::{AsyncReadExt, AsyncWriteExt, BufStream},
13 net::{TcpStream, UdpSocket},
14};
15
16#[async_trait::async_trait]
17pub trait Socks5Reader: AsyncReadExt + Unpin {
18 async fn read_version(&mut self) -> Result<()> {
19 let value = Version::try_from(self.read_u8().await?)?;
20 match value {
21 Version::V4 => Err(Error::WrongVersion),
22 Version::V5 => Ok(()),
23 }
24 }
25
26 async fn read_method(&mut self) -> Result<AuthMethod> {
27 let value = AuthMethod::from(self.read_u8().await?);
28 match value {
29 AuthMethod::NoAuth | AuthMethod::UserPass => Ok(value),
30 _ => Err(Error::InvalidAuthMethod(value)),
31 }
32 }
33
34 async fn read_command(&mut self) -> Result<Command> {
35 let value = self.read_u8().await?;
36 Ok(Command::try_from(value)?)
37 }
38
39 async fn read_atyp(&mut self) -> Result<AddressType> {
40 let value = self.read_u8().await?;
41 Ok(AddressType::try_from(value)?)
42 }
43
44 async fn read_reserved(&mut self) -> Result<()> {
45 let value = self.read_u8().await?;
46 match value {
47 0x00 => Ok(()),
48 _ => Err(Error::InvalidReserved(value)),
49 }
50 }
51
52 async fn read_fragment_id(&mut self) -> Result<()> {
53 let value = self.read_u8().await?;
54 if value == 0x00 {
55 Ok(())
56 } else {
57 Err(Error::InvalidFragmentId(value))
58 }
59 }
60
61 async fn read_reply(&mut self) -> Result<()> {
62 let value = self.read_u8().await?;
63 match Reply::try_from(value)? {
64 Reply::Succeeded => Ok(()),
65 reply => Err(format!("{reply}").into()),
66 }
67 }
68
69 async fn read_address(&mut self) -> Result<Address> {
70 Ok(Address::retrieve_from_async_stream(self).await?)
71 }
72
73 async fn read_string(&mut self) -> Result<String> {
74 let len = self.read_u8().await? as usize;
75 let mut str = vec![0; len];
76 self.read_exact(&mut str).await?;
77 let str = String::from_utf8(str)?;
78 Ok(str)
79 }
80
81 async fn read_auth_version(&mut self) -> Result<()> {
82 let value = self.read_u8().await?;
83 if value != 0x01 {
84 return Err(Error::InvalidAuthSubnegotiation(value));
85 }
86 Ok(())
87 }
88
89 async fn read_auth_status(&mut self) -> Result<()> {
90 let value = self.read_u8().await?;
91 if value != 0x00 {
92 return Err(Error::InvalidAuthStatus(value));
93 }
94 Ok(())
95 }
96
97 async fn read_selection_msg(&mut self) -> Result<AuthMethod> {
98 self.read_version().await?;
99 self.read_method().await
100 }
101
102 async fn read_final(&mut self) -> Result<Address> {
103 self.read_version().await?;
104 self.read_reply().await?;
105 self.read_reserved().await?;
106 let addr = self.read_address().await?;
107 Ok(addr)
108 }
109}
110
111#[async_trait::async_trait]
112impl<T: AsyncReadExt + Unpin> Socks5Reader for T {}
113
114#[async_trait::async_trait]
115pub trait Socks5Writer: AsyncWriteExt + Unpin {
116 async fn write_version(&mut self) -> Result<()> {
117 self.write_u8(0x05).await?;
118 Ok(())
119 }
120
121 async fn write_method(&mut self, method: AuthMethod) -> Result<()> {
122 self.write_u8(u8::from(method)).await?;
123 Ok(())
124 }
125
126 async fn write_command(&mut self, command: Command) -> Result<()> {
127 self.write_u8(u8::from(command)).await?;
128 Ok(())
129 }
130
131 async fn write_atyp(&mut self, atyp: AddressType) -> Result<()> {
132 self.write_u8(u8::from(atyp)).await?;
133 Ok(())
134 }
135
136 async fn write_reserved(&mut self) -> Result<()> {
137 self.write_u8(0x00).await?;
138 Ok(())
139 }
140
141 async fn write_fragment_id(&mut self, id: u8) -> Result<()> {
142 self.write_u8(id).await?;
143 Ok(())
144 }
145
146 async fn write_address(&mut self, address: &Address) -> Result<()> {
147 address.write_to_async_stream(self).await?;
148 Ok(())
149 }
150
151 async fn write_string(&mut self, string: &str) -> Result<()> {
152 let bytes = string.as_bytes();
153 if bytes.len() > 255 {
154 return Err("Too long string".into());
155 }
156 self.write_u8(bytes.len() as u8).await?;
157 self.write_all(bytes).await?;
158 Ok(())
159 }
160
161 async fn write_auth_version(&mut self) -> Result<()> {
162 self.write_u8(0x01).await?;
163 Ok(())
164 }
165
166 async fn write_methods(&mut self, methods: &[AuthMethod]) -> Result<()> {
167 self.write_u8(methods.len() as u8).await?;
168 for method in methods {
169 self.write_method(*method).await?;
170 }
171 Ok(())
172 }
173
174 async fn write_selection_msg(&mut self, methods: &[AuthMethod]) -> Result<()> {
175 self.write_version().await?;
176 self.write_methods(methods).await?;
177 self.flush().await?;
178 Ok(())
179 }
180
181 async fn write_final(&mut self, command: Command, addr: &Address) -> Result<()> {
182 self.write_version().await?;
183 self.write_command(command).await?;
184 self.write_reserved().await?;
185 self.write_address(addr).await?;
186 self.flush().await?;
187 Ok(())
188 }
189}
190
191#[async_trait::async_trait]
192impl<T: AsyncWriteExt + Unpin> Socks5Writer for T {}
193
194async fn username_password_auth<S>(stream: &mut S, auth: &UserKey) -> Result<()>
195where
196 S: Socks5Writer + Socks5Reader + Send,
197{
198 stream.write_auth_version().await?;
199 stream.write_string(&auth.username).await?;
200 stream.write_string(&auth.password).await?;
201 stream.flush().await?;
202
203 stream.read_auth_version().await?;
204 stream.read_auth_status().await
205}
206
207async fn init<S, A>(stream: &mut S, command: Command, addr: A, auth: Option<UserKey>) -> Result<Address>
208where
209 S: Socks5Writer + Socks5Reader + Send,
210 A: Into<Address>,
211{
212 let addr: Address = addr.into();
213
214 let mut methods = Vec::with_capacity(2);
215 methods.push(AuthMethod::NoAuth);
216 if auth.is_some() {
217 methods.push(AuthMethod::UserPass);
218 }
219 stream.write_selection_msg(&methods).await?;
220 stream.flush().await?;
221
222 let method: AuthMethod = stream.read_selection_msg().await?;
223 match method {
224 AuthMethod::NoAuth => {}
225 AuthMethod::UserPass if auth.is_some() => {
226 username_password_auth(stream, auth.as_ref().unwrap()).await?;
227 }
228 _ => return Err(Error::InvalidAuthMethod(method)),
229 }
230
231 stream.write_final(command, &addr).await?;
232 stream.read_final().await
233}
234
235pub async fn connect<S, A>(socket: &mut S, addr: A, auth: Option<UserKey>) -> Result<Address>
254where
255 S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
256 A: Into<Address>,
257{
258 init(socket, Command::Connect, addr, auth).await
259}
260
261#[derive(Debug)]
281pub struct SocksListener<S> {
282 stream: S,
283 proxy_addr: Address,
284}
285
286impl<S> SocksListener<S>
287where
288 S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
289{
290 pub async fn bind<A>(mut stream: S, addr: A, auth: Option<UserKey>) -> Result<Self>
294 where
295 A: Into<Address>,
296 {
297 let addr = init(&mut stream, Command::Bind, addr, auth).await?;
298 Ok(Self { stream, proxy_addr: addr })
299 }
300
301 pub fn proxy_addr(&self) -> &Address {
302 &self.proxy_addr
303 }
304
305 pub async fn accept(mut self) -> Result<(S, Address)> {
306 let addr = self.stream.read_final().await?;
307 Ok((self.stream, addr))
308 }
309}
310
311#[derive(Debug)]
313pub struct SocksDatagram<S> {
314 socket: UdpSocket,
315 proxy_addr: Address,
316 stream: S,
317}
318
319impl<S> SocksDatagram<S>
320where
321 S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
322{
323 pub async fn udp_associate(mut stream: S, socket: UdpSocket, auth: Option<UserKey>) -> Result<Self> {
327 let addr = if socket.local_addr()?.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
328 let addr = addr.parse::<SocketAddr>()?;
329 let proxy_addr = init(&mut stream, Command::UdpAssociate, addr, auth).await?;
330 let addr = proxy_addr.to_socket_addrs()?.next().ok_or("InvalidAddress")?;
331 socket.connect(addr).await?;
332 Ok(Self {
333 socket,
334 proxy_addr,
335 stream,
336 })
337 }
338
339 pub fn proxy_addr(&self) -> &Address {
341 &self.proxy_addr
342 }
343
344 pub fn get_ref(&self) -> &UdpSocket {
346 &self.socket
347 }
348
349 pub fn get_mut(&mut self) -> &mut UdpSocket {
351 &mut self.socket
352 }
353
354 pub fn into_inner(self) -> (S, UdpSocket) {
356 (self.stream, self.socket)
357 }
358
359 pub async fn build_socks5_udp_datagram(buf: &[u8], addr: &Address) -> Result<Vec<u8>> {
369 let bytes_size = Self::get_buf_size(addr.len(), buf.len());
370 let bytes = Vec::with_capacity(bytes_size);
371
372 let mut cursor = Cursor::new(bytes);
373 cursor.write_reserved().await?;
374 cursor.write_reserved().await?;
375 cursor.write_fragment_id(0x00).await?;
376 cursor.write_address(addr).await?;
377 cursor.write_all(buf).await?;
378
379 let bytes = cursor.into_inner();
380 Ok(bytes)
381 }
382
383 pub async fn send_to<A>(&self, buf: &[u8], addr: A) -> Result<usize>
385 where
386 A: Into<Address>,
387 {
388 let addr: Address = addr.into();
389 let bytes = Self::build_socks5_udp_datagram(buf, &addr).await?;
390 Ok(self.socket.send(&bytes).await?)
391 }
392
393 async fn parse_socks5_udp_response(bytes: &mut [u8], buf: &mut Vec<u8>) -> Result<(usize, Address)> {
395 let len = bytes.len();
396 let mut cursor = Cursor::new(bytes);
397 cursor.read_reserved().await?;
398 cursor.read_reserved().await?;
399 cursor.read_fragment_id().await?;
400 let addr = cursor.read_address().await?;
401 let header_len = cursor.position() as usize;
402 buf.resize(len - header_len, 0);
403 _ = cursor.read_exact(buf).await?;
404 Ok((len - header_len, addr))
405 }
406
407 pub async fn recv_from(&self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address)> {
409 const UDP_MTU: usize = 1500;
410 let bytes_size = UDP_MTU;
412 let mut bytes = vec![0; bytes_size];
413 let len = tokio::time::timeout(timeout, self.socket.recv(&mut bytes)).await??;
414 bytes.truncate(len);
415 let (read, addr) = Self::parse_socks5_udp_response(&mut bytes, buf).await?;
416 Ok((read, addr))
417 }
418
419 fn get_buf_size(addr_size: usize, buf_len: usize) -> usize {
420 2 + 1 + addr_size + buf_len
422 }
423}
424
425pub type GuardTcpStream = BufStream<TcpStream>;
426pub type SocksUdpClient = SocksDatagram<GuardTcpStream>;
427
428#[async_trait::async_trait]
429pub trait UdpClientTrait {
430 async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize>
431 where
432 A: Into<Address> + Send + Unpin;
433
434 async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address)>;
435}
436
437#[async_trait::async_trait]
438impl UdpClientTrait for SocksUdpClient {
439 async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
440 where
441 A: Into<Address> + Send + Unpin,
442 {
443 SocksDatagram::send_to(self, buf, addr).await
444 }
445
446 async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address), Error> {
447 SocksDatagram::recv_from(self, timeout, buf).await
448 }
449}
450
451pub async fn create_udp_client<A: Into<SocketAddr>>(proxy_addr: A, auth: Option<UserKey>) -> Result<SocksUdpClient> {
452 let proxy_addr = proxy_addr.into();
453 let client_addr = if proxy_addr.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
454 let proxy = TcpStream::connect(proxy_addr).await?;
455 let proxy = BufStream::new(proxy);
456 let client = UdpSocket::bind(client_addr).await?;
457 SocksDatagram::udp_associate(proxy, client, auth).await
458}
459
460pub struct UdpClientImpl<C> {
461 client: C,
462 server_addr: Address,
463}
464
465impl UdpClientImpl<SocksUdpClient> {
466 pub async fn transfer_data(&self, data: &[u8], timeout: Duration) -> Result<Vec<u8>> {
467 let len = self.client.send_to(data, &self.server_addr).await?;
468 let buf = SocksDatagram::<GuardTcpStream>::build_socks5_udp_datagram(data, &self.server_addr).await?;
469 assert_eq!(len, buf.len());
470
471 let mut buf = Vec::with_capacity(data.len());
472 let (_len, _) = self.client.recv_from(timeout, &mut buf).await?;
473 Ok(buf)
474 }
475
476 pub async fn datagram<A1, A2>(proxy_addr: A1, udp_server_addr: A2, auth: Option<UserKey>) -> Result<Self>
477 where
478 A1: Into<SocketAddr>,
479 A2: Into<Address>,
480 {
481 let client = create_udp_client(proxy_addr, auth).await?;
482
483 let server_addr = udp_server_addr.into();
484
485 Ok(Self { client, server_addr })
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use crate::{
492 Error, Result,
493 client::{self, SocksListener, SocksUdpClient, UdpClientTrait},
494 protocol::{Address, UserKey},
495 };
496 use std::{
497 net::{SocketAddr, ToSocketAddrs},
498 sync::Arc,
499 time::Duration,
500 };
501 use tokio::{
502 io::{AsyncReadExt, AsyncWriteExt, BufStream},
503 net::{TcpStream, UdpSocket},
504 };
505
506 const PROXY_ADDR: &str = "127.0.0.1:1080";
507 const PROXY_AUTH_ADDR: &str = "127.0.0.1:1081";
508 const DATA: &[u8] = b"Hello, world!";
509
510 async fn connect(addr: &str, auth: Option<UserKey>) {
511 let socket = TcpStream::connect(addr).await.unwrap();
512 let mut socket = BufStream::new(socket);
513 client::connect(&mut socket, Address::from(("baidu.com", 80)), auth).await.unwrap();
514 }
515
516 #[ignore]
517 #[tokio::test]
518 async fn connect_auth() {
519 connect(PROXY_AUTH_ADDR, Some(UserKey::new("hyper", "proxy"))).await;
520 }
521
522 #[ignore]
523 #[tokio::test]
524 async fn connect_no_auth() {
525 connect(PROXY_ADDR, None).await;
526 }
527
528 #[ignore]
529 #[should_panic = "InvalidAuthMethod(NoAcceptableMethods)"]
530 #[tokio::test]
531 async fn connect_no_auth_panic() {
532 connect(PROXY_AUTH_ADDR, None).await;
533 }
534
535 #[ignore]
536 #[tokio::test]
537 async fn bind() {
538 let run_block = async {
539 let server_addr = Address::from(("127.0.0.1", 8000));
540
541 let client = TcpStream::connect(PROXY_ADDR).await?;
542 let client = BufStream::new(client);
543 let client = SocksListener::bind(client, server_addr, None).await?;
544
545 let server_addr = client.proxy_addr.to_socket_addrs()?.next().ok_or("Invalid address")?;
546 let mut server = TcpStream::connect(&server_addr).await?;
547
548 let (mut client, _) = client.accept().await?;
549
550 server.write_all(DATA).await?;
551
552 let mut buf = [0; DATA.len()];
553 client.read_exact(&mut buf).await?;
554 assert_eq!(buf, DATA);
555 Ok::<_, Error>(())
556 };
557 if let Err(e) = run_block.await {
558 println!("{e:?}");
559 }
560 }
561
562 type TestHalves = (Arc<SocksUdpClient>, Arc<SocksUdpClient>);
563
564 #[async_trait::async_trait]
565 impl UdpClientTrait for TestHalves {
566 async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
567 where
568 A: Into<Address> + Send,
569 {
570 self.1.send_to(buf, addr).await
571 }
572
573 async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address), Error> {
574 self.0.recv_from(timeout, buf).await
575 }
576 }
577
578 const SERVER_ADDR: &str = "127.0.0.1:23456";
579
580 struct UdpTest<C> {
581 client: C,
582 server: UdpSocket,
583 server_addr: Address,
584 }
585
586 impl<C: UdpClientTrait> UdpTest<C> {
587 async fn test(mut self) {
588 let mut buf = vec![0; DATA.len()];
589 self.client.send_to(DATA, self.server_addr).await.unwrap();
590 let (len, addr) = self.server.recv_from(&mut buf).await.unwrap();
591 assert_eq!(len, buf.len());
592 assert_eq!(buf.as_slice(), DATA);
593
594 let mut buf = vec![0; DATA.len()];
595 self.server.send_to(DATA, addr).await.unwrap();
596 let timeout = Duration::from_secs(5);
597 let (len, _) = self.client.recv_from(timeout, &mut buf).await.unwrap();
598 assert_eq!(len, buf.len());
599 assert_eq!(buf.as_slice(), DATA);
600 }
601 }
602
603 impl UdpTest<SocksUdpClient> {
604 async fn datagram() -> Self {
605 let addr = PROXY_ADDR.parse::<SocketAddr>().unwrap();
606 let client = client::create_udp_client(addr, None).await.unwrap();
607
608 let server_addr: SocketAddr = SERVER_ADDR.parse().unwrap();
609 let server = UdpSocket::bind(server_addr).await.unwrap();
610 let server_addr = Address::from(server_addr);
611
612 Self {
613 client,
614 server,
615 server_addr,
616 }
617 }
618 }
619
620 impl UdpTest<TestHalves> {
621 async fn halves() -> Self {
622 let this = UdpTest::<SocksUdpClient>::datagram().await;
623 let client = Arc::new(this.client);
624 Self {
625 client: (client.clone(), client),
626 server: this.server,
627 server_addr: this.server_addr,
628 }
629 }
630 }
631
632 #[ignore]
633 #[tokio::test]
634 async fn udp_datagram_halves() {
635 UdpTest::halves().await.test().await
636 }
637}