1use std::net::SocketAddr;
4use std::sync::Arc;
5
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::{TcpListener, TcpStream};
8use tokio::sync::oneshot;
9
10use crate::error::SandboxError;
11use crate::proxy::filter::{DomainFilter, FilterDecision};
12
13const SOCKS_VERSION: u8 = 0x05;
15const AUTH_NONE: u8 = 0x00;
16const CMD_CONNECT: u8 = 0x01;
17const ATYP_IPV4: u8 = 0x01;
18const ATYP_DOMAIN: u8 = 0x03;
19const ATYP_IPV6: u8 = 0x04;
20const REP_SUCCESS: u8 = 0x00;
21const REP_GENERAL_FAILURE: u8 = 0x01;
22const REP_CONNECTION_NOT_ALLOWED: u8 = 0x02;
23const REP_HOST_UNREACHABLE: u8 = 0x04;
24
25pub struct Socks5Proxy {
27 listener: Option<TcpListener>,
28 port: u16,
29 filter: Arc<DomainFilter>,
30 shutdown_tx: Option<oneshot::Sender<()>>,
31}
32
33impl Socks5Proxy {
34 pub async fn new(filter: DomainFilter) -> Result<Self, SandboxError> {
36 let listener = TcpListener::bind("127.0.0.1:0").await?;
37 let port = listener.local_addr()?.port();
38
39 tracing::debug!("SOCKS5 proxy listening on port {}", port);
40
41 Ok(Self {
42 listener: Some(listener),
43 port,
44 filter: Arc::new(filter),
45 shutdown_tx: None,
46 })
47 }
48
49 pub fn port(&self) -> u16 {
51 self.port
52 }
53
54 pub fn start(&mut self) -> Result<(), SandboxError> {
56 let listener = self
57 .listener
58 .take()
59 .ok_or_else(|| SandboxError::Proxy("Proxy already started".to_string()))?;
60
61 let filter = self.filter.clone();
62 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
63 self.shutdown_tx = Some(shutdown_tx);
64
65 tokio::spawn(async move {
66 loop {
67 tokio::select! {
68 accept_result = listener.accept() => {
69 match accept_result {
70 Ok((stream, addr)) => {
71 let filter = filter.clone();
72 tokio::spawn(async move {
73 if let Err(e) = handle_client(stream, addr, filter).await {
74 tracing::debug!("SOCKS5 error from {}: {}", addr, e);
75 }
76 });
77 }
78 Err(e) => {
79 tracing::error!("SOCKS5 accept error: {}", e);
80 }
81 }
82 }
83 _ = &mut shutdown_rx => {
84 tracing::debug!("SOCKS5 proxy shutting down");
85 break;
86 }
87 }
88 }
89 });
90
91 Ok(())
92 }
93
94 pub fn stop(&mut self) {
96 if let Some(tx) = self.shutdown_tx.take() {
97 let _ = tx.send(());
98 }
99 }
100}
101
102async fn handle_client(
104 mut stream: TcpStream,
105 _addr: SocketAddr,
106 filter: Arc<DomainFilter>,
107) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
108 let mut header = [0u8; 2];
110 stream.read_exact(&mut header).await?;
111
112 if header[0] != SOCKS_VERSION {
113 return Err("Invalid SOCKS version".into());
114 }
115
116 let nmethods = header[1] as usize;
117 let mut methods = vec![0u8; nmethods];
118 stream.read_exact(&mut methods).await?;
119
120 if !methods.contains(&AUTH_NONE) {
122 stream.write_all(&[SOCKS_VERSION, 0xFF]).await?;
123 return Err("No supported authentication method".into());
124 }
125
126 stream.write_all(&[SOCKS_VERSION, AUTH_NONE]).await?;
128
129 let mut request = [0u8; 4];
131 stream.read_exact(&mut request).await?;
132
133 if request[0] != SOCKS_VERSION {
134 return Err("Invalid SOCKS version in request".into());
135 }
136
137 let cmd = request[1];
138 let atyp = request[3];
140
141 if cmd != CMD_CONNECT {
142 send_reply(&mut stream, REP_GENERAL_FAILURE, "0.0.0.0", 0).await?;
143 return Err("Only CONNECT command is supported".into());
144 }
145
146 let (host, port) = match atyp {
148 ATYP_IPV4 => {
149 let mut addr = [0u8; 4];
150 stream.read_exact(&mut addr).await?;
151 let mut port_buf = [0u8; 2];
152 stream.read_exact(&mut port_buf).await?;
153 let port = u16::from_be_bytes(port_buf);
154 let host = format!("{}.{}.{}.{}", addr[0], addr[1], addr[2], addr[3]);
155 (host, port)
156 }
157 ATYP_DOMAIN => {
158 let mut len_buf = [0u8; 1];
159 stream.read_exact(&mut len_buf).await?;
160 let len = len_buf[0] as usize;
161 let mut domain = vec![0u8; len];
162 stream.read_exact(&mut domain).await?;
163 let mut port_buf = [0u8; 2];
164 stream.read_exact(&mut port_buf).await?;
165 let port = u16::from_be_bytes(port_buf);
166 let host = String::from_utf8_lossy(&domain).to_string();
167 (host, port)
168 }
169 ATYP_IPV6 => {
170 let mut addr = [0u8; 16];
171 stream.read_exact(&mut addr).await?;
172 let mut port_buf = [0u8; 2];
173 stream.read_exact(&mut port_buf).await?;
174 let port = u16::from_be_bytes(port_buf);
175 let host = format!(
177 "{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}",
178 u16::from_be_bytes([addr[0], addr[1]]),
179 u16::from_be_bytes([addr[2], addr[3]]),
180 u16::from_be_bytes([addr[4], addr[5]]),
181 u16::from_be_bytes([addr[6], addr[7]]),
182 u16::from_be_bytes([addr[8], addr[9]]),
183 u16::from_be_bytes([addr[10], addr[11]]),
184 u16::from_be_bytes([addr[12], addr[13]]),
185 u16::from_be_bytes([addr[14], addr[15]])
186 );
187 (host, port)
188 }
189 _ => {
190 send_reply(&mut stream, REP_GENERAL_FAILURE, "0.0.0.0", 0).await?;
191 return Err("Unsupported address type".into());
192 }
193 };
194
195 tracing::debug!("SOCKS5 CONNECT {}:{}", host, port);
196
197 let decision = filter.check(&host, port);
199
200 if matches!(decision, FilterDecision::Deny) {
201 tracing::debug!("SOCKS5 denied connection to {}:{}", host, port);
202 send_reply(&mut stream, REP_CONNECTION_NOT_ALLOWED, "0.0.0.0", 0).await?;
203 return Ok(());
204 }
205
206 let target = match TcpStream::connect(format!("{}:{}", host, port)).await {
208 Ok(s) => s,
209 Err(e) => {
210 tracing::debug!("SOCKS5 failed to connect to {}:{}: {}", host, port, e);
211 send_reply(&mut stream, REP_HOST_UNREACHABLE, "0.0.0.0", 0).await?;
212 return Ok(());
213 }
214 };
215
216 let local_addr = target.local_addr()?;
218 let (bind_addr, bind_port) = match local_addr {
219 SocketAddr::V4(addr) => (addr.ip().to_string(), addr.port()),
220 SocketAddr::V6(addr) => (addr.ip().to_string(), addr.port()),
221 };
222 send_reply(&mut stream, REP_SUCCESS, &bind_addr, bind_port).await?;
223
224 let (mut client_read, mut client_write) = stream.into_split();
226 let (mut target_read, mut target_write) = target.into_split();
227
228 let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
229 let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);
230
231 tokio::select! {
232 _ = client_to_target => {}
233 _ = target_to_client => {}
234 }
235
236 Ok(())
237}
238
239async fn send_reply(
241 stream: &mut TcpStream,
242 rep: u8,
243 addr: &str,
244 port: u16,
245) -> Result<(), std::io::Error> {
246 let mut reply = vec![SOCKS_VERSION, rep, 0x00]; if let Ok(ipv4) = addr.parse::<std::net::Ipv4Addr>() {
250 reply.push(ATYP_IPV4);
251 reply.extend_from_slice(&ipv4.octets());
252 } else if let Ok(ipv6) = addr.parse::<std::net::Ipv6Addr>() {
253 reply.push(ATYP_IPV6);
254 reply.extend_from_slice(&ipv6.octets());
255 } else {
256 reply.push(ATYP_DOMAIN);
258 reply.push(addr.len() as u8);
259 reply.extend_from_slice(addr.as_bytes());
260 }
261
262 reply.extend_from_slice(&port.to_be_bytes());
263
264 stream.write_all(&reply).await
265}