1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::net::{TcpListener, TcpStream};
6use tokio::sync::oneshot;
7
8pub fn domain_matches(host: &str, pattern: &str) -> bool {
11 let host = host.split(':').next().unwrap_or(host);
12 let host = host.to_lowercase();
13 let pattern = pattern.to_lowercase();
14
15 if pattern == "*" {
16 return true;
17 }
18
19 if let Some(suffix) = pattern.strip_prefix("*.") {
20 host == suffix || host.ends_with(&format!(".{suffix}"))
21 } else {
22 host == pattern
23 }
24}
25
26pub struct DomainFilterProxy {
29 port: u16,
30 shutdown_tx: Option<oneshot::Sender<()>>,
31 _thread: Option<std::thread::JoinHandle<()>>,
32 blocked_count: Arc<AtomicUsize>,
33}
34
35impl DomainFilterProxy {
36 pub fn start(
38 allowed_domains: Vec<String>,
39 quiet: bool,
40 ) -> Result<Self, Box<dyn std::error::Error>> {
41 let (port_tx, port_rx) = std::sync::mpsc::channel();
42 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
43 let blocked_count = Arc::new(AtomicUsize::new(0));
44 let blocked_count_clone = blocked_count.clone();
45
46 let thread = std::thread::spawn(move || {
47 let rt = tokio::runtime::Builder::new_current_thread()
48 .enable_all()
49 .build()
50 .unwrap();
51
52 rt.block_on(async {
53 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
54 let port = listener.local_addr().unwrap().port();
55 let _ = port_tx.send(port);
56
57 let domains = Arc::new(allowed_domains);
58
59 tokio::select! {
60 _ = accept_loop(listener, domains, blocked_count_clone, quiet) => {}
61 _ = shutdown_rx => {}
62 }
63 });
64 });
65
66 let port = port_rx
67 .recv()
68 .map_err(|e| format!("Proxy failed to start: {e}"))?;
69
70 Ok(Self {
71 port,
72 shutdown_tx: Some(shutdown_tx),
73 _thread: Some(thread),
74 blocked_count,
75 })
76 }
77
78 pub fn port(&self) -> u16 {
79 self.port
80 }
81
82 pub fn blocked_count(&self) -> usize {
83 self.blocked_count.load(Ordering::Relaxed)
84 }
85}
86
87impl Drop for DomainFilterProxy {
88 fn drop(&mut self) {
89 if let Some(tx) = self.shutdown_tx.take() {
90 let _ = tx.send(());
91 }
92 }
93}
94
95async fn accept_loop(
96 listener: TcpListener,
97 domains: Arc<Vec<String>>,
98 blocked_count: Arc<AtomicUsize>,
99 quiet: bool,
100) {
101 while let Ok((stream, _)) = listener.accept().await {
102 let domains = domains.clone();
103 let blocked = blocked_count.clone();
104 tokio::spawn(async move {
105 if let Err(e) = handle_connection(stream, &domains, &blocked, quiet).await {
106 let msg = e.to_string();
107 if !msg.contains("Broken pipe") && !msg.contains("Connection reset") {
108 eprintln!("[safe-shell proxy] {msg}");
109 }
110 }
111 });
112 }
113}
114
115async fn handle_connection(
116 stream: TcpStream,
117 allowed: &[String],
118 blocked_count: &AtomicUsize,
119 quiet: bool,
120) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
121 let (reader, writer) = stream.into_split();
122 let mut reader = BufReader::new(reader);
123 let writer = writer;
124
125 let mut request_line = String::new();
127 reader.read_line(&mut request_line).await?;
128
129 let parts: Vec<&str> = request_line.split_whitespace().collect();
130 if parts.len() < 2 {
131 return Ok(());
132 }
133
134 let method = parts[0].to_uppercase();
135 let target = parts[1].to_string();
136
137 let mut headers = Vec::new();
139 loop {
140 let mut line = String::new();
141 reader.read_line(&mut line).await?;
142 if line.trim().is_empty() {
143 break;
144 }
145 headers.push(line);
146 }
147
148 if method == "CONNECT" {
149 handle_connect(reader, writer, &target, allowed, blocked_count, quiet).await
150 } else {
151 handle_http(
152 reader,
153 writer,
154 &request_line,
155 &target,
156 &headers,
157 allowed,
158 blocked_count,
159 quiet,
160 )
161 .await
162 }
163}
164
165async fn handle_connect(
166 reader: BufReader<tokio::net::tcp::OwnedReadHalf>,
167 mut writer: tokio::net::tcp::OwnedWriteHalf,
168 target: &str,
169 allowed: &[String],
170 blocked_count: &AtomicUsize,
171 quiet: bool,
172) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
173 let host = target.split(':').next().unwrap_or(target);
174
175 if !allowed.iter().any(|p| domain_matches(host, p)) {
176 blocked_count.fetch_add(1, Ordering::Relaxed);
177 if !quiet {
178 eprintln!("\x1b[33m\u{26a0}\x1b[0m safe-shell: blocked network: {host}");
179 }
180 let msg = format!(
181 "HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n\
182 [safe-shell] Network blocked: {host} is not in the allowlist\n"
183 );
184 writer.write_all(msg.as_bytes()).await?;
185 return Ok(());
186 }
187
188 match TcpStream::connect(target).await {
190 Ok(upstream) => {
191 writer
192 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
193 .await?;
194
195 let mut client_reader = reader.into_inner();
196 let (mut upstream_reader, mut upstream_writer) = upstream.into_split();
197
198 let c2u = tokio::io::copy(&mut client_reader, &mut upstream_writer);
200 let u2c = tokio::io::copy(&mut upstream_reader, &mut writer);
201
202 tokio::select! {
203 _ = c2u => {}
204 _ = u2c => {}
205 }
206 }
207 Err(e) => {
208 let msg = format!(
209 "HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n\
210 [safe-shell] Cannot connect to {target}: {e}\n"
211 );
212 writer.write_all(msg.as_bytes()).await?;
213 }
214 }
215
216 Ok(())
217}
218
219async fn handle_http(
220 mut reader: BufReader<tokio::net::tcp::OwnedReadHalf>,
221 mut writer: tokio::net::tcp::OwnedWriteHalf,
222 request_line: &str,
223 target: &str,
224 headers: &[String],
225 allowed: &[String],
226 blocked_count: &AtomicUsize,
227 quiet: bool,
228) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
229 let (hostname, port, path) = parse_http_url(target);
231
232 if !allowed.iter().any(|p| domain_matches(&hostname, p)) {
233 blocked_count.fetch_add(1, Ordering::Relaxed);
234 if !quiet {
235 eprintln!("\x1b[33m\u{26a0}\x1b[0m safe-shell: blocked network: {hostname}");
236 }
237 let msg = format!(
238 "HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n\
239 [safe-shell] Network blocked: {hostname} is not in the allowlist\n"
240 );
241 writer.write_all(msg.as_bytes()).await?;
242 return Ok(());
243 }
244
245 let upstream_addr = format!("{hostname}:{port}");
246
247 match TcpStream::connect(&upstream_addr).await {
248 Ok(upstream) => {
249 let (mut upstream_reader, mut upstream_writer) = upstream.into_split();
250
251 let parts: Vec<&str> = request_line.split_whitespace().collect();
253 let rewritten = format!("{} {} {}\r\n", parts[0], path, parts[2]);
254 upstream_writer.write_all(rewritten.as_bytes()).await?;
255
256 for h in headers {
258 upstream_writer.write_all(h.as_bytes()).await?;
259 }
260 upstream_writer.write_all(b"\r\n").await?;
261
262 let c2u = tokio::io::copy(&mut reader, &mut upstream_writer);
264 let u2c = tokio::io::copy(&mut upstream_reader, &mut writer);
265
266 tokio::select! {
267 _ = c2u => {}
268 _ = u2c => {}
269 }
270 }
271 Err(e) => {
272 let msg = format!(
273 "HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n\
274 [safe-shell] Cannot connect to {upstream_addr}: {e}\n"
275 );
276 writer.write_all(msg.as_bytes()).await?;
277 }
278 }
279
280 Ok(())
281}
282
283fn parse_http_url(url: &str) -> (String, String, String) {
284 let rest = url
285 .strip_prefix("http://")
286 .or_else(|| url.strip_prefix("https://"))
287 .unwrap_or(url);
288
289 let (host_port, path) = match rest.find('/') {
290 Some(i) => (&rest[..i], &rest[i..]),
291 None => (rest, "/"),
292 };
293
294 let (host, port) = match host_port.find(':') {
295 Some(i) => (&host_port[..i], &host_port[i + 1..]),
296 None => (host_port, "80"),
297 };
298
299 (host.to_string(), port.to_string(), path.to_string())
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn exact_match() {
308 assert!(domain_matches("registry.npmjs.org", "registry.npmjs.org"));
309 assert!(domain_matches("Registry.Npmjs.Org", "registry.npmjs.org"));
310 }
311
312 #[test]
313 fn exact_no_match() {
314 assert!(!domain_matches("untrusted.test", "npmjs.org"));
315 assert!(!domain_matches("registry.npmjs.org", "npmjs.org"));
316 }
317
318 #[test]
319 fn wildcard_subdomain() {
320 assert!(domain_matches("sub.npmjs.org", "*.npmjs.org"));
321 assert!(domain_matches("deep.sub.npmjs.org", "*.npmjs.org"));
322 }
323
324 #[test]
325 fn wildcard_matches_base() {
326 assert!(domain_matches("npmjs.org", "*.npmjs.org"));
327 }
328
329 #[test]
330 fn wildcard_no_match() {
331 assert!(!domain_matches("untrusted.test", "*.npmjs.org"));
332 assert!(!domain_matches("npmjs.org.untrusted.test", "*.npmjs.org"));
333 }
334
335 #[test]
336 fn strips_port() {
337 assert!(domain_matches(
338 "registry.npmjs.org:443",
339 "registry.npmjs.org"
340 ));
341 assert!(domain_matches("sub.npmjs.org:8080", "*.npmjs.org"));
342 }
343
344 #[test]
345 fn star_matches_everything() {
346 assert!(domain_matches("anything.com", "*"));
347 assert!(domain_matches("untrusted.test:8000", "*"));
348 }
349
350 #[test]
351 fn case_insensitive() {
352 assert!(domain_matches("REGISTRY.NPMJS.ORG", "*.npmjs.org"));
353 assert!(domain_matches("GitHub.com", "github.com"));
354 }
355
356 #[test]
357 fn prevents_suffix_attack() {
358 assert!(!domain_matches("bad-npmjs.org", "*.npmjs.org"));
359 assert!(!domain_matches("fakenpmjs.org", "*.npmjs.org"));
360 }
361
362 #[test]
363 fn proxy_starts_and_stops() {
364 let proxy = DomainFilterProxy::start(vec!["example.com".to_string()], true).unwrap();
365 assert!(proxy.port() > 0);
366 drop(proxy);
367 }
368
369 #[test]
370 fn parse_url_with_path() {
371 let (h, p, path) = parse_http_url("http://example.com/foo/bar");
372 assert_eq!(h, "example.com");
373 assert_eq!(p, "80");
374 assert_eq!(path, "/foo/bar");
375 }
376
377 #[test]
378 fn parse_url_with_port() {
379 let (h, p, path) = parse_http_url("http://example.com:8080/api");
380 assert_eq!(h, "example.com");
381 assert_eq!(p, "8080");
382 assert_eq!(path, "/api");
383 }
384
385 #[test]
386 fn parse_url_no_path() {
387 let (h, p, path) = parse_http_url("http://example.com");
388 assert_eq!(h, "example.com");
389 assert_eq!(p, "80");
390 assert_eq!(path, "/");
391 }
392
393 #[test]
396 fn empty_host_no_crash() {
397 assert!(!domain_matches("", "example.com"));
398 assert!(!domain_matches("", "*.example.com"));
399 }
400
401 #[test]
402 fn empty_pattern_no_crash() {
403 assert!(!domain_matches("example.com", ""));
404 }
405
406 #[test]
407 fn subdomain_of_tld_not_confused() {
408 assert!(domain_matches("untrusted.com", "*.com"));
410 assert!(domain_matches("com", "*.com"));
411 }
412
413 #[test]
414 fn host_with_trailing_dot() {
415 let result = domain_matches("untrusted.test.", "untrusted.test");
419 let _ = result;
421 }
422
423 #[test]
424 fn wildcard_pattern_with_port() {
425 assert!(domain_matches("sub.example.com:8080", "*.example.com"));
426 }
427
428 #[test]
429 fn multiple_ports_in_host_no_crash() {
430 let _ = domain_matches("untrusted.test:80:443", "untrusted.test");
432 }
433}