1use crate::utils::*;
2use log::{error, info};
3use std::{
4 convert::TryInto,
5 net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
6 ops::{Deref, DerefMut},
7 sync::Arc,
8};
9use thiserror::Error;
10use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11use tokio::net::{TcpSocket, TcpStream};
12
13type Result<T> = std::result::Result<T, Socks5ServerError>;
14
15#[derive(Debug, Error)]
16pub enum Socks5ServerError {
17 #[error("unrecognized protocol")]
18 UnknowProtocol,
19 #[error("unsupport authenticate method")]
20 UnsupportAuth,
21 #[error("unsupport socks5 command {0:#04X}")]
22 UnsupportCommand(u8),
23 #[error("unknow destination type {0:#04X}")]
24 UnknowAddrType(u8),
25 #[error("invalid hostname received")]
26 InvalidHost(#[from] std::str::Utf8Error),
27 #[error("DNS lookup error: {0}")]
28 DNSError(String),
29 #[error(transparent)]
30 IOError(#[from] io::Error),
31}
32pub struct Socks5Server {
33 conn: TcpSocket,
34 auth: Arc<AuthMethod>,
35}
36pub fn new(addr: SocketAddr, auth: Option<AuthMethod>) -> Result<Socks5Server> {
37 let conn = match addr {
38 SocketAddr::V4(_) => TcpSocket::new_v4()?,
39 SocketAddr::V6(_) => TcpSocket::new_v6()?,
40 };
41 conn.bind(addr)?;
42
43 let auth = auth.unwrap_or(AuthMethod::NoAuth);
44 let auth = Arc::new(auth);
45 Ok(Socks5Server { conn, auth })
46}
47
48impl Socks5Server {
49 pub async fn run(self) -> Result<()> {
50 let conn = self.conn.listen(1024)?;
51 loop {
52 let (conn, source) = conn.accept().await?;
53
54 let auth = self.auth.clone();
55 tokio::spawn(async move {
56 let result = handle_client(conn, auth).await;
57 if let Err(e) = result {
58 error!("{:?}, source {}", e, source);
59 }
60 });
61 }
62 }
63}
64
65impl_deref!(PendingHandshake, TcpStream);
66impl PendingHandshake {
67 async fn handshake(mut self, auth: &Arc<AuthMethod>) -> Result<PendingAuthenticate> {
68 let mut header = [0u8; 2];
69 self.read_exact(&mut header).await?;
70 if header[0] != SOCKS_VER {
71 return Err(Socks5ServerError::UnknowProtocol);
72 }
73 let mut matched = false;
74 for _ in 0..header[1] {
75 let mut m = [0u8; 1];
76 self.read_exact(&mut m).await?;
77 if m[0] == auth.to_code() {
78 matched = true;
79 }
80 }
81 if !matched {
82 return Err(Socks5ServerError::UnsupportAuth);
83 }
84
85 self.write_all(&[SOCKS_VER, auth.to_code()]).await?;
86 self.flush().await?;
87
88 Ok(PendingAuthenticate(self.0))
89 }
90}
91
92impl_deref!(PendingAuthenticate, TcpStream);
93impl PendingAuthenticate {
94 async fn authenticate(self, auth: &Arc<AuthMethod>) -> Result<PendingCommand> {
95 match **auth {
96 AuthMethod::NoAuth => Ok(PendingCommand(self.0)),
97 _ => Err(Socks5ServerError::UnsupportAuth),
98 }
99 }
100}
101
102impl_deref!(PendingCommand, TcpStream);
103impl PendingCommand {
104 async fn handle_command(&mut self) -> Result<SocketAddr> {
105 let mut header = [0u8; 4];
106 self.read_exact(&mut header).await?;
107 if header[0] != SOCKS_VER || header[2] != SOCKS_RSV {
108 return Err(Socks5ServerError::UnknowProtocol);
109 } else if header[1] != SOCKS_COMMAND_CONNECT {
110 return Err(Socks5ServerError::UnsupportCommand(header[1]));
111 }
112
113 match header[3] {
114 SOCKS_ADDR_IPV4 => {
115 let mut buffer = [0u8; 4 + 2];
116 self.read_exact(&mut buffer).await?;
117 let ip: [u8; 4] = buffer[..4].try_into().unwrap();
118 let ip: Ipv4Addr = Ipv4Addr::from(ip);
119 let port = u16::from_be_bytes([buffer[4], buffer[5]]);
120 let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
121 info!("connecting to {}", addr);
122 Ok(addr)
123 }
124 SOCKS_ADDR_IPV6 => {
125 let mut buffer = [0u8; 16 + 2];
126 self.read_exact(&mut buffer).await?;
127 let ip: [u8; 16] = buffer[..16].try_into().unwrap();
128 let ip = Ipv6Addr::from(ip);
129 let port = u16::from_be_bytes([buffer[16], buffer[17]]);
130 let addr = SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0));
131 info!("connecting to {}", addr);
132 Ok(addr)
133 }
134 SOCKS_ADDR_DOMAINNAME => {
135 let mut buffer = [0u8; 255];
136 self.read_exact(&mut buffer[..1]).await?;
137 let len = buffer[0];
138 self.read_exact(&mut buffer[..len as usize]).await?;
139 let mut port = [0u8; 2];
140 self.read_exact(&mut port).await?;
141 let port = u16::from_be_bytes(port);
142 let host = std::str::from_utf8(&buffer[..len as usize])?;
143 let sock = (host, port).to_socket_addrs()?.next();
144 if let None = sock {
145 return Err(Socks5ServerError::DNSError(host.into()));
146 }
147 let addr = sock.unwrap();
148 info!("connecting to {}:{}", host, port);
149 Ok(addr)
150 }
151 _ => Err(Socks5ServerError::UnknowAddrType(header[3])),
152 }
153 }
154 async fn reply(mut self, content: &[u8]) -> Result<TcpStream> {
155 self.write_all(&content).await?;
156 self.flush().await?;
157 Ok(self.0)
158 }
159}
160async fn handle_client(conn: TcpStream, auth: Arc<AuthMethod>) -> Result<()> {
161 let mut conn = PendingHandshake(conn)
162 .handshake(&auth)
163 .await?
164 .authenticate(&auth)
165 .await?;
166 let addr = conn.handle_command().await;
167 let mut rep = [
168 SOCKS_VER,
169 SocksError::SUCCESS as u8,
170 SOCKS_RSV,
171 SOCKS_ADDR_IPV4,
172 0,
173 0,
174 0,
175 0,
176 0,
177 0,
178 ];
179 let addr = match addr {
180 Ok(c) => c,
181 Err(e) => {
182 rep[1] = match e {
183 Socks5ServerError::DNSError(_) => SocksError::HOST,
184 Socks5ServerError::UnsupportCommand(_) => SocksError::COMMAND,
185 Socks5ServerError::UnknowAddrType(_) => SocksError::ADDRESS,
186 _ => SocksError::FAIL,
187 } as u8;
188 conn.reply(&rep).await?;
189 return Err(e);
190 }
191 };
192
193 let delegate = TcpStream::connect(addr).await;
195 let delegate = match delegate {
196 Ok(c) => c,
197 Err(e) => {
198 rep[1] = SocksError::NETWORK as u8;
199 conn.reply(&rep).await?;
200 return Err(e.into());
201 }
202 };
203
204 let conn = conn.reply(&rep).await?;
205
206 let (conn_r, conn_w) = conn.into_split();
207 let (delegate_r, delegate_w) = delegate.into_split();
208
209 tokio::spawn(async move {
210 copy(conn_r, delegate_w).await;
211 });
212
213 tokio::spawn(async move {
214 copy(delegate_r, conn_w).await;
215 });
216
217 Ok(())
218}
219
220async fn copy(mut r: impl AsyncRead + Unpin, mut w: impl AsyncWrite + Unpin) {
221 tokio::io::copy(&mut r, &mut w).await.unwrap_or(0);
222
223 w.shutdown().await.unwrap_or(());
224}