socks5_proxy/
client.rs

1use crate::utils::*;
2
3use std::{
4    io,
5    net::SocketAddr,
6    ops::{Deref, DerefMut},
7};
8use tokio::{
9    io::{AsyncReadExt, AsyncWriteExt, Result},
10    net::{TcpStream, ToSocketAddrs},
11};
12
13pub async fn new(
14    server: impl ToSocketAddrs,
15    dest: &Addr,
16    auth: Option<AuthMethod>,
17) -> Result<TcpStream> {
18    let conn = TcpStream::connect(server).await?;
19    let auth = auth.unwrap_or(AuthMethod::NoAuth);
20
21    let client = PendingHandshake(conn);
22    let client = client.handshake(&auth).await?;
23    let client = client.authenticate(&auth).await?;
24    let client = client.connect(dest).await?;
25
26    Ok(client)
27}
28
29impl_deref!(PendingHandshake, TcpStream);
30impl PendingHandshake {
31    #[inline]
32    async fn handshake(mut self, method: &AuthMethod) -> Result<PendingAuthenticate> {
33        let msg: &[u8] = &[SOCKS_VER, 0x01, method.to_code()];
34        self.write_all(msg).await?;
35        self.flush().await?;
36
37        let mut buffer = [0; 2];
38        self.read_exact(&mut buffer).await?;
39
40        if buffer[0] != SOCKS_VER {
41            return Err(io::Error::new(
42                io::ErrorKind::ConnectionAborted,
43                "unsupported protocol",
44            ));
45        }
46
47        let auth = AuthMethod::from_code(buffer[1])?;
48
49        if let AuthMethod::NoAvailable = auth {
50            Err(io::Error::new(
51                io::ErrorKind::ConnectionRefused,
52                "no supported authenticate method available",
53            ))
54        } else if auth.to_code() != method.to_code() {
55            Err(io::Error::new(
56                io::ErrorKind::ConnectionAborted,
57                "unsupported protocol",
58            ))
59        } else {
60            Ok(PendingAuthenticate(self.0))
61        }
62    }
63}
64
65impl_deref!(PendingAuthenticate, TcpStream);
66impl PendingAuthenticate {
67    #[inline]
68    async fn authenticate(self, auth: &AuthMethod) -> Result<PendingConnect> {
69        match auth {
70            AuthMethod::NoAuth => Ok(PendingConnect(self.0)),
71            _ => Err(io::Error::new(
72                io::ErrorKind::Other,
73                format!("authenticate method {:?} not implemented", &auth),
74            )),
75        }
76    }
77}
78
79impl_deref!(PendingConnect, TcpStream);
80impl PendingConnect {
81    #[inline]
82    async fn connect(mut self, dest: &Addr) -> Result<TcpStream> {
83        let mut buffer = [0u8; 4 + 255 + 2];
84        let mut request = Buffer::from(&mut buffer);
85        request.extend(&[SOCKS_RSV, SOCKS_COMMAND_CONNECT, SOCKS_RSV]);
86
87        parse_dest(&mut request, dest)?;
88
89        self.write_all(request.content()).await?;
90        self.flush().await?;
91
92        let header: &mut [u8] = &mut buffer[..4];
93
94        self.read_exact(header).await?;
95
96        if header[0] != SOCKS_VER || header[02] != SOCKS_RSV {
97            return Err(io::Error::new(
98                io::ErrorKind::ConnectionAborted,
99                "unsupported protocol",
100            ));
101        }
102        if header[1] != SocksError::SUCCESS as u8 {
103            return Err(SocksError::from(header[1]).into());
104        }
105
106        self.extract_address(header[3], &mut buffer).await?;
107
108        Ok(self.0)
109    }
110
111    async fn extract_address(&mut self, addr_type: u8, buffer: &mut [u8]) -> Result<()> {
112        match addr_type {
113            SOCKS_ADDR_IPV4 => self.read_exact(&mut buffer[..4 + 2]).await?,
114            SOCKS_ADDR_IPV6 => self.read_exact(&mut buffer[..16 + 2]).await?,
115            SOCKS_ADDR_DOMAINNAME => {
116                self.read_exact(&mut buffer[..1]).await?;
117                let len = buffer[0] as usize;
118                self.read_exact(&mut buffer[..(len + 2)]).await?
119            }
120            _ => {
121                return Err(io::Error::new(
122                    io::ErrorKind::ConnectionAborted,
123                    "unsupported address type",
124                ))
125            }
126        };
127        Ok(())
128    }
129}
130
131macro_rules! write_addr_binary {
132    ($buffer:ident,$addr_type:ident,$addr:ident) => {{
133        $buffer.push($addr_type);
134        $buffer.extend(&$addr.ip().octets());
135        $buffer.extend(&$addr.port().to_be_bytes());
136    }};
137}
138
139#[inline]
140fn parse_dest(request: &mut Buffer, dest: &Addr) -> Result<()> {
141    match dest {
142        Addr::SocketAddr(addr) => {
143            match addr {
144                SocketAddr::V4(v4) => write_addr_binary!(request, SOCKS_ADDR_IPV4, v4),
145                SocketAddr::V6(v6) => write_addr_binary!(request, SOCKS_ADDR_IPV6, v6),
146            };
147        }
148        Addr::HostnamePort(hostname_port) => {
149            request.push(SOCKS_ADDR_DOMAINNAME);
150            let mut hostname_port = hostname_port.split(":");
151            let parse_err =
152                io::Error::new(io::ErrorKind::InvalidInput, "bad pattern in hostname:port");
153            let hostname = hostname_port.next();
154            let port = hostname_port.next();
155            let none = hostname_port.next();
156
157            if let (Some(hostname), Some(port), None) = (hostname, port, none) {
158                let hostname = hostname.as_bytes();
159                if hostname.len() > u8::MAX as usize {
160                    return Err(io::Error::new(
161                        io::ErrorKind::InvalidInput,
162                        "hostname too long",
163                    ));
164                }
165                request.push(hostname.len() as u8);
166                request.extend(hostname);
167                let port = port.parse::<u16>().map_err(|_| parse_err)?;
168                request.extend(&port.to_be_bytes());
169            } else {
170                return Err(parse_err);
171            }
172        }
173    }
174    Ok(())
175}