sandbox_runtime/proxy/
socks5.rs

1//! SOCKS5 proxy server (RFC 1928).
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::{TcpListener, TcpStream};
8use tokio::sync::oneshot;
9
10use crate::error::SandboxError;
11use crate::proxy::filter::{DomainFilter, FilterDecision};
12
13// SOCKS5 constants
14const SOCKS_VERSION: u8 = 0x05;
15const AUTH_NONE: u8 = 0x00;
16const CMD_CONNECT: u8 = 0x01;
17const ATYP_IPV4: u8 = 0x01;
18const ATYP_DOMAIN: u8 = 0x03;
19const ATYP_IPV6: u8 = 0x04;
20const REP_SUCCESS: u8 = 0x00;
21const REP_GENERAL_FAILURE: u8 = 0x01;
22const REP_CONNECTION_NOT_ALLOWED: u8 = 0x02;
23const REP_HOST_UNREACHABLE: u8 = 0x04;
24
25/// SOCKS5 proxy server.
26pub struct Socks5Proxy {
27    listener: Option<TcpListener>,
28    port: u16,
29    filter: Arc<DomainFilter>,
30    shutdown_tx: Option<oneshot::Sender<()>>,
31}
32
33impl Socks5Proxy {
34    /// Create a new SOCKS5 proxy server.
35    pub async fn new(filter: DomainFilter) -> Result<Self, SandboxError> {
36        let listener = TcpListener::bind("127.0.0.1:0").await?;
37        let port = listener.local_addr()?.port();
38
39        tracing::debug!("SOCKS5 proxy listening on port {}", port);
40
41        Ok(Self {
42            listener: Some(listener),
43            port,
44            filter: Arc::new(filter),
45            shutdown_tx: None,
46        })
47    }
48
49    /// Get the port the proxy is listening on.
50    pub fn port(&self) -> u16 {
51        self.port
52    }
53
54    /// Start the proxy server.
55    pub fn start(&mut self) -> Result<(), SandboxError> {
56        let listener = self
57            .listener
58            .take()
59            .ok_or_else(|| SandboxError::Proxy("Proxy already started".to_string()))?;
60
61        let filter = self.filter.clone();
62        let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
63        self.shutdown_tx = Some(shutdown_tx);
64
65        tokio::spawn(async move {
66            loop {
67                tokio::select! {
68                    accept_result = listener.accept() => {
69                        match accept_result {
70                            Ok((stream, addr)) => {
71                                let filter = filter.clone();
72                                tokio::spawn(async move {
73                                    if let Err(e) = handle_client(stream, addr, filter).await {
74                                        tracing::debug!("SOCKS5 error from {}: {}", addr, e);
75                                    }
76                                });
77                            }
78                            Err(e) => {
79                                tracing::error!("SOCKS5 accept error: {}", e);
80                            }
81                        }
82                    }
83                    _ = &mut shutdown_rx => {
84                        tracing::debug!("SOCKS5 proxy shutting down");
85                        break;
86                    }
87                }
88            }
89        });
90
91        Ok(())
92    }
93
94    /// Stop the proxy server.
95    pub fn stop(&mut self) {
96        if let Some(tx) = self.shutdown_tx.take() {
97            let _ = tx.send(());
98        }
99    }
100}
101
102/// Handle a SOCKS5 client connection.
103async fn handle_client(
104    mut stream: TcpStream,
105    _addr: SocketAddr,
106    filter: Arc<DomainFilter>,
107) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
108    // Read version and authentication methods
109    let mut header = [0u8; 2];
110    stream.read_exact(&mut header).await?;
111
112    if header[0] != SOCKS_VERSION {
113        return Err("Invalid SOCKS version".into());
114    }
115
116    let nmethods = header[1] as usize;
117    let mut methods = vec![0u8; nmethods];
118    stream.read_exact(&mut methods).await?;
119
120    // We only support no authentication
121    if !methods.contains(&AUTH_NONE) {
122        stream.write_all(&[SOCKS_VERSION, 0xFF]).await?;
123        return Err("No supported authentication method".into());
124    }
125
126    // Send auth method selection
127    stream.write_all(&[SOCKS_VERSION, AUTH_NONE]).await?;
128
129    // Read connection request
130    let mut request = [0u8; 4];
131    stream.read_exact(&mut request).await?;
132
133    if request[0] != SOCKS_VERSION {
134        return Err("Invalid SOCKS version in request".into());
135    }
136
137    let cmd = request[1];
138    // request[2] is reserved
139    let atyp = request[3];
140
141    if cmd != CMD_CONNECT {
142        send_reply(&mut stream, REP_GENERAL_FAILURE, "0.0.0.0", 0).await?;
143        return Err("Only CONNECT command is supported".into());
144    }
145
146    // Parse destination address
147    let (host, port) = match atyp {
148        ATYP_IPV4 => {
149            let mut addr = [0u8; 4];
150            stream.read_exact(&mut addr).await?;
151            let mut port_buf = [0u8; 2];
152            stream.read_exact(&mut port_buf).await?;
153            let port = u16::from_be_bytes(port_buf);
154            let host = format!("{}.{}.{}.{}", addr[0], addr[1], addr[2], addr[3]);
155            (host, port)
156        }
157        ATYP_DOMAIN => {
158            let mut len_buf = [0u8; 1];
159            stream.read_exact(&mut len_buf).await?;
160            let len = len_buf[0] as usize;
161            let mut domain = vec![0u8; len];
162            stream.read_exact(&mut domain).await?;
163            let mut port_buf = [0u8; 2];
164            stream.read_exact(&mut port_buf).await?;
165            let port = u16::from_be_bytes(port_buf);
166            let host = String::from_utf8_lossy(&domain).to_string();
167            (host, port)
168        }
169        ATYP_IPV6 => {
170            let mut addr = [0u8; 16];
171            stream.read_exact(&mut addr).await?;
172            let mut port_buf = [0u8; 2];
173            stream.read_exact(&mut port_buf).await?;
174            let port = u16::from_be_bytes(port_buf);
175            // Format as IPv6 address
176            let host = format!(
177                "{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}",
178                u16::from_be_bytes([addr[0], addr[1]]),
179                u16::from_be_bytes([addr[2], addr[3]]),
180                u16::from_be_bytes([addr[4], addr[5]]),
181                u16::from_be_bytes([addr[6], addr[7]]),
182                u16::from_be_bytes([addr[8], addr[9]]),
183                u16::from_be_bytes([addr[10], addr[11]]),
184                u16::from_be_bytes([addr[12], addr[13]]),
185                u16::from_be_bytes([addr[14], addr[15]])
186            );
187            (host, port)
188        }
189        _ => {
190            send_reply(&mut stream, REP_GENERAL_FAILURE, "0.0.0.0", 0).await?;
191            return Err("Unsupported address type".into());
192        }
193    };
194
195    tracing::debug!("SOCKS5 CONNECT {}:{}", host, port);
196
197    // Check filter
198    let decision = filter.check(&host, port);
199
200    if matches!(decision, FilterDecision::Deny) {
201        tracing::debug!("SOCKS5 denied connection to {}:{}", host, port);
202        send_reply(&mut stream, REP_CONNECTION_NOT_ALLOWED, "0.0.0.0", 0).await?;
203        return Ok(());
204    }
205
206    // Connect to target
207    let target = match TcpStream::connect(format!("{}:{}", host, port)).await {
208        Ok(s) => s,
209        Err(e) => {
210            tracing::debug!("SOCKS5 failed to connect to {}:{}: {}", host, port, e);
211            send_reply(&mut stream, REP_HOST_UNREACHABLE, "0.0.0.0", 0).await?;
212            return Ok(());
213        }
214    };
215
216    // Send success reply
217    let local_addr = target.local_addr()?;
218    let (bind_addr, bind_port) = match local_addr {
219        SocketAddr::V4(addr) => (addr.ip().to_string(), addr.port()),
220        SocketAddr::V6(addr) => (addr.ip().to_string(), addr.port()),
221    };
222    send_reply(&mut stream, REP_SUCCESS, &bind_addr, bind_port).await?;
223
224    // Pipe data
225    let (mut client_read, mut client_write) = stream.into_split();
226    let (mut target_read, mut target_write) = target.into_split();
227
228    let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
229    let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);
230
231    tokio::select! {
232        _ = client_to_target => {}
233        _ = target_to_client => {}
234    }
235
236    Ok(())
237}
238
239/// Send a SOCKS5 reply.
240async fn send_reply(
241    stream: &mut TcpStream,
242    rep: u8,
243    addr: &str,
244    port: u16,
245) -> Result<(), std::io::Error> {
246    let mut reply = vec![SOCKS_VERSION, rep, 0x00]; // VER, REP, RSV
247
248    // Parse address
249    if let Ok(ipv4) = addr.parse::<std::net::Ipv4Addr>() {
250        reply.push(ATYP_IPV4);
251        reply.extend_from_slice(&ipv4.octets());
252    } else if let Ok(ipv6) = addr.parse::<std::net::Ipv6Addr>() {
253        reply.push(ATYP_IPV6);
254        reply.extend_from_slice(&ipv6.octets());
255    } else {
256        // Domain name
257        reply.push(ATYP_DOMAIN);
258        reply.push(addr.len() as u8);
259        reply.extend_from_slice(addr.as_bytes());
260    }
261
262    reply.extend_from_slice(&port.to_be_bytes());
263
264    stream.write_all(&reply).await
265}