socks5_proxy/
server.rs

1use crate::utils::*;
2use log::{error, info};
3use std::{
4    convert::TryInto,
5    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
6    ops::{Deref, DerefMut},
7    sync::Arc,
8};
9use thiserror::Error;
10use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11use tokio::net::{TcpSocket, TcpStream};
12
13type Result<T> = std::result::Result<T, Socks5ServerError>;
14
15#[derive(Debug, Error)]
16pub enum Socks5ServerError {
17    #[error("unrecognized protocol")]
18    UnknowProtocol,
19    #[error("unsupport authenticate method")]
20    UnsupportAuth,
21    #[error("unsupport socks5 command {0:#04X}")]
22    UnsupportCommand(u8),
23    #[error("unknow destination type {0:#04X}")]
24    UnknowAddrType(u8),
25    #[error("invalid hostname received")]
26    InvalidHost(#[from] std::str::Utf8Error),
27    #[error("DNS lookup error: {0}")]
28    DNSError(String),
29    #[error(transparent)]
30    IOError(#[from] io::Error),
31}
32pub struct Socks5Server {
33    conn: TcpSocket,
34    auth: Arc<AuthMethod>,
35}
36pub fn new(addr: SocketAddr, auth: Option<AuthMethod>) -> Result<Socks5Server> {
37    let conn = match addr {
38        SocketAddr::V4(_) => TcpSocket::new_v4()?,
39        SocketAddr::V6(_) => TcpSocket::new_v6()?,
40    };
41    conn.bind(addr)?;
42
43    let auth = auth.unwrap_or(AuthMethod::NoAuth);
44    let auth = Arc::new(auth);
45    Ok(Socks5Server { conn, auth })
46}
47
48impl Socks5Server {
49    pub async fn run(self) -> Result<()> {
50        let conn = self.conn.listen(1024)?;
51        loop {
52            let (conn, source) = conn.accept().await?;
53
54            let auth = self.auth.clone();
55            tokio::spawn(async move {
56                let result = handle_client(conn, auth).await;
57                if let Err(e) = result {
58                    error!("{:?}, source {}", e, source);
59                }
60            });
61        }
62    }
63}
64
65impl_deref!(PendingHandshake, TcpStream);
66impl PendingHandshake {
67    async fn handshake(mut self, auth: &Arc<AuthMethod>) -> Result<PendingAuthenticate> {
68        let mut header = [0u8; 2];
69        self.read_exact(&mut header).await?;
70        if header[0] != SOCKS_VER {
71            return Err(Socks5ServerError::UnknowProtocol);
72        }
73        let mut matched = false;
74        for _ in 0..header[1] {
75            let mut m = [0u8; 1];
76            self.read_exact(&mut m).await?;
77            if m[0] == auth.to_code() {
78                matched = true;
79            }
80        }
81        if !matched {
82            return Err(Socks5ServerError::UnsupportAuth);
83        }
84
85        self.write_all(&[SOCKS_VER, auth.to_code()]).await?;
86        self.flush().await?;
87
88        Ok(PendingAuthenticate(self.0))
89    }
90}
91
92impl_deref!(PendingAuthenticate, TcpStream);
93impl PendingAuthenticate {
94    async fn authenticate(self, auth: &Arc<AuthMethod>) -> Result<PendingCommand> {
95        match **auth {
96            AuthMethod::NoAuth => Ok(PendingCommand(self.0)),
97            _ => Err(Socks5ServerError::UnsupportAuth),
98        }
99    }
100}
101
102impl_deref!(PendingCommand, TcpStream);
103impl PendingCommand {
104    async fn handle_command(&mut self) -> Result<SocketAddr> {
105        let mut header = [0u8; 4];
106        self.read_exact(&mut header).await?;
107        if header[0] != SOCKS_VER || header[2] != SOCKS_RSV {
108            return Err(Socks5ServerError::UnknowProtocol);
109        } else if header[1] != SOCKS_COMMAND_CONNECT {
110            return Err(Socks5ServerError::UnsupportCommand(header[1]));
111        }
112
113        match header[3] {
114            SOCKS_ADDR_IPV4 => {
115                let mut buffer = [0u8; 4 + 2];
116                self.read_exact(&mut buffer).await?;
117                let ip: [u8; 4] = buffer[..4].try_into().unwrap();
118                let ip: Ipv4Addr = Ipv4Addr::from(ip);
119                let port = u16::from_be_bytes([buffer[4], buffer[5]]);
120                let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
121                info!("connecting to {}", addr);
122                Ok(addr)
123            }
124            SOCKS_ADDR_IPV6 => {
125                let mut buffer = [0u8; 16 + 2];
126                self.read_exact(&mut buffer).await?;
127                let ip: [u8; 16] = buffer[..16].try_into().unwrap();
128                let ip = Ipv6Addr::from(ip);
129                let port = u16::from_be_bytes([buffer[16], buffer[17]]);
130                let addr = SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0));
131                info!("connecting to {}", addr);
132                Ok(addr)
133            }
134            SOCKS_ADDR_DOMAINNAME => {
135                let mut buffer = [0u8; 255];
136                self.read_exact(&mut buffer[..1]).await?;
137                let len = buffer[0];
138                self.read_exact(&mut buffer[..len as usize]).await?;
139                let mut port = [0u8; 2];
140                self.read_exact(&mut port).await?;
141                let port = u16::from_be_bytes(port);
142                let host = std::str::from_utf8(&buffer[..len as usize])?;
143                let sock = (host, port).to_socket_addrs()?.next();
144                if let None = sock {
145                    return Err(Socks5ServerError::DNSError(host.into()));
146                }
147                let addr = sock.unwrap();
148                info!("connecting to {}:{}", host, port);
149                Ok(addr)
150            }
151            _ => Err(Socks5ServerError::UnknowAddrType(header[3])),
152        }
153    }
154    async fn reply(mut self, content: &[u8]) -> Result<TcpStream> {
155        self.write_all(&content).await?;
156        self.flush().await?;
157        Ok(self.0)
158    }
159}
160async fn handle_client(conn: TcpStream, auth: Arc<AuthMethod>) -> Result<()> {
161    let mut conn = PendingHandshake(conn)
162        .handshake(&auth)
163        .await?
164        .authenticate(&auth)
165        .await?;
166    let addr = conn.handle_command().await;
167    let mut rep = [
168        SOCKS_VER,
169        SocksError::SUCCESS as u8,
170        SOCKS_RSV,
171        SOCKS_ADDR_IPV4,
172        0,
173        0,
174        0,
175        0,
176        0,
177        0,
178    ];
179    let addr = match addr {
180        Ok(c) => c,
181        Err(e) => {
182            rep[1] = match e {
183                Socks5ServerError::DNSError(_) => SocksError::HOST,
184                Socks5ServerError::UnsupportCommand(_) => SocksError::COMMAND,
185                Socks5ServerError::UnknowAddrType(_) => SocksError::ADDRESS,
186                _ => SocksError::FAIL,
187            } as u8;
188            conn.reply(&rep).await?;
189            return Err(e);
190        }
191    };
192
193    // --------------------------------
194    let delegate = TcpStream::connect(addr).await;
195    let delegate = match delegate {
196        Ok(c) => c,
197        Err(e) => {
198            rep[1] = SocksError::NETWORK as u8;
199            conn.reply(&rep).await?;
200            return Err(e.into());
201        }
202    };
203
204    let conn = conn.reply(&rep).await?;
205
206    let (conn_r, conn_w) = conn.into_split();
207    let (delegate_r, delegate_w) = delegate.into_split();
208
209    tokio::spawn(async move {
210        copy(conn_r, delegate_w).await;
211    });
212
213    tokio::spawn(async move {
214        copy(delegate_r, conn_w).await;
215    });
216
217    Ok(())
218}
219
220async fn copy(mut r: impl AsyncRead + Unpin, mut w: impl AsyncWrite + Unpin) {
221    tokio::io::copy(&mut r, &mut w).await.unwrap_or(0);
222
223    w.shutdown().await.unwrap_or(());
224}