socks5_impl/client/
mod.rs

1use crate::{
2    error::{Error, Result},
3    protocol::{Address, AddressType, AsyncStreamOperation, AuthMethod, Command, Reply, StreamOperation, UserKey, Version},
4};
5use std::{
6    fmt::Debug,
7    io::Cursor,
8    net::{SocketAddr, ToSocketAddrs},
9    time::Duration,
10};
11use tokio::{
12    io::{AsyncReadExt, AsyncWriteExt, BufStream},
13    net::{TcpStream, UdpSocket},
14};
15
16#[async_trait::async_trait]
17pub trait Socks5Reader: AsyncReadExt + Unpin {
18    async fn read_version(&mut self) -> Result<()> {
19        let value = Version::try_from(self.read_u8().await?)?;
20        match value {
21            Version::V4 => Err(Error::WrongVersion),
22            Version::V5 => Ok(()),
23        }
24    }
25
26    async fn read_method(&mut self) -> Result<AuthMethod> {
27        let value = AuthMethod::from(self.read_u8().await?);
28        match value {
29            AuthMethod::NoAuth | AuthMethod::UserPass => Ok(value),
30            _ => Err(Error::InvalidAuthMethod(value)),
31        }
32    }
33
34    async fn read_command(&mut self) -> Result<Command> {
35        let value = self.read_u8().await?;
36        Ok(Command::try_from(value)?)
37    }
38
39    async fn read_atyp(&mut self) -> Result<AddressType> {
40        let value = self.read_u8().await?;
41        Ok(AddressType::try_from(value)?)
42    }
43
44    async fn read_reserved(&mut self) -> Result<()> {
45        let value = self.read_u8().await?;
46        match value {
47            0x00 => Ok(()),
48            _ => Err(Error::InvalidReserved(value)),
49        }
50    }
51
52    async fn read_fragment_id(&mut self) -> Result<()> {
53        let value = self.read_u8().await?;
54        if value == 0x00 {
55            Ok(())
56        } else {
57            Err(Error::InvalidFragmentId(value))
58        }
59    }
60
61    async fn read_reply(&mut self) -> Result<()> {
62        let value = self.read_u8().await?;
63        match Reply::try_from(value)? {
64            Reply::Succeeded => Ok(()),
65            reply => Err(format!("{reply}").into()),
66        }
67    }
68
69    async fn read_address(&mut self) -> Result<Address> {
70        Ok(Address::retrieve_from_async_stream(self).await?)
71    }
72
73    async fn read_string(&mut self) -> Result<String> {
74        let len = self.read_u8().await? as usize;
75        let mut str = vec![0; len];
76        self.read_exact(&mut str).await?;
77        let str = String::from_utf8(str)?;
78        Ok(str)
79    }
80
81    async fn read_auth_version(&mut self) -> Result<()> {
82        let value = self.read_u8().await?;
83        if value != 0x01 {
84            return Err(Error::InvalidAuthSubnegotiation(value));
85        }
86        Ok(())
87    }
88
89    async fn read_auth_status(&mut self) -> Result<()> {
90        let value = self.read_u8().await?;
91        if value != 0x00 {
92            return Err(Error::InvalidAuthStatus(value));
93        }
94        Ok(())
95    }
96
97    async fn read_selection_msg(&mut self) -> Result<AuthMethod> {
98        self.read_version().await?;
99        self.read_method().await
100    }
101
102    async fn read_final(&mut self) -> Result<Address> {
103        self.read_version().await?;
104        self.read_reply().await?;
105        self.read_reserved().await?;
106        let addr = self.read_address().await?;
107        Ok(addr)
108    }
109}
110
111#[async_trait::async_trait]
112impl<T: AsyncReadExt + Unpin> Socks5Reader for T {}
113
114#[async_trait::async_trait]
115pub trait Socks5Writer: AsyncWriteExt + Unpin {
116    async fn write_version(&mut self) -> Result<()> {
117        self.write_u8(0x05).await?;
118        Ok(())
119    }
120
121    async fn write_method(&mut self, method: AuthMethod) -> Result<()> {
122        self.write_u8(u8::from(method)).await?;
123        Ok(())
124    }
125
126    async fn write_command(&mut self, command: Command) -> Result<()> {
127        self.write_u8(u8::from(command)).await?;
128        Ok(())
129    }
130
131    async fn write_atyp(&mut self, atyp: AddressType) -> Result<()> {
132        self.write_u8(u8::from(atyp)).await?;
133        Ok(())
134    }
135
136    async fn write_reserved(&mut self) -> Result<()> {
137        self.write_u8(0x00).await?;
138        Ok(())
139    }
140
141    async fn write_fragment_id(&mut self, id: u8) -> Result<()> {
142        self.write_u8(id).await?;
143        Ok(())
144    }
145
146    async fn write_address(&mut self, address: &Address) -> Result<()> {
147        address.write_to_async_stream(self).await?;
148        Ok(())
149    }
150
151    async fn write_string(&mut self, string: &str) -> Result<()> {
152        let bytes = string.as_bytes();
153        if bytes.len() > 255 {
154            return Err("Too long string".into());
155        }
156        self.write_u8(bytes.len() as u8).await?;
157        self.write_all(bytes).await?;
158        Ok(())
159    }
160
161    async fn write_auth_version(&mut self) -> Result<()> {
162        self.write_u8(0x01).await?;
163        Ok(())
164    }
165
166    async fn write_methods(&mut self, methods: &[AuthMethod]) -> Result<()> {
167        self.write_u8(methods.len() as u8).await?;
168        for method in methods {
169            self.write_method(*method).await?;
170        }
171        Ok(())
172    }
173
174    async fn write_selection_msg(&mut self, methods: &[AuthMethod]) -> Result<()> {
175        self.write_version().await?;
176        self.write_methods(methods).await?;
177        self.flush().await?;
178        Ok(())
179    }
180
181    async fn write_final(&mut self, command: Command, addr: &Address) -> Result<()> {
182        self.write_version().await?;
183        self.write_command(command).await?;
184        self.write_reserved().await?;
185        self.write_address(addr).await?;
186        self.flush().await?;
187        Ok(())
188    }
189}
190
191#[async_trait::async_trait]
192impl<T: AsyncWriteExt + Unpin> Socks5Writer for T {}
193
194async fn username_password_auth<S>(stream: &mut S, auth: &UserKey) -> Result<()>
195where
196    S: Socks5Writer + Socks5Reader + Send,
197{
198    stream.write_auth_version().await?;
199    stream.write_string(&auth.username).await?;
200    stream.write_string(&auth.password).await?;
201    stream.flush().await?;
202
203    stream.read_auth_version().await?;
204    stream.read_auth_status().await
205}
206
207async fn init<S, A>(stream: &mut S, command: Command, addr: A, auth: Option<UserKey>) -> Result<Address>
208where
209    S: Socks5Writer + Socks5Reader + Send,
210    A: Into<Address>,
211{
212    let addr: Address = addr.into();
213
214    let mut methods = Vec::with_capacity(2);
215    methods.push(AuthMethod::NoAuth);
216    if auth.is_some() {
217        methods.push(AuthMethod::UserPass);
218    }
219    stream.write_selection_msg(&methods).await?;
220    stream.flush().await?;
221
222    let method: AuthMethod = stream.read_selection_msg().await?;
223    match method {
224        AuthMethod::NoAuth => {}
225        AuthMethod::UserPass if auth.is_some() => {
226            username_password_auth(stream, auth.as_ref().unwrap()).await?;
227        }
228        _ => return Err(Error::InvalidAuthMethod(method)),
229    }
230
231    stream.write_final(command, &addr).await?;
232    stream.read_final().await
233}
234
235/// Proxifies a TCP connection. Performs the [`CONNECT`] command under the hood.
236///
237/// [`CONNECT`]: https://tools.ietf.org/html/rfc1928#page-6
238///
239/// ```no_run
240/// # use socks5_impl::Result;
241/// # #[tokio::main(flavor = "current_thread")]
242/// # async fn main() -> Result<()> {
243/// use socks5_impl::client;
244/// use tokio::{io::BufStream, net::TcpStream};
245///
246/// let stream = TcpStream::connect("my-proxy-server.com:54321").await?;
247/// let mut stream = BufStream::new(stream);
248/// client::connect(&mut stream, ("google.com", 80), None).await?;
249///
250/// # Ok(())
251/// # }
252/// ```
253pub async fn connect<S, A>(socket: &mut S, addr: A, auth: Option<UserKey>) -> Result<Address>
254where
255    S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
256    A: Into<Address>,
257{
258    init(socket, Command::Connect, addr, auth).await
259}
260
261/// A listener that accepts TCP connections through a proxy.
262///
263/// ```no_run
264/// # use socks5_impl::Result;
265/// # #[tokio::main(flavor = "current_thread")]
266/// # async fn main() -> Result<()> {
267/// use socks5_impl::client::SocksListener;
268/// use tokio::{io::BufStream, net::TcpStream};
269///
270/// let stream = TcpStream::connect("my-proxy-server.com:54321").await?;
271/// let mut stream = BufStream::new(stream);
272/// let (stream, addr) = SocksListener::bind(stream, ("ftp-server.org", 21), None)
273///     .await?
274///     .accept()
275///     .await?;
276///
277/// # Ok(())
278/// # }
279/// ```
280#[derive(Debug)]
281pub struct SocksListener<S> {
282    stream: S,
283    proxy_addr: Address,
284}
285
286impl<S> SocksListener<S>
287where
288    S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
289{
290    /// Creates `SocksListener`. Performs the [`BIND`] command under the hood.
291    ///
292    /// [`BIND`]: https://tools.ietf.org/html/rfc1928#page-6
293    pub async fn bind<A>(mut stream: S, addr: A, auth: Option<UserKey>) -> Result<Self>
294    where
295        A: Into<Address>,
296    {
297        let addr = init(&mut stream, Command::Bind, addr, auth).await?;
298        Ok(Self { stream, proxy_addr: addr })
299    }
300
301    pub fn proxy_addr(&self) -> &Address {
302        &self.proxy_addr
303    }
304
305    pub async fn accept(mut self) -> Result<(S, Address)> {
306        let addr = self.stream.read_final().await?;
307        Ok((self.stream, addr))
308    }
309}
310
311/// A UDP socket that sends packets through a proxy.
312#[derive(Debug)]
313pub struct SocksDatagram<S> {
314    socket: UdpSocket,
315    proxy_addr: Address,
316    stream: S,
317}
318
319impl<S> SocksDatagram<S>
320where
321    S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
322{
323    /// Creates `SocksDatagram`. Performs [`UDP ASSOCIATE`] under the hood.
324    ///
325    /// [`UDP ASSOCIATE`]: https://tools.ietf.org/html/rfc1928#page-7
326    pub async fn udp_associate(mut stream: S, socket: UdpSocket, auth: Option<UserKey>) -> Result<Self> {
327        let addr = if socket.local_addr()?.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
328        let addr = addr.parse::<SocketAddr>()?;
329        let proxy_addr = init(&mut stream, Command::UdpAssociate, addr, auth).await?;
330        let addr = proxy_addr.to_socket_addrs()?.next().ok_or("InvalidAddress")?;
331        socket.connect(addr).await?;
332        Ok(Self {
333            socket,
334            proxy_addr,
335            stream,
336        })
337    }
338
339    /// Returns the address of the associated udp address.
340    pub fn proxy_addr(&self) -> &Address {
341        &self.proxy_addr
342    }
343
344    /// Returns a reference to the underlying udp socket.
345    pub fn get_ref(&self) -> &UdpSocket {
346        &self.socket
347    }
348
349    /// Returns a mutable reference to the underlying udp socket.
350    pub fn get_mut(&mut self) -> &mut UdpSocket {
351        &mut self.socket
352    }
353
354    /// Returns the associated stream and udp socket.
355    pub fn into_inner(self) -> (S, UdpSocket) {
356        (self.stream, self.socket)
357    }
358
359    //  Builds a udp-based client request packet, the format is as follows:
360    //  +----+------+------+----------+----------+----------+
361    //  |RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
362    //  +----+------+------+----------+----------+----------+
363    //  | 2  |  1   |  1   | Variable |    2     | Variable |
364    //  +----+------+------+----------+----------+----------+
365    //  The reference link is as follows:
366    //  https://tools.ietf.org/html/rfc1928#page-8
367    //
368    pub async fn build_socks5_udp_datagram(buf: &[u8], addr: &Address) -> Result<Vec<u8>> {
369        let bytes_size = Self::get_buf_size(addr.len(), buf.len());
370        let bytes = Vec::with_capacity(bytes_size);
371
372        let mut cursor = Cursor::new(bytes);
373        cursor.write_reserved().await?;
374        cursor.write_reserved().await?;
375        cursor.write_fragment_id(0x00).await?;
376        cursor.write_address(addr).await?;
377        cursor.write_all(buf).await?;
378
379        let bytes = cursor.into_inner();
380        Ok(bytes)
381    }
382
383    /// Sends data via the udp socket to the given address.
384    pub async fn send_to<A>(&self, buf: &[u8], addr: A) -> Result<usize>
385    where
386        A: Into<Address>,
387    {
388        let addr: Address = addr.into();
389        let bytes = Self::build_socks5_udp_datagram(buf, &addr).await?;
390        Ok(self.socket.send(&bytes).await?)
391    }
392
393    /// Parses the udp-based server response packet, the format is same as the client request packet.
394    async fn parse_socks5_udp_response(bytes: &mut [u8], buf: &mut Vec<u8>) -> Result<(usize, Address)> {
395        let len = bytes.len();
396        let mut cursor = Cursor::new(bytes);
397        cursor.read_reserved().await?;
398        cursor.read_reserved().await?;
399        cursor.read_fragment_id().await?;
400        let addr = cursor.read_address().await?;
401        let header_len = cursor.position() as usize;
402        buf.resize(len - header_len, 0);
403        _ = cursor.read_exact(buf).await?;
404        Ok((len - header_len, addr))
405    }
406
407    /// Receives data from the udp socket and returns the number of bytes read and the origin of the data.
408    pub async fn recv_from(&self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address)> {
409        const UDP_MTU: usize = 1500;
410        // let bytes_size = Self::get_buf_size(Address::max_serialized_len(), buf.len());
411        let bytes_size = UDP_MTU;
412        let mut bytes = vec![0; bytes_size];
413        let len = tokio::time::timeout(timeout, self.socket.recv(&mut bytes)).await??;
414        bytes.truncate(len);
415        let (read, addr) = Self::parse_socks5_udp_response(&mut bytes, buf).await?;
416        Ok((read, addr))
417    }
418
419    fn get_buf_size(addr_size: usize, buf_len: usize) -> usize {
420        // reserved + fragment id + addr_size + buf_len
421        2 + 1 + addr_size + buf_len
422    }
423}
424
425pub type GuardTcpStream = BufStream<TcpStream>;
426pub type SocksUdpClient = SocksDatagram<GuardTcpStream>;
427
428#[async_trait::async_trait]
429pub trait UdpClientTrait {
430    async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize>
431    where
432        A: Into<Address> + Send + Unpin;
433
434    async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address)>;
435}
436
437#[async_trait::async_trait]
438impl UdpClientTrait for SocksUdpClient {
439    async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
440    where
441        A: Into<Address> + Send + Unpin,
442    {
443        SocksDatagram::send_to(self, buf, addr).await
444    }
445
446    async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address), Error> {
447        SocksDatagram::recv_from(self, timeout, buf).await
448    }
449}
450
451pub async fn create_udp_client<A: Into<SocketAddr>>(proxy_addr: A, auth: Option<UserKey>) -> Result<SocksUdpClient> {
452    let proxy_addr = proxy_addr.into();
453    let client_addr = if proxy_addr.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
454    let proxy = TcpStream::connect(proxy_addr).await?;
455    let proxy = BufStream::new(proxy);
456    let client = UdpSocket::bind(client_addr).await?;
457    SocksDatagram::udp_associate(proxy, client, auth).await
458}
459
460pub struct UdpClientImpl<C> {
461    client: C,
462    server_addr: Address,
463}
464
465impl UdpClientImpl<SocksUdpClient> {
466    pub async fn transfer_data(&self, data: &[u8], timeout: Duration) -> Result<Vec<u8>> {
467        let len = self.client.send_to(data, &self.server_addr).await?;
468        let buf = SocksDatagram::<GuardTcpStream>::build_socks5_udp_datagram(data, &self.server_addr).await?;
469        assert_eq!(len, buf.len());
470
471        let mut buf = Vec::with_capacity(data.len());
472        let (_len, _) = self.client.recv_from(timeout, &mut buf).await?;
473        Ok(buf)
474    }
475
476    pub async fn datagram<A1, A2>(proxy_addr: A1, udp_server_addr: A2, auth: Option<UserKey>) -> Result<Self>
477    where
478        A1: Into<SocketAddr>,
479        A2: Into<Address>,
480    {
481        let client = create_udp_client(proxy_addr, auth).await?;
482
483        let server_addr = udp_server_addr.into();
484
485        Ok(Self { client, server_addr })
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use crate::{
492        Error, Result,
493        client::{self, SocksListener, SocksUdpClient, UdpClientTrait},
494        protocol::{Address, UserKey},
495    };
496    use std::{
497        net::{SocketAddr, ToSocketAddrs},
498        sync::Arc,
499        time::Duration,
500    };
501    use tokio::{
502        io::{AsyncReadExt, AsyncWriteExt, BufStream},
503        net::{TcpStream, UdpSocket},
504    };
505
506    const PROXY_ADDR: &str = "127.0.0.1:1080";
507    const PROXY_AUTH_ADDR: &str = "127.0.0.1:1081";
508    const DATA: &[u8] = b"Hello, world!";
509
510    async fn connect(addr: &str, auth: Option<UserKey>) {
511        let socket = TcpStream::connect(addr).await.unwrap();
512        let mut socket = BufStream::new(socket);
513        client::connect(&mut socket, Address::from(("baidu.com", 80)), auth).await.unwrap();
514    }
515
516    #[ignore]
517    #[tokio::test]
518    async fn connect_auth() {
519        connect(PROXY_AUTH_ADDR, Some(UserKey::new("hyper", "proxy"))).await;
520    }
521
522    #[ignore]
523    #[tokio::test]
524    async fn connect_no_auth() {
525        connect(PROXY_ADDR, None).await;
526    }
527
528    #[ignore]
529    #[should_panic = "InvalidAuthMethod(NoAcceptableMethods)"]
530    #[tokio::test]
531    async fn connect_no_auth_panic() {
532        connect(PROXY_AUTH_ADDR, None).await;
533    }
534
535    #[ignore]
536    #[tokio::test]
537    async fn bind() {
538        let run_block = async {
539            let server_addr = Address::from(("127.0.0.1", 8000));
540
541            let client = TcpStream::connect(PROXY_ADDR).await?;
542            let client = BufStream::new(client);
543            let client = SocksListener::bind(client, server_addr, None).await?;
544
545            let server_addr = client.proxy_addr.to_socket_addrs()?.next().ok_or("Invalid address")?;
546            let mut server = TcpStream::connect(&server_addr).await?;
547
548            let (mut client, _) = client.accept().await?;
549
550            server.write_all(DATA).await?;
551
552            let mut buf = [0; DATA.len()];
553            client.read_exact(&mut buf).await?;
554            assert_eq!(buf, DATA);
555            Ok::<_, Error>(())
556        };
557        if let Err(e) = run_block.await {
558            println!("{e:?}");
559        }
560    }
561
562    type TestHalves = (Arc<SocksUdpClient>, Arc<SocksUdpClient>);
563
564    #[async_trait::async_trait]
565    impl UdpClientTrait for TestHalves {
566        async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
567        where
568            A: Into<Address> + Send,
569        {
570            self.1.send_to(buf, addr).await
571        }
572
573        async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address), Error> {
574            self.0.recv_from(timeout, buf).await
575        }
576    }
577
578    const SERVER_ADDR: &str = "127.0.0.1:23456";
579
580    struct UdpTest<C> {
581        client: C,
582        server: UdpSocket,
583        server_addr: Address,
584    }
585
586    impl<C: UdpClientTrait> UdpTest<C> {
587        async fn test(mut self) {
588            let mut buf = vec![0; DATA.len()];
589            self.client.send_to(DATA, self.server_addr).await.unwrap();
590            let (len, addr) = self.server.recv_from(&mut buf).await.unwrap();
591            assert_eq!(len, buf.len());
592            assert_eq!(buf.as_slice(), DATA);
593
594            let mut buf = vec![0; DATA.len()];
595            self.server.send_to(DATA, addr).await.unwrap();
596            let timeout = Duration::from_secs(5);
597            let (len, _) = self.client.recv_from(timeout, &mut buf).await.unwrap();
598            assert_eq!(len, buf.len());
599            assert_eq!(buf.as_slice(), DATA);
600        }
601    }
602
603    impl UdpTest<SocksUdpClient> {
604        async fn datagram() -> Self {
605            let addr = PROXY_ADDR.parse::<SocketAddr>().unwrap();
606            let client = client::create_udp_client(addr, None).await.unwrap();
607
608            let server_addr: SocketAddr = SERVER_ADDR.parse().unwrap();
609            let server = UdpSocket::bind(server_addr).await.unwrap();
610            let server_addr = Address::from(server_addr);
611
612            Self {
613                client,
614                server,
615                server_addr,
616            }
617        }
618    }
619
620    impl UdpTest<TestHalves> {
621        async fn halves() -> Self {
622            let this = UdpTest::<SocksUdpClient>::datagram().await;
623            let client = Arc::new(this.client);
624            Self {
625                client: (client.clone(), client),
626                server: this.server,
627                server_addr: this.server_addr,
628            }
629        }
630    }
631
632    #[ignore]
633    #[tokio::test]
634    async fn udp_datagram_halves() {
635        UdpTest::halves().await.test().await
636    }
637}