Skip to main content

safe_shell_sandbox/
proxy.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::net::{TcpListener, TcpStream};
6use tokio::sync::oneshot;
7
8/// Check if a hostname matches an allowlist pattern.
9/// Supports exact match and wildcard (`*.example.com` also matches `example.com`).
10pub fn domain_matches(host: &str, pattern: &str) -> bool {
11    let host = host.split(':').next().unwrap_or(host);
12    let host = host.to_lowercase();
13    let pattern = pattern.to_lowercase();
14
15    if pattern == "*" {
16        return true;
17    }
18
19    if let Some(suffix) = pattern.strip_prefix("*.") {
20        host == suffix || host.ends_with(&format!(".{suffix}"))
21    } else {
22        host == pattern
23    }
24}
25
26/// A local HTTP proxy that filters requests by domain allowlist.
27/// Runs its own tokio runtime in a background thread.
28pub struct DomainFilterProxy {
29    port: u16,
30    shutdown_tx: Option<oneshot::Sender<()>>,
31    _thread: Option<std::thread::JoinHandle<()>>,
32    blocked_count: Arc<AtomicUsize>,
33}
34
35impl DomainFilterProxy {
36    /// Start the proxy on a random port. Returns immediately with the bound port.
37    pub fn start(
38        allowed_domains: Vec<String>,
39        quiet: bool,
40    ) -> Result<Self, Box<dyn std::error::Error>> {
41        let (port_tx, port_rx) = std::sync::mpsc::channel();
42        let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
43        let blocked_count = Arc::new(AtomicUsize::new(0));
44        let blocked_count_clone = blocked_count.clone();
45
46        let thread = std::thread::spawn(move || {
47            let rt = tokio::runtime::Builder::new_current_thread()
48                .enable_all()
49                .build()
50                .unwrap();
51
52            rt.block_on(async {
53                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
54                let port = listener.local_addr().unwrap().port();
55                let _ = port_tx.send(port);
56
57                let domains = Arc::new(allowed_domains);
58
59                tokio::select! {
60                    _ = accept_loop(listener, domains, blocked_count_clone, quiet) => {}
61                    _ = shutdown_rx => {}
62                }
63            });
64        });
65
66        let port = port_rx
67            .recv()
68            .map_err(|e| format!("Proxy failed to start: {e}"))?;
69
70        Ok(Self {
71            port,
72            shutdown_tx: Some(shutdown_tx),
73            _thread: Some(thread),
74            blocked_count,
75        })
76    }
77
78    pub fn port(&self) -> u16 {
79        self.port
80    }
81
82    pub fn blocked_count(&self) -> usize {
83        self.blocked_count.load(Ordering::Relaxed)
84    }
85}
86
87impl Drop for DomainFilterProxy {
88    fn drop(&mut self) {
89        if let Some(tx) = self.shutdown_tx.take() {
90            let _ = tx.send(());
91        }
92    }
93}
94
95async fn accept_loop(
96    listener: TcpListener,
97    domains: Arc<Vec<String>>,
98    blocked_count: Arc<AtomicUsize>,
99    quiet: bool,
100) {
101    while let Ok((stream, _)) = listener.accept().await {
102        let domains = domains.clone();
103        let blocked = blocked_count.clone();
104        tokio::spawn(async move {
105            if let Err(e) = handle_connection(stream, &domains, &blocked, quiet).await {
106                let msg = e.to_string();
107                if !msg.contains("Broken pipe") && !msg.contains("Connection reset") {
108                    eprintln!("[safe-shell proxy] {msg}");
109                }
110            }
111        });
112    }
113}
114
115async fn handle_connection(
116    stream: TcpStream,
117    allowed: &[String],
118    blocked_count: &AtomicUsize,
119    quiet: bool,
120) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
121    let (reader, writer) = stream.into_split();
122    let mut reader = BufReader::new(reader);
123    let writer = writer;
124
125    // Read the request line
126    let mut request_line = String::new();
127    reader.read_line(&mut request_line).await?;
128
129    let parts: Vec<&str> = request_line.split_whitespace().collect();
130    if parts.len() < 2 {
131        return Ok(());
132    }
133
134    let method = parts[0].to_uppercase();
135    let target = parts[1].to_string();
136
137    // Read remaining headers
138    let mut headers = Vec::new();
139    loop {
140        let mut line = String::new();
141        reader.read_line(&mut line).await?;
142        if line.trim().is_empty() {
143            break;
144        }
145        headers.push(line);
146    }
147
148    if method == "CONNECT" {
149        handle_connect(reader, writer, &target, allowed, blocked_count, quiet).await
150    } else {
151        handle_http(
152            reader,
153            writer,
154            &request_line,
155            &target,
156            &headers,
157            allowed,
158            blocked_count,
159            quiet,
160        )
161        .await
162    }
163}
164
165async fn handle_connect(
166    reader: BufReader<tokio::net::tcp::OwnedReadHalf>,
167    mut writer: tokio::net::tcp::OwnedWriteHalf,
168    target: &str,
169    allowed: &[String],
170    blocked_count: &AtomicUsize,
171    quiet: bool,
172) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
173    let host = target.split(':').next().unwrap_or(target);
174
175    if !allowed.iter().any(|p| domain_matches(host, p)) {
176        blocked_count.fetch_add(1, Ordering::Relaxed);
177        if !quiet {
178            eprintln!("\x1b[33m\u{26a0}\x1b[0m safe-shell: blocked network: {host}");
179        }
180        let msg = format!(
181            "HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n\
182             [safe-shell] Network blocked: {host} is not in the allowlist\n"
183        );
184        writer.write_all(msg.as_bytes()).await?;
185        return Ok(());
186    }
187
188    // Connect to upstream
189    match TcpStream::connect(target).await {
190        Ok(upstream) => {
191            writer
192                .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
193                .await?;
194
195            let mut client_reader = reader.into_inner();
196            let (mut upstream_reader, mut upstream_writer) = upstream.into_split();
197
198            // Bidirectional tunnel
199            let c2u = tokio::io::copy(&mut client_reader, &mut upstream_writer);
200            let u2c = tokio::io::copy(&mut upstream_reader, &mut writer);
201
202            tokio::select! {
203                _ = c2u => {}
204                _ = u2c => {}
205            }
206        }
207        Err(e) => {
208            let msg = format!(
209                "HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n\
210                 [safe-shell] Cannot connect to {target}: {e}\n"
211            );
212            writer.write_all(msg.as_bytes()).await?;
213        }
214    }
215
216    Ok(())
217}
218
219async fn handle_http(
220    mut reader: BufReader<tokio::net::tcp::OwnedReadHalf>,
221    mut writer: tokio::net::tcp::OwnedWriteHalf,
222    request_line: &str,
223    target: &str,
224    headers: &[String],
225    allowed: &[String],
226    blocked_count: &AtomicUsize,
227    quiet: bool,
228) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
229    // Extract host from URL: "http://host:port/path"
230    let (hostname, port, path) = parse_http_url(target);
231
232    if !allowed.iter().any(|p| domain_matches(&hostname, p)) {
233        blocked_count.fetch_add(1, Ordering::Relaxed);
234        if !quiet {
235            eprintln!("\x1b[33m\u{26a0}\x1b[0m safe-shell: blocked network: {hostname}");
236        }
237        let msg = format!(
238            "HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n\
239             [safe-shell] Network blocked: {hostname} is not in the allowlist\n"
240        );
241        writer.write_all(msg.as_bytes()).await?;
242        return Ok(());
243    }
244
245    let upstream_addr = format!("{hostname}:{port}");
246
247    match TcpStream::connect(&upstream_addr).await {
248        Ok(upstream) => {
249            let (mut upstream_reader, mut upstream_writer) = upstream.into_split();
250
251            // Rewrite request line: "GET http://host/path HTTP/1.1" → "GET /path HTTP/1.1"
252            let parts: Vec<&str> = request_line.split_whitespace().collect();
253            let rewritten = format!("{} {} {}\r\n", parts[0], path, parts[2]);
254            upstream_writer.write_all(rewritten.as_bytes()).await?;
255
256            // Forward headers
257            for h in headers {
258                upstream_writer.write_all(h.as_bytes()).await?;
259            }
260            upstream_writer.write_all(b"\r\n").await?;
261
262            // Bidirectional copy
263            let c2u = tokio::io::copy(&mut reader, &mut upstream_writer);
264            let u2c = tokio::io::copy(&mut upstream_reader, &mut writer);
265
266            tokio::select! {
267                _ = c2u => {}
268                _ = u2c => {}
269            }
270        }
271        Err(e) => {
272            let msg = format!(
273                "HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n\
274                 [safe-shell] Cannot connect to {upstream_addr}: {e}\n"
275            );
276            writer.write_all(msg.as_bytes()).await?;
277        }
278    }
279
280    Ok(())
281}
282
283fn parse_http_url(url: &str) -> (String, String, String) {
284    let rest = url
285        .strip_prefix("http://")
286        .or_else(|| url.strip_prefix("https://"))
287        .unwrap_or(url);
288
289    let (host_port, path) = match rest.find('/') {
290        Some(i) => (&rest[..i], &rest[i..]),
291        None => (rest, "/"),
292    };
293
294    let (host, port) = match host_port.find(':') {
295        Some(i) => (&host_port[..i], &host_port[i + 1..]),
296        None => (host_port, "80"),
297    };
298
299    (host.to_string(), port.to_string(), path.to_string())
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn exact_match() {
308        assert!(domain_matches("registry.npmjs.org", "registry.npmjs.org"));
309        assert!(domain_matches("Registry.Npmjs.Org", "registry.npmjs.org"));
310    }
311
312    #[test]
313    fn exact_no_match() {
314        assert!(!domain_matches("untrusted.test", "npmjs.org"));
315        assert!(!domain_matches("registry.npmjs.org", "npmjs.org"));
316    }
317
318    #[test]
319    fn wildcard_subdomain() {
320        assert!(domain_matches("sub.npmjs.org", "*.npmjs.org"));
321        assert!(domain_matches("deep.sub.npmjs.org", "*.npmjs.org"));
322    }
323
324    #[test]
325    fn wildcard_matches_base() {
326        assert!(domain_matches("npmjs.org", "*.npmjs.org"));
327    }
328
329    #[test]
330    fn wildcard_no_match() {
331        assert!(!domain_matches("untrusted.test", "*.npmjs.org"));
332        assert!(!domain_matches("npmjs.org.untrusted.test", "*.npmjs.org"));
333    }
334
335    #[test]
336    fn strips_port() {
337        assert!(domain_matches(
338            "registry.npmjs.org:443",
339            "registry.npmjs.org"
340        ));
341        assert!(domain_matches("sub.npmjs.org:8080", "*.npmjs.org"));
342    }
343
344    #[test]
345    fn star_matches_everything() {
346        assert!(domain_matches("anything.com", "*"));
347        assert!(domain_matches("untrusted.test:8000", "*"));
348    }
349
350    #[test]
351    fn case_insensitive() {
352        assert!(domain_matches("REGISTRY.NPMJS.ORG", "*.npmjs.org"));
353        assert!(domain_matches("GitHub.com", "github.com"));
354    }
355
356    #[test]
357    fn prevents_suffix_attack() {
358        assert!(!domain_matches("bad-npmjs.org", "*.npmjs.org"));
359        assert!(!domain_matches("fakenpmjs.org", "*.npmjs.org"));
360    }
361
362    #[test]
363    fn proxy_starts_and_stops() {
364        let proxy = DomainFilterProxy::start(vec!["example.com".to_string()], true).unwrap();
365        assert!(proxy.port() > 0);
366        drop(proxy);
367    }
368
369    #[test]
370    fn parse_url_with_path() {
371        let (h, p, path) = parse_http_url("http://example.com/foo/bar");
372        assert_eq!(h, "example.com");
373        assert_eq!(p, "80");
374        assert_eq!(path, "/foo/bar");
375    }
376
377    #[test]
378    fn parse_url_with_port() {
379        let (h, p, path) = parse_http_url("http://example.com:8080/api");
380        assert_eq!(h, "example.com");
381        assert_eq!(p, "8080");
382        assert_eq!(path, "/api");
383    }
384
385    #[test]
386    fn parse_url_no_path() {
387        let (h, p, path) = parse_http_url("http://example.com");
388        assert_eq!(h, "example.com");
389        assert_eq!(p, "80");
390        assert_eq!(path, "/");
391    }
392
393    // --- Edge cases in domain matching ---
394
395    #[test]
396    fn empty_host_no_crash() {
397        assert!(!domain_matches("", "example.com"));
398        assert!(!domain_matches("", "*.example.com"));
399    }
400
401    #[test]
402    fn empty_pattern_no_crash() {
403        assert!(!domain_matches("example.com", ""));
404    }
405
406    #[test]
407    fn subdomain_of_tld_not_confused() {
408        // *.com should match sub.com
409        assert!(domain_matches("untrusted.com", "*.com"));
410        assert!(domain_matches("com", "*.com"));
411    }
412
413    #[test]
414    fn host_with_trailing_dot() {
415        // DNS allows trailing dot (FQDN)
416        // Our matcher strips port but not trailing dot — this is a known edge case
417        // Attackers shouldn't be able to bypass by adding a trailing dot
418        let result = domain_matches("untrusted.test.", "untrusted.test");
419        // Either match or not — just don't crash
420        let _ = result;
421    }
422
423    #[test]
424    fn wildcard_pattern_with_port() {
425        assert!(domain_matches("sub.example.com:8080", "*.example.com"));
426    }
427
428    #[test]
429    fn multiple_ports_in_host_no_crash() {
430        // Malformed host — should not crash, port strip takes first ':'
431        let _ = domain_matches("untrusted.test:80:443", "untrusted.test");
432    }
433}