1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
2use thiserror::Error;
3use tokio::{
4 io::{AsyncReadExt, AsyncWriteExt},
5 net::{TcpListener, TcpStream, UdpSocket},
6};
7
8#[derive(Debug, Error)]
9pub enum Error {
10 #[error("IO error: {0}")]
11 Io(#[from] std::io::Error),
12 #[error("Invalid SOCKS version")]
13 InvalidVersion,
14 #[error("Authentication failed")]
15 AuthFailed,
16 #[error("Invalid command: {0}")]
17 InvalidCommand(u8),
18 #[error("Invalid address type: {0}")]
19 InvalidAddrType(u8),
20 #[error("UTF-8 decode error: {0}")]
21 Utf8Error(#[from] std::string::FromUtf8Error),
22 #[error("Connection not allowed by ruleset")]
23 ConnectionNotAllowed,
24 #[error("Network unreachable")]
25 NetworkUnreachable,
26 #[error("Host unreachable")]
27 HostUnreachable,
28 #[error("Connection refused")]
29 ConnectionRefused,
30 #[error("TTL expired")]
31 TtlExpired,
32 #[error("Protocol error")]
33 ProtocolError,
34 #[error("Address type not supported")]
35 AddrTypeNotSupported,
36}
37
38pub type Result<T> = std::result::Result<T, Error>;
39
40#[derive(Debug)]
41pub enum Command {
42 Connect = 0x01,
43 Bind = 0x02,
44 UdpAssociate = 0x03,
45}
46
47impl TryFrom<u8> for Command {
48 type Error = Error;
49 fn try_from(value: u8) -> Result<Self> {
50 match value {
51 0x01 => Ok(Command::Connect),
52 0x02 => Ok(Command::Bind),
53 0x03 => Ok(Command::UdpAssociate),
54 cmd => Err(Error::InvalidCommand(cmd)),
55 }
56 }
57}
58
59#[derive(Debug)]
60pub enum Address {
61 Ipv4(Ipv4Addr),
62 Ipv6(Ipv6Addr),
63 Domain(String),
64}
65
66impl Address {
67 async fn read_from(stream: &mut TcpStream) -> Result<(Self, u16)> {
68 let atyp = stream.read_u8().await?;
69
70 let addr = match atyp {
71 0x01 => { let mut buf = [0u8; 4];
73 stream.read_exact(&mut buf).await?;
74 Address::Ipv4(Ipv4Addr::from(buf))
75 }
76 0x03 => { let len = stream.read_u8().await? as usize;
78 let mut domain = vec![0u8; len];
79 stream.read_exact(&mut domain).await?;
80 Address::Domain(String::from_utf8(domain)?)
81 }
82 0x04 => { let mut buf = [0u8; 16];
84 stream.read_exact(&mut buf).await?;
85 Address::Ipv6(Ipv6Addr::from(buf))
86 }
87 _ => return Err(Error::InvalidAddrType(atyp)),
88 };
89
90 let mut port_buf = [0u8; 2];
91 stream.read_exact(&mut port_buf).await?;
92 let port = u16::from_be_bytes(port_buf);
93
94 Ok((addr, port))
95 }
96
97 async fn to_socket_addr(&self, port: u16) -> Result<SocketAddr> {
98 match self {
99 Address::Ipv4(addr) => Ok(SocketAddr::new(IpAddr::V4(*addr), port)),
100 Address::Ipv6(addr) => Ok(SocketAddr::new(IpAddr::V6(*addr), port)),
101 Address::Domain(domain) => {
102 let addrs = tokio::net::lookup_host(format!("{}:{}", domain, port)).await?;
103 Ok(addrs.into_iter().next().ok_or(Error::HostUnreachable)?)
104 }
105 }
106 }
107}
108
109pub struct Socks5Server {
110 listener: TcpListener,
111}
112
113impl Socks5Server {
114 pub async fn new(addr: &str) -> Result<Self> {
115 let listener = TcpListener::bind(addr).await?;
116 Ok(Self { listener })
117 }
118
119 pub async fn run(&self) -> Result<()> {
120 loop {
121 let (client, _) = self.listener.accept().await?;
122 tokio::spawn(async move {
123 if let Err(e) = handle_client(client).await {
124 eprintln!("Client error: {}", e);
125 }
126 });
127 }
128 }
129}
130
131async fn handle_client(mut client: TcpStream) -> Result<()> {
132 let ver = client.read_u8().await?;
134 if ver != 0x05 {
135 return Err(Error::InvalidVersion);
136 }
137
138 let nmethods = client.read_u8().await?;
139 let mut methods = vec![0u8; nmethods as usize];
140 client.read_exact(&mut methods).await?;
141
142 if !methods.contains(&0x00) {
144 client.write_all(&[0x05, 0xFF]).await?;
145 return Err(Error::AuthFailed);
146 }
147 client.write_all(&[0x05, 0x00]).await?;
148
149 let ver = client.read_u8().await?;
151 if ver != 0x05 {
152 return Err(Error::InvalidVersion);
153 }
154
155 let cmd = Command::try_from(client.read_u8().await?)?;
156 let _rsv = client.read_u8().await?; let (target_addr, target_port) = Address::read_from(&mut client).await?;
158
159 match cmd {
160 Command::Connect => handle_connect(&mut client, target_addr, target_port).await,
161 Command::Bind => handle_bind(&mut client, target_addr, target_port).await,
162 Command::UdpAssociate => handle_udp_associate(&mut client).await,
163 }
164}
165
166async fn handle_connect(client: &mut TcpStream, addr: Address, port: u16) -> Result<()> {
167 let target_addr = addr.to_socket_addr(port).await?;
168 let mut target = match TcpStream::connect(target_addr).await {
169 Ok(stream) => stream,
170 Err(e) => {
171 let reply = match e.kind() {
172 std::io::ErrorKind::ConnectionRefused => 0x05,
173 std::io::ErrorKind::TimedOut => 0x06,
174 _ => 0x01,
175 };
176 send_error_reply(client, reply).await?;
177 return Err(e.into());
178 }
179 };
180
181 let bind_addr = target.local_addr()?;
183 send_reply(client, 0x00, &bind_addr).await?;
184
185 let (mut cr, mut cw) = client.split();
187 let (mut tr, mut tw) = target.split();
188
189 tokio::select! {
190 res = tokio::io::copy(&mut cr, &mut tw) => {
191 if let Err(e) = res {
192 return Err(e.into());
193 }
194 }
195 res = tokio::io::copy(&mut tr, &mut cw) => {
196 if let Err(e) = res {
197 return Err(e.into());
198 }
199 }
200 }
201
202 Ok(())
203}
204
205async fn handle_bind(client: &mut TcpStream, addr: Address, port: u16) -> Result<()> {
206 let listener = TcpListener::bind("0.0.0.0:0").await?;
208 let bind_addr = listener.local_addr()?;
209
210 send_reply(client, 0x00, &bind_addr).await?;
212
213 let (mut target, target_addr) = listener.accept().await?;
215
216 let expected_addr = addr.to_socket_addr(port).await?;
218 if target_addr.ip() != expected_addr.ip() {
219 send_error_reply(client, 0x02).await?;
220 return Err(Error::ConnectionNotAllowed);
221 }
222
223 send_reply(client, 0x00, &target_addr).await?;
225
226 let (mut cr, mut cw) = client.split();
228 let (mut tr, mut tw) = target.split();
229
230 tokio::select! {
231 res = tokio::io::copy(&mut cr, &mut tw) => {
232 if let Err(e) = res {
233 return Err(e.into());
234 }
235 }
236 res = tokio::io::copy(&mut tr, &mut cw) => {
237 if let Err(e) = res {
238 return Err(e.into());
239 }
240 }
241 }
242
243 Ok(())
244}
245
246async fn handle_udp_associate(client: &mut TcpStream) -> Result<()> {
247 let relay = UdpSocket::bind("0.0.0.0:0").await?;
249 let relay_addr = relay.local_addr()?;
250
251 send_reply(client, 0x00, &relay_addr).await?;
253
254 let mut keep_alive = [0u8; 1];
256 loop {
257 match client.read(&mut keep_alive).await {
258 Ok(0) | Err(_) => break, _ => continue,
260 }
261 }
262
263 Ok(())
264}
265
266async fn send_reply(stream: &mut TcpStream, reply: u8, addr: &SocketAddr) -> Result<()> {
267 stream.write_u8(0x05).await?; stream.write_u8(reply).await?; stream.write_u8(0x00).await?; match addr {
272 SocketAddr::V4(addr) => {
273 stream.write_u8(0x01).await?; stream.write_all(&addr.ip().octets()).await?;
275 stream.write_all(&addr.port().to_be_bytes()).await?;
276 }
277 SocketAddr::V6(addr) => {
278 stream.write_u8(0x04).await?; stream.write_all(&addr.ip().octets()).await?;
280 stream.write_all(&addr.port().to_be_bytes()).await?;
281 }
282 }
283
284 Ok(())
285}
286
287async fn send_error_reply(stream: &mut TcpStream, reply: u8) -> Result<()> {
288 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
289 send_reply(stream, reply, &addr).await
290}