Skip to main content

rustgate/
proxy.rs

1use crate::cert::CertificateAuthority;
2use crate::error::ProxyError;
3use crate::handler::{boxed_body, full_boxed_body, Buffered, BoxBody, Dropped, RequestHandler};
4use crate::logging::{LogId, UpstreamTarget};
5use crate::tls;
6use bytes::Bytes;
7use http_body_util::{BodyExt, Empty, Full};
8use hyper::client::conn::http1 as client_http1;
9use hyper::server::conn::http1 as server_http1;
10use hyper::service::service_fn;
11use hyper::upgrade::Upgraded;
12use hyper::{Method, Request, Response};
13use hyper_util::rt::TokioIo;
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::net::TcpStream;
17use tracing::{debug, error, info, warn};
18
19/// Maximum body size for interception (10 MB).
20const MAX_INTERCEPT_BODY: usize = 10 * 1024 * 1024;
21
22/// Check if a body should be intercepted based on Content-Length header.
23/// Returns true ONLY if Content-Length is explicitly present and within the limit.
24/// All other cases (chunked, close-delimited, unknown-length) skip interception
25/// to avoid consuming streaming bodies.
26fn should_intercept_body(headers: &hyper::HeaderMap) -> bool {
27    if let Some(cl) = headers.get(hyper::header::CONTENT_LENGTH) {
28        if let Ok(s) = cl.to_str() {
29            if let Ok(len) = s.parse::<usize>() {
30                return len <= MAX_INTERCEPT_BODY;
31            }
32        }
33    }
34    false
35}
36
37/// Collect a body into Bytes. Returns None on failure or size exceeded.
38async fn try_collect_body<B>(body: B) -> Option<Bytes>
39where
40    B: hyper::body::Body<Data = Bytes, Error = hyper::Error>,
41{
42    use http_body_util::Limited;
43    let limited = Limited::new(body, MAX_INTERCEPT_BODY);
44    BodyExt::collect(limited)
45        .await
46        .ok()
47        .map(|c| c.to_bytes())
48}
49
50/// Shared state passed to each connection handler.
51pub struct ProxyState {
52    pub ca: Arc<CertificateAuthority>,
53    pub mitm: bool,
54    pub intercept: bool,
55    pub log_traffic: bool,
56    pub handler: Arc<dyn RequestHandler>,
57}
58
59/// Flush a pending log entry by calling handle_response with a synthetic error response.
60fn flush_log_on_error(
61    handler: &Arc<dyn RequestHandler>,
62    log_id: Option<LogId>,
63    status: u16,
64) {
65    if let Some(id) = log_id {
66        let mut res = Response::builder()
67            .status(status)
68            .body(full_boxed_body(Bytes::new()))
69            .unwrap();
70        res.extensions_mut().insert(id);
71        handler.handle_response(&mut res);
72    }
73}
74
75/// Handle a single accepted TCP connection.
76pub async fn handle_connection(
77    stream: TcpStream,
78    addr: SocketAddr,
79    state: Arc<ProxyState>,
80) {
81    debug!("New connection from {addr}");
82
83    let io = TokioIo::new(stream);
84    let state = state.clone();
85
86    let service = service_fn(move |req: Request<hyper::body::Incoming>| {
87        let state = state.clone();
88        async move { handle_request(req, state).await }
89    });
90
91    if let Err(e) = server_http1::Builder::new()
92        .preserve_header_case(true)
93        .title_case_headers(true)
94        .serve_connection(io, service)
95        .with_upgrades()
96        .await
97    {
98        if !e.to_string().contains("early eof")
99            && !e.to_string().contains("connection closed")
100        {
101            error!("Connection error from {addr}: {e}");
102        }
103    }
104}
105
106/// Route a request: CONNECT goes to tunnel/MITM, everything else gets forwarded.
107async fn handle_request(
108    req: Request<hyper::body::Incoming>,
109    state: Arc<ProxyState>,
110) -> Result<Response<BoxBody>, hyper::Error> {
111    if req.method() == Method::CONNECT {
112        handle_connect(req, state).await
113    } else {
114        handle_forward(req, state).await
115    }
116}
117
118// ─── HTTP Forwarding ───────────────────────────────────────────────────────────
119
120/// Forward a plain HTTP request to the upstream server.
121async fn handle_forward(
122    req: Request<hyper::body::Incoming>,
123    state: Arc<ProxyState>,
124) -> Result<Response<BoxBody>, hyper::Error> {
125    let uri = req.uri().clone();
126    let host = match uri.host() {
127        Some(h) => h.to_string(),
128        None => {
129            warn!("Request with no host: {uri}");
130            return Ok(bad_request("Missing host in URI"));
131        }
132    };
133    let port = uri.port_u16().unwrap_or(80);
134    let addr = format!("{host}:{port}");
135
136    // Build the request to forward
137    let (mut parts, body) = req.into_parts();
138    let path = parts
139        .uri
140        .path_and_query()
141        .map(|pq| pq.as_str())
142        .unwrap_or("/");
143    parts.uri = match path.parse() {
144        Ok(uri) => uri,
145        Err(_) => {
146            warn!("Invalid path: {path}");
147            return Ok(bad_request("Invalid request URI"));
148        }
149    };
150
151    // Store upstream target for logging
152    parts.extensions.insert(UpstreamTarget {
153        scheme: "http".into(),
154        host: host.to_string(),
155        port,
156    });
157
158    // Check intercept eligibility BEFORE stripping hop-by-hop headers
159    let do_buffer = (state.intercept || state.log_traffic) && should_intercept_body(&parts.headers);
160
161    strip_hop_by_hop_headers(&mut parts.headers);
162
163    let mut forwarded_req = if do_buffer {
164        match try_collect_body(body).await {
165            Some(bytes) => {
166                let mut req = Request::from_parts(parts, full_boxed_body(bytes));
167                req.extensions_mut().insert(Buffered);
168                req
169            }
170            None => {
171                error!("Request body collection failed");
172                return Ok(bad_gateway("Request body read error"));
173            }
174        }
175    } else {
176        Request::from_parts(parts, boxed_body(body))
177    };
178
179    state.handler.handle_request(&mut forwarded_req);
180    let log_id = forwarded_req.extensions().get::<LogId>().cloned();
181
182    if forwarded_req.extensions().get::<Dropped>().is_some() {
183        return Ok(bad_gateway("Request dropped by interceptor"));
184    }
185
186    // Connect to upstream
187    let upstream = match TcpStream::connect(&addr).await {
188        Ok(s) => s,
189        Err(e) => {
190            error!("Failed to connect to {addr}: {e}");
191            flush_log_on_error(&state.handler, log_id, 502);
192            return Ok(bad_gateway(&format!("Failed to connect to {addr}")));
193        }
194    };
195
196    let io = TokioIo::new(upstream);
197    let (mut sender, conn) = match client_http1::handshake(io).await {
198        Ok(r) => r,
199        Err(e) => {
200            error!("Handshake with {addr} failed: {e}");
201            flush_log_on_error(&state.handler, log_id, 502);
202            return Ok(bad_gateway("Upstream handshake failed"));
203        }
204    };
205
206    tokio::spawn(async move {
207        if let Err(e) = conn.await {
208            error!("Upstream connection error: {e}");
209        }
210    });
211
212    match sender.send_request(forwarded_req).await {
213        Ok(res) => {
214            let (parts, body) = res.into_parts();
215            let mut response = if (state.intercept || state.log_traffic) && should_intercept_body(&parts.headers) {
216                match try_collect_body(body).await {
217                    Some(bytes) => {
218                        let mut res = Response::from_parts(parts, full_boxed_body(bytes));
219                        res.extensions_mut().insert(Buffered);
220                        res
221                    }
222                    None => {
223                        error!("Response body collection failed");
224                        flush_log_on_error(&state.handler, log_id, 502);
225                        return Ok(bad_gateway("Response body read error"));
226                    }
227                }
228            } else {
229                Response::from_parts(parts, boxed_body(body))
230            };
231            if let Some(id) = log_id.clone() { response.extensions_mut().insert(id); }
232            state.handler.handle_response(&mut response);
233            if response.extensions().get::<Dropped>().is_some() {
234                return Ok(interceptor_dropped_response());
235            }
236            Ok(response)
237        }
238        Err(e) => {
239            error!("Upstream request failed: {e}");
240            flush_log_on_error(&state.handler, log_id, 502);
241            Ok(bad_gateway("Upstream request failed"))
242        }
243    }
244}
245
246// ─── CONNECT Handling ──────────────────────────────────────────────────────────
247
248/// Handle a CONNECT request: either tunnel (passthrough) or MITM.
249async fn handle_connect(
250    req: Request<hyper::body::Incoming>,
251    state: Arc<ProxyState>,
252) -> Result<Response<BoxBody>, hyper::Error> {
253    let target = match req.uri().authority() {
254        Some(auth) => auth.to_string(),
255        None => {
256            warn!("CONNECT without authority");
257            return Ok(bad_request("CONNECT target missing"));
258        }
259    };
260
261    let (host, port) = parse_host_port(&target);
262    let addr = format!("{host}:{port}");
263
264    info!("CONNECT {target}");
265
266    if state.mitm {
267        // MITM mode: intercept the TLS connection
268        handle_mitm(req, host, addr, state).await
269    } else {
270        // Passthrough mode: just tunnel bytes
271        handle_tunnel(req, addr).await
272    }
273}
274
275/// Passthrough tunneling: bidirectional copy between client and upstream.
276async fn handle_tunnel(
277    req: Request<hyper::body::Incoming>,
278    addr: String,
279) -> Result<Response<BoxBody>, hyper::Error> {
280    tokio::spawn(async move {
281        match hyper::upgrade::on(req).await {
282            Ok(upgraded) => {
283                if let Err(e) = tunnel_bidirectional(upgraded, &addr).await {
284                    error!("Tunnel error to {addr}: {e}");
285                }
286            }
287            Err(e) => {
288                error!("Upgrade failed: {e}");
289            }
290        }
291    });
292
293    // Respond with 200 to tell the client the tunnel is established
294    Ok(Response::new(empty_body()))
295}
296
297/// Copy data bidirectionally between the upgraded client connection and upstream.
298async fn tunnel_bidirectional(
299    upgraded: Upgraded,
300    addr: &str,
301) -> crate::error::Result<()> {
302    let mut upstream = TcpStream::connect(addr).await?;
303
304    let mut client = TokioIo::new(upgraded);
305
306    let (client_to_server, server_to_client) =
307        tokio::io::copy_bidirectional(&mut client, &mut upstream).await?;
308
309    debug!(
310        "Tunnel closed: {addr} (client→server: {client_to_server}B, server→client: {server_to_client}B)"
311    );
312    Ok(())
313}
314
315/// MITM mode: terminate TLS with both ends, intercept HTTP traffic.
316async fn handle_mitm(
317    req: Request<hyper::body::Incoming>,
318    host: String,
319    addr: String,
320    state: Arc<ProxyState>,
321) -> Result<Response<BoxBody>, hyper::Error> {
322    let state = state.clone();
323
324    tokio::spawn(async move {
325        match hyper::upgrade::on(req).await {
326            Ok(upgraded) => {
327                if let Err(e) =
328                    mitm_intercept(upgraded, &host, &addr, state).await
329                {
330                    error!("MITM error for {host}: {e}");
331                }
332            }
333            Err(e) => {
334                error!("MITM upgrade failed: {e}");
335            }
336        }
337    });
338
339    Ok(Response::new(empty_body()))
340}
341
342/// Perform MITM interception on an upgraded connection.
343async fn mitm_intercept(
344    upgraded: Upgraded,
345    host: &str,
346    addr: &str,
347    state: Arc<ProxyState>,
348) -> crate::error::Result<()> {
349    // Create a TLS acceptor with a fake cert for this domain
350    let acceptor = tls::make_tls_acceptor(&state.ca, host).await?;
351
352    // Accept TLS from the client side
353    let client_io = TokioIo::new(upgraded);
354    let client_tls = acceptor
355        .accept(client_io)
356        .await
357        .map_err(|e| ProxyError::Other(format!("Client TLS accept failed: {e}")))?;
358
359    let client_tls = TokioIo::new(client_tls);
360
361    // Serve HTTP on the decrypted client stream
362    let host = host.to_string();
363    let addr = addr.to_string();
364
365    let service = service_fn(move |req: Request<hyper::body::Incoming>| {
366        let host = host.clone();
367        let addr = addr.clone();
368        let state = state.clone();
369        async move {
370            mitm_forward_request(req, &host, &addr, state).await
371        }
372    });
373
374    if let Err(e) = server_http1::Builder::new()
375        .preserve_header_case(true)
376        .title_case_headers(true)
377        .serve_connection(client_tls, service)
378        .await
379    {
380        if !e.to_string().contains("early eof")
381            && !e.to_string().contains("connection closed")
382        {
383            debug!("MITM connection closed: {e}");
384        }
385    }
386
387    Ok(())
388}
389
390/// Forward a request from the MITM-decrypted stream to the real upstream over TLS.
391async fn mitm_forward_request(
392    req: Request<hyper::body::Incoming>,
393    host: &str,
394    addr: &str,
395    state: Arc<ProxyState>,
396) -> Result<Response<BoxBody>, hyper::Error> {
397    let (mut parts, body) = req.into_parts();
398
399    parts.extensions.insert(UpstreamTarget {
400        scheme: "https".into(),
401        host: host.to_string(),
402        port: addr.rsplit_once(':').and_then(|(_, p)| p.parse().ok()).unwrap_or(443),
403    });
404
405    let do_buffer = (state.intercept || state.log_traffic) && should_intercept_body(&parts.headers);
406    strip_hop_by_hop_headers(&mut parts.headers);
407
408    let mut forwarded_req = if do_buffer {
409        match try_collect_body(body).await {
410            Some(bytes) => {
411                let mut req = Request::from_parts(parts, full_boxed_body(bytes));
412                req.extensions_mut().insert(Buffered);
413                req
414            }
415            None => {
416                error!("MITM request body collection failed");
417                return Ok(bad_gateway("Request body read error"));
418            }
419        }
420    } else {
421        Request::from_parts(parts, boxed_body(body))
422    };
423
424    state.handler.handle_request(&mut forwarded_req);
425    let log_id = forwarded_req.extensions().get::<LogId>().cloned();
426
427    if forwarded_req.extensions().get::<Dropped>().is_some() {
428        return Ok(bad_gateway("Request dropped by interceptor"));
429    }
430
431    // Connect to upstream over TLS
432    let upstream_tls = match tls::connect_tls_upstream(host, addr).await {
433        Ok(s) => s,
434        Err(e) => {
435            error!("Failed TLS connect to {addr}: {e}");
436            flush_log_on_error(&state.handler, log_id.clone(), 502);
437            return Ok(bad_gateway(&format!(
438                "Failed to connect to upstream: {e}"
439            )));
440        }
441    };
442
443    let io = TokioIo::new(upstream_tls);
444    let (mut sender, conn) = match client_http1::handshake(io).await {
445        Ok(r) => r,
446        Err(e) => {
447            error!("Upstream TLS handshake failed: {e}");
448            flush_log_on_error(&state.handler, log_id.clone(), 502);
449            return Ok(bad_gateway("Upstream TLS handshake failed"));
450        }
451    };
452
453    tokio::spawn(async move {
454        if let Err(e) = conn.await {
455            debug!("Upstream TLS connection closed: {e}");
456        }
457    });
458
459    match sender.send_request(forwarded_req).await {
460        Ok(res) => {
461            let (parts, body) = res.into_parts();
462            let mut response = if (state.intercept || state.log_traffic) && should_intercept_body(&parts.headers) {
463                match try_collect_body(body).await {
464                    Some(bytes) => {
465                        let mut res = Response::from_parts(parts, full_boxed_body(bytes));
466                        res.extensions_mut().insert(Buffered);
467                        res
468                    }
469                    None => {
470                        error!("MITM response body collection failed");
471                        flush_log_on_error(&state.handler, log_id, 502);
472                        return Ok(bad_gateway("Response body read error"));
473                    }
474                }
475            } else {
476                Response::from_parts(parts, boxed_body(body))
477            };
478            if let Some(id) = log_id.clone() { response.extensions_mut().insert(id); }
479            state.handler.handle_response(&mut response);
480            if response.extensions().get::<Dropped>().is_some() {
481                return Ok(interceptor_dropped_response());
482            }
483            Ok(response)
484        }
485        Err(e) => {
486            error!("Upstream TLS request failed: {e}");
487            flush_log_on_error(&state.handler, log_id, 502);
488            Ok(bad_gateway("Upstream request failed"))
489        }
490    }
491}
492
493// ─── Helpers ───────────────────────────────────────────────────────────────────
494
495/// Hop-by-hop headers that should not be forwarded.
496const HOP_BY_HOP_HEADERS: &[&str] = &[
497    "connection",
498    "keep-alive",
499    "proxy-authenticate",
500    "proxy-authorization",
501    "te",
502    "trailers",
503    "transfer-encoding",
504    "upgrade",
505];
506
507/// Parse host and port from a CONNECT target, handling IPv6 bracket notation.
508/// e.g. "example.com:443", "[::1]:443", "example.com"
509pub fn parse_host_port(target: &str) -> (String, u16) {
510    if let Some(bracketed) = target.strip_prefix('[') {
511        // IPv6: [::1]:port
512        if let Some((ip6, rest)) = bracketed.split_once(']') {
513            let port = rest
514                .strip_prefix(':')
515                .and_then(|p| p.parse().ok())
516                .unwrap_or(443);
517            return (ip6.to_string(), port);
518        }
519    }
520    // IPv4 / hostname: host:port
521    if let Some((host, port_str)) = target.rsplit_once(':') {
522        if let Ok(port) = port_str.parse::<u16>() {
523            return (host.to_string(), port);
524        }
525    }
526    (target.to_string(), 443)
527}
528
529fn strip_hop_by_hop_headers(headers: &mut hyper::HeaderMap) {
530    // Also remove headers listed in the Connection header value
531    if let Some(conn_val) = headers.get("connection").cloned() {
532        if let Ok(val) = conn_val.to_str() {
533            for name in val.split(',') {
534                let name = name.trim();
535                if !name.is_empty() {
536                    headers.remove(name);
537                }
538            }
539        }
540    }
541
542    for name in HOP_BY_HOP_HEADERS {
543        headers.remove(*name);
544    }
545}
546
547fn empty_body() -> BoxBody {
548    Empty::<Bytes>::new()
549        .map_err(|never| match never {})
550        .boxed()
551}
552
553fn bad_request(msg: &str) -> Response<BoxBody> {
554    Response::builder()
555        .status(400)
556        .body(full_body(msg))
557        .unwrap()
558}
559
560fn bad_gateway(msg: &str) -> Response<BoxBody> {
561    Response::builder()
562        .status(502)
563        .body(full_body(msg))
564        .unwrap()
565}
566
567/// Non-retryable response for interceptor-dropped responses.
568/// Uses 444 (No Response, nginx convention) + Connection: close to signal
569/// that the response was locally suppressed and the client should NOT retry.
570/// The upstream request was already executed.
571fn interceptor_dropped_response() -> Response<BoxBody> {
572    Response::builder()
573        .status(444)
574        .header("Connection", "close")
575        .header("X-RustGate-Interceptor", "response-dropped")
576        .body(full_body(
577            "Response dropped by interceptor. The upstream request was already executed. Do not retry.",
578        ))
579        .unwrap()
580}
581
582fn full_body(msg: &str) -> BoxBody {
583    Full::new(Bytes::from(msg.to_string()))
584        .map_err(|never| match never {})
585        .boxed()
586}