socks_abstract5/
lib.rs

1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
2use thiserror::Error;
3use tokio::{
4    io::{AsyncReadExt, AsyncWriteExt},
5    net::{TcpListener, TcpStream, UdpSocket},
6};
7
8#[derive(Debug, Error)]
9pub enum Error {
10    #[error("IO error: {0}")]
11    Io(#[from] std::io::Error),
12    #[error("Invalid SOCKS version")]
13    InvalidVersion,
14    #[error("Authentication failed")]
15    AuthFailed,
16    #[error("Invalid command: {0}")]
17    InvalidCommand(u8),
18    #[error("Invalid address type: {0}")]
19    InvalidAddrType(u8),
20    #[error("UTF-8 decode error: {0}")]
21    Utf8Error(#[from] std::string::FromUtf8Error),
22    #[error("Connection not allowed by ruleset")]
23    ConnectionNotAllowed,
24    #[error("Network unreachable")]
25    NetworkUnreachable,
26    #[error("Host unreachable")]
27    HostUnreachable,
28    #[error("Connection refused")]
29    ConnectionRefused,
30    #[error("TTL expired")]
31    TtlExpired,
32    #[error("Protocol error")]
33    ProtocolError,
34    #[error("Address type not supported")]
35    AddrTypeNotSupported,
36}
37
38pub type Result<T> = std::result::Result<T, Error>;
39
40#[derive(Debug)]
41pub enum Command {
42    Connect = 0x01,
43    Bind = 0x02,
44    UdpAssociate = 0x03,
45}
46
47impl TryFrom<u8> for Command {
48    type Error = Error;
49    fn try_from(value: u8) -> Result<Self> {
50        match value {
51            0x01 => Ok(Command::Connect),
52            0x02 => Ok(Command::Bind),
53            0x03 => Ok(Command::UdpAssociate),
54            cmd => Err(Error::InvalidCommand(cmd)),
55        }
56    }
57}
58
59#[derive(Debug)]
60pub enum Address {
61    Ipv4(Ipv4Addr),
62    Ipv6(Ipv6Addr),
63    Domain(String),
64}
65
66impl Address {
67    async fn read_from(stream: &mut TcpStream) -> Result<(Self, u16)> {
68        let atyp = stream.read_u8().await?;
69        
70        let addr = match atyp {
71            0x01 => { // IPv4
72                let mut buf = [0u8; 4];
73                stream.read_exact(&mut buf).await?;
74                Address::Ipv4(Ipv4Addr::from(buf))
75            }
76            0x03 => { // Domain
77                let len = stream.read_u8().await? as usize;
78                let mut domain = vec![0u8; len];
79                stream.read_exact(&mut domain).await?;
80                Address::Domain(String::from_utf8(domain)?)
81            }
82            0x04 => { // IPv6
83                let mut buf = [0u8; 16];
84                stream.read_exact(&mut buf).await?;
85                Address::Ipv6(Ipv6Addr::from(buf))
86            }
87            _ => return Err(Error::InvalidAddrType(atyp)),
88        };
89
90        let mut port_buf = [0u8; 2];
91        stream.read_exact(&mut port_buf).await?;
92        let port = u16::from_be_bytes(port_buf);
93
94        Ok((addr, port))
95    }
96
97    async fn to_socket_addr(&self, port: u16) -> Result<SocketAddr> {
98        match self {
99            Address::Ipv4(addr) => Ok(SocketAddr::new(IpAddr::V4(*addr), port)),
100            Address::Ipv6(addr) => Ok(SocketAddr::new(IpAddr::V6(*addr), port)),
101            Address::Domain(domain) => {
102                let addrs = tokio::net::lookup_host(format!("{}:{}", domain, port)).await?;
103                Ok(addrs.into_iter().next().ok_or(Error::HostUnreachable)?)
104            }
105        }
106    }
107}
108
109pub struct Socks5Server {
110    listener: TcpListener,
111}
112
113impl Socks5Server {
114    pub async fn new(addr: &str) -> Result<Self> {
115        let listener = TcpListener::bind(addr).await?;
116        Ok(Self { listener })
117    }
118
119    pub async fn run(&self) -> Result<()> {
120        loop {
121            let (client, _) = self.listener.accept().await?;
122            tokio::spawn(async move {
123                if let Err(e) = handle_client(client).await {
124                    eprintln!("Client error: {}", e);
125                }
126            });
127        }
128    }
129}
130
131async fn handle_client(mut client: TcpStream) -> Result<()> {
132    // Authentication negotiation
133    let ver = client.read_u8().await?;
134    if ver != 0x05 {
135        return Err(Error::InvalidVersion);
136    }
137
138    let nmethods = client.read_u8().await?;
139    let mut methods = vec![0u8; nmethods as usize];
140    client.read_exact(&mut methods).await?;
141
142    // For now, only support no authentication
143    if !methods.contains(&0x00) {
144        client.write_all(&[0x05, 0xFF]).await?;
145        return Err(Error::AuthFailed);
146    }
147    client.write_all(&[0x05, 0x00]).await?;
148
149    // Request
150    let ver = client.read_u8().await?;
151    if ver != 0x05 {
152        return Err(Error::InvalidVersion);
153    }
154
155    let cmd = Command::try_from(client.read_u8().await?)?;
156    let _rsv = client.read_u8().await?; // Reserved byte
157    let (target_addr, target_port) = Address::read_from(&mut client).await?;
158
159    match cmd {
160        Command::Connect => handle_connect(&mut client, target_addr, target_port).await,
161        Command::Bind => handle_bind(&mut client, target_addr, target_port).await,
162        Command::UdpAssociate => handle_udp_associate(&mut client).await,
163    }
164}
165
166async fn handle_connect(client: &mut TcpStream, addr: Address, port: u16) -> Result<()> {
167    let target_addr = addr.to_socket_addr(port).await?;
168    let mut target = match TcpStream::connect(target_addr).await {
169        Ok(stream) => stream,
170        Err(e) => {
171            let reply = match e.kind() {
172                std::io::ErrorKind::ConnectionRefused => 0x05,
173                std::io::ErrorKind::TimedOut => 0x06,
174                _ => 0x01,
175            };
176            send_error_reply(client, reply).await?;
177            return Err(e.into());
178        }
179    };
180
181    // Send success response
182    let bind_addr = target.local_addr()?;
183    send_reply(client, 0x00, &bind_addr).await?;
184
185    // Start proxying
186    let (mut cr, mut cw) = client.split();
187    let (mut tr, mut tw) = target.split();
188
189    tokio::select! {
190        res = tokio::io::copy(&mut cr, &mut tw) => {
191            if let Err(e) = res {
192                return Err(e.into());
193            }
194        }
195        res = tokio::io::copy(&mut tr, &mut cw) => {
196            if let Err(e) = res {
197                return Err(e.into());
198            }
199        }
200    }
201
202    Ok(())
203}
204
205async fn handle_bind(client: &mut TcpStream, addr: Address, port: u16) -> Result<()> {
206    // Create listener for incoming connection
207    let listener = TcpListener::bind("0.0.0.0:0").await?;
208    let bind_addr = listener.local_addr()?;
209
210    // Send first reply with bind address
211    send_reply(client, 0x00, &bind_addr).await?;
212
213    // Wait for incoming connection
214    let (mut target, target_addr) = listener.accept().await?;
215    
216    // Verify connection is from expected address
217    let expected_addr = addr.to_socket_addr(port).await?;
218    if target_addr.ip() != expected_addr.ip() {
219        send_error_reply(client, 0x02).await?;
220        return Err(Error::ConnectionNotAllowed);
221    }
222
223    // Send second reply with target address
224    send_reply(client, 0x00, &target_addr).await?;
225
226    // Start proxying
227    let (mut cr, mut cw) = client.split();
228    let (mut tr, mut tw) = target.split();
229
230    tokio::select! {
231        res = tokio::io::copy(&mut cr, &mut tw) => {
232            if let Err(e) = res {
233                return Err(e.into());
234            }
235        }
236        res = tokio::io::copy(&mut tr, &mut cw) => {
237            if let Err(e) = res {
238                return Err(e.into());
239            }
240        }
241    }
242
243    Ok(())
244}
245
246async fn handle_udp_associate(client: &mut TcpStream) -> Result<()> {
247    // Create UDP socket for relaying
248    let relay = UdpSocket::bind("0.0.0.0:0").await?;
249    let relay_addr = relay.local_addr()?;
250
251    // Send reply with relay address
252    send_reply(client, 0x00, &relay_addr).await?;
253
254    // Keep TCP connection open and monitor it
255    let mut keep_alive = [0u8; 1];
256    loop {
257        match client.read(&mut keep_alive).await {
258            Ok(0) | Err(_) => break, // Connection closed
259            _ => continue,
260        }
261    }
262
263    Ok(())
264}
265
266async fn send_reply(stream: &mut TcpStream, reply: u8, addr: &SocketAddr) -> Result<()> {
267    stream.write_u8(0x05).await?; // VER
268    stream.write_u8(reply).await?; // REP
269    stream.write_u8(0x00).await?; // RSV
270
271    match addr {
272        SocketAddr::V4(addr) => {
273            stream.write_u8(0x01).await?; // ATYP
274            stream.write_all(&addr.ip().octets()).await?;
275            stream.write_all(&addr.port().to_be_bytes()).await?;
276        }
277        SocketAddr::V6(addr) => {
278            stream.write_u8(0x04).await?; // ATYP
279            stream.write_all(&addr.ip().octets()).await?;
280            stream.write_all(&addr.port().to_be_bytes()).await?;
281        }
282    }
283
284    Ok(())
285}
286
287async fn send_error_reply(stream: &mut TcpStream, reply: u8) -> Result<()> {
288    let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
289    send_reply(stream, reply, &addr).await
290}