Skip to main content

rustgate/
socks5.rs

1use crate::error::{ProxyError, Result};
2use tokio::io::{AsyncReadExt, AsyncWriteExt};
3use tokio::net::{TcpListener, TcpStream};
4use tracing::debug;
5
6const SOCKS5_VERSION: u8 = 0x05;
7const AUTH_NONE: u8 = 0x00;
8const CMD_CONNECT: u8 = 0x01;
9const ATYP_IPV4: u8 = 0x01;
10const ATYP_DOMAIN: u8 = 0x03;
11const ATYP_IPV6: u8 = 0x04;
12const REPLY_SUCCESS: u8 = 0x00;
13const REPLY_GENERAL_FAILURE: u8 = 0x01;
14const REPLY_CMD_NOT_SUPPORTED: u8 = 0x07;
15
16/// Parsed SOCKS5 CONNECT request.
17pub struct Socks5Request {
18    pub target_addr: String,
19}
20
21/// A minimal SOCKS5 listener that only supports the CONNECT command.
22pub struct Socks5Listener {
23    listener: TcpListener,
24    pub tunnel_id: u32,
25}
26
27impl Socks5Listener {
28    pub async fn bind(addr: &str, tunnel_id: u32) -> Result<Self> {
29        let listener = TcpListener::bind(addr).await?;
30        Ok(Self {
31            listener,
32            tunnel_id,
33        })
34    }
35
36    /// Accept a raw TCP connection (no handshake — caller handles it per-connection).
37    pub async fn accept_raw(&self) -> Result<TcpStream> {
38        let (stream, peer) = self.listener.accept().await?;
39        debug!("SOCKS5 connection from {peer}");
40        Ok(stream)
41    }
42
43    /// Accept one SOCKS5 connection, perform handshake, return the stream and target.
44    pub async fn accept(&self) -> Result<(TcpStream, Socks5Request)> {
45        let (stream, peer) = self.listener.accept().await?;
46        debug!("SOCKS5 connection from {peer}");
47        socks5_handshake(stream).await
48    }
49}
50
51/// Perform SOCKS5 server-side handshake. Returns the TCP stream and CONNECT target.
52pub async fn socks5_handshake(mut stream: TcpStream) -> Result<(TcpStream, Socks5Request)> {
53    // 1. Method selection: [VER][NMETHODS][METHODS...]
54    let ver = stream.read_u8().await?;
55    if ver != SOCKS5_VERSION {
56        return Err(ProxyError::Protocol(format!(
57            "SOCKS: expected version 5, got {ver}"
58        )));
59    }
60    let nmethods = stream.read_u8().await? as usize;
61    let mut methods = vec![0u8; nmethods];
62    stream.read_exact(&mut methods).await?;
63
64    if !methods.contains(&AUTH_NONE) {
65        // No acceptable auth method
66        stream.write_all(&[SOCKS5_VERSION, 0xFF]).await?;
67        return Err(ProxyError::Protocol(
68            "SOCKS: no acceptable auth method".into(),
69        ));
70    }
71    // Reply: no auth required
72    stream.write_all(&[SOCKS5_VERSION, AUTH_NONE]).await?;
73
74    // 2. Request: [VER][CMD][RSV][ATYP][ADDR][PORT]
75    let ver = stream.read_u8().await?;
76    if ver != SOCKS5_VERSION {
77        return Err(ProxyError::Protocol(format!(
78            "SOCKS: expected version 5 in request, got {ver}"
79        )));
80    }
81    let cmd = stream.read_u8().await?;
82    let _rsv = stream.read_u8().await?;
83
84    if cmd != CMD_CONNECT {
85        // Reply with "command not supported"
86        stream
87            .write_all(&[SOCKS5_VERSION, REPLY_CMD_NOT_SUPPORTED, 0x00, ATYP_IPV4, 0, 0, 0, 0, 0, 0])
88            .await?;
89        return Err(ProxyError::Protocol(format!(
90            "SOCKS: unsupported command {cmd}"
91        )));
92    }
93
94    let atyp = stream.read_u8().await?;
95    let host = match atyp {
96        ATYP_IPV4 => {
97            let mut addr = [0u8; 4];
98            stream.read_exact(&mut addr).await?;
99            format!("{}.{}.{}.{}", addr[0], addr[1], addr[2], addr[3])
100        }
101        ATYP_DOMAIN => {
102            let len = stream.read_u8().await? as usize;
103            let mut domain = vec![0u8; len];
104            stream.read_exact(&mut domain).await?;
105            String::from_utf8(domain)
106                .map_err(|e| ProxyError::Protocol(format!("SOCKS: invalid domain: {e}")))?
107        }
108        ATYP_IPV6 => {
109            let mut addr = [0u8; 16];
110            stream.read_exact(&mut addr).await?;
111            let segments: Vec<String> = addr
112                .chunks(2)
113                .map(|c| format!("{:02x}{:02x}", c[0], c[1]))
114                .collect();
115            format!("[{}]", segments.join(":"))
116        }
117        _ => {
118            stream
119                .write_all(&[SOCKS5_VERSION, REPLY_GENERAL_FAILURE, 0x00, ATYP_IPV4, 0, 0, 0, 0, 0, 0])
120                .await?;
121            return Err(ProxyError::Protocol(format!(
122                "SOCKS: unsupported address type {atyp}"
123            )));
124        }
125    };
126
127    let port = stream.read_u16().await?;
128    let target_addr = format!("{host}:{port}");
129    debug!("SOCKS5 CONNECT to {target_addr}");
130
131    Ok((stream, Socks5Request { target_addr }))
132}
133
134/// Send the SOCKS5 success reply after the remote side is ready.
135pub async fn send_socks5_success(stream: &mut TcpStream) -> Result<()> {
136    stream
137        .write_all(&[SOCKS5_VERSION, REPLY_SUCCESS, 0x00, ATYP_IPV4, 0, 0, 0, 0, 0, 0])
138        .await?;
139    Ok(())
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use tokio::io::AsyncWriteExt;
146    use tokio::net::TcpStream;
147
148    #[tokio::test]
149    async fn test_socks5_connect_ipv4() {
150        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
151        let addr = listener.local_addr().unwrap();
152
153        let client = tokio::spawn(async move {
154            let mut stream = TcpStream::connect(addr).await.unwrap();
155            // Method selection: version 5, 1 method, no auth
156            stream.write_all(&[0x05, 0x01, 0x00]).await.unwrap();
157            let mut resp = [0u8; 2];
158            stream.read_exact(&mut resp).await.unwrap();
159            assert_eq!(resp, [0x05, 0x00]);
160
161            // CONNECT to 93.184.216.34:80 (example.com)
162            stream
163                .write_all(&[0x05, 0x01, 0x00, 0x01, 93, 184, 216, 34, 0x00, 0x50])
164                .await
165                .unwrap();
166            let mut resp = [0u8; 10];
167            stream.read_exact(&mut resp).await.unwrap();
168            assert_eq!(resp[0], 0x05); // version
169            assert_eq!(resp[1], 0x00); // success
170        });
171
172        let (stream, _peer) = listener.accept().await.unwrap();
173        let (mut stream, req) = socks5_handshake(stream).await.unwrap();
174        assert_eq!(req.target_addr, "93.184.216.34:80");
175        send_socks5_success(&mut stream).await.unwrap();
176
177        client.await.unwrap();
178    }
179
180    #[tokio::test]
181    async fn test_socks5_connect_domain() {
182        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
183        let addr = listener.local_addr().unwrap();
184
185        let client = tokio::spawn(async move {
186            let mut stream = TcpStream::connect(addr).await.unwrap();
187            stream.write_all(&[0x05, 0x01, 0x00]).await.unwrap();
188            let mut resp = [0u8; 2];
189            stream.read_exact(&mut resp).await.unwrap();
190
191            // CONNECT to example.com:443 (domain)
192            let domain = b"example.com";
193            let mut req = vec![0x05, 0x01, 0x00, 0x03, domain.len() as u8];
194            req.extend_from_slice(domain);
195            req.extend_from_slice(&443u16.to_be_bytes());
196            stream.write_all(&req).await.unwrap();
197
198            let mut resp = [0u8; 10];
199            stream.read_exact(&mut resp).await.unwrap();
200            assert_eq!(resp[1], 0x00);
201        });
202
203        let (stream, _peer) = listener.accept().await.unwrap();
204        let (mut stream, req) = socks5_handshake(stream).await.unwrap();
205        assert_eq!(req.target_addr, "example.com:443");
206        send_socks5_success(&mut stream).await.unwrap();
207
208        client.await.unwrap();
209    }
210}