Skip to main content

tiny_proxy/proxy/
handler.rs

1use anyhow::Error;
2use bytes::Bytes;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::header;
6use hyper::{Request, Response, StatusCode, Uri};
7use hyper_rustls::HttpsConnector;
8use hyper_util::client::legacy::connect::HttpConnector;
9use hyper_util::client::legacy::Client;
10use std::sync::Arc;
11use tokio::time::{timeout, Duration};
12use tracing::{error, info};
13
14#[cfg(feature = "logging")]
15use tracing::info_span;
16#[cfg(feature = "logging")]
17use tracing::Instrument;
18
19use crate::config::{extract_hostname, Config, SiteConfig};
20#[cfg(feature = "logging")]
21use crate::proxy::access_log::AccessLogGuard;
22use crate::proxy::access_log::{ensure_request_id, final_request_id};
23use crate::proxy::ActionResult;
24
25use crate::proxy::directives::{
26    handle_header, handle_method, handle_redirect, handle_respond, handle_reverse_proxy,
27    handle_strip_prefix, handle_uri_replace,
28};
29
30/// Unified response body type - can handle both streaming (Incoming) and buffered (Full<Bytes>)
31/// This allows us to support SSE streaming while maintaining a simple API
32type ResponseBody =
33    http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
34
35/// Check if header is hop-by-hop (should not be proxied)
36///
37/// Hop-by-hop headers are defined in RFC 7230 Section 6.1
38/// These headers are meant for a single connection and should NOT be proxied
39/// Uses hyper::header constants for optimal performance (no allocations!)
40fn is_hop_header(name: &header::HeaderName) -> bool {
41    matches!(
42        name,
43        &header::CONNECTION
44            | &header::UPGRADE
45            | &header::TE
46            | &header::TRAILER
47            | &header::PROXY_AUTHENTICATE
48            | &header::PROXY_AUTHORIZATION
49    )
50}
51
52/// Process directives in order, applying modifications and returning final action.
53/// Supports recursive handling of handle_path blocks.
54///
55/// Note: `info!` logs here are correlated with request ID only when the `logging`
56/// feature is enabled (via the tracing span set in `proxy()`). Without `logging`,
57/// these logs appear without request context.
58pub fn process_directives(
59    directives: &[crate::config::Directive],
60    req: &mut Request<Incoming>,
61    current_path: &str,
62) -> Result<ActionResult, String> {
63    let mut modified_path = current_path.to_string();
64
65    for directive in directives {
66        match directive {
67            crate::config::Directive::Header { name, value } => {
68                if let Err(e) = handle_header(name, value.as_deref(), req) {
69                    info!("   Failed to apply header {}: {}", name, e);
70                }
71            }
72
73            crate::config::Directive::UriReplace { find, replace } => {
74                handle_uri_replace(find, replace, &mut modified_path);
75            }
76
77            crate::config::Directive::StripPrefix { prefix } => {
78                handle_strip_prefix(prefix, &mut modified_path);
79            }
80
81            crate::config::Directive::HandlePath {
82                pattern,
83                directives: nested_directives,
84            } => {
85                if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
86                    info!("   Matched handle_path: {}", pattern);
87                    return process_directives(nested_directives, req, &remaining_path);
88                }
89            }
90
91            crate::config::Directive::Method {
92                methods,
93                directives: nested_directives,
94            } => {
95                if handle_method(methods, req) {
96                    info!("   Matched method directive");
97                    return process_directives(nested_directives, req, &modified_path);
98                }
99            }
100
101            crate::config::Directive::Redirect { status, url } => {
102                return Ok(handle_redirect(status, url));
103            }
104
105            crate::config::Directive::Respond { status, body } => {
106                return Ok(handle_respond(status, body));
107            }
108
109            crate::config::Directive::ReverseProxy {
110                to,
111                connect_timeout,
112                read_timeout,
113            } => {
114                return Ok(handle_reverse_proxy(
115                    to,
116                    &modified_path,
117                    *connect_timeout,
118                    *read_timeout,
119                ));
120            }
121        }
122    }
123
124    Err(format!(
125        "No action directive (respond or reverse_proxy) found in configuration for path: {}",
126        current_path
127    ))
128}
129
130/// Process a single request through the proxy
131///
132/// This implementation ALWAYS streams backend responses (nginx-style):
133/// - No buffering of response body
134/// - Direct streaming from backend to client
135/// - Works for both SSE and regular HTTP
136/// - Optimal performance and memory usage
137///
138/// For direct responses (Respond directive) and errors, buffering is used
139/// since these are small and generated by the proxy itself
140pub async fn proxy(
141    mut req: Request<Incoming>,
142    client: Client<HttpsConnector<HttpConnector>, Incoming>,
143    config: Arc<Config>,
144    remote_addr: std::net::SocketAddr,
145    is_tls: bool,
146) -> Result<Response<ResponseBody>, Error> {
147    // Generate or reuse request ID
148    let initial_request_id = ensure_request_id(&mut req);
149
150    // Extract request info before processing
151    #[cfg(feature = "logging")]
152    let method = req.method().clone().to_string();
153    let path = req.uri().path().to_string();
154    let host = req
155        .headers()
156        .get(hyper::header::HOST)
157        .and_then(|h| h.to_str().ok())
158        .unwrap_or("localhost")
159        .to_string();
160
161    #[cfg(feature = "logging")]
162    let span = info_span!("request", req_id = %initial_request_id);
163
164    #[allow(unused_variables)]
165    let future = async move {
166        #[cfg(feature = "logging")]
167        let mut log_guard = AccessLogGuard::new(
168            initial_request_id.clone(),
169            remote_addr,
170            method,
171            path.clone(),
172            host.clone(),
173        );
174
175        // Find site configuration by host
176        // Browsers send Host: example.com (no port) for default ports,
177        // but config keys may be "example.com:443". Try both.
178        let site_config = match find_site(&config, &host, is_tls) {
179            Some(config) => config,
180            None => {
181                error!("No configuration found for host: {}", host);
182                let (response, _body_len) = error_response_with_id(
183                    StatusCode::NOT_FOUND,
184                    &format!("No configuration found for host: {}", host),
185                    &initial_request_id,
186                );
187                #[cfg(feature = "logging")]
188                {
189                    log_guard.set_bytes_sent(_body_len);
190                    log_guard.finish(404);
191                }
192                return Ok(response);
193            }
194        };
195
196        // Process directives in correct order
197        let action_result = match process_directives(&site_config.directives, &mut req, &path) {
198            Ok(result) => result,
199            Err(e) => {
200                error!("Directive processing error: {}", e);
201                let final_id = final_request_id(&req, &initial_request_id);
202                #[cfg(feature = "logging")]
203                {
204                    log_guard.set_request_id(final_id.clone());
205                    tracing::Span::current().record("req_id", final_id.as_str());
206                }
207                let (response, _body_len) =
208                    error_response_with_id(StatusCode::INTERNAL_SERVER_ERROR, &e, &final_id);
209                #[cfg(feature = "logging")]
210                {
211                    log_guard.set_bytes_sent(_body_len);
212                    log_guard.finish(500);
213                }
214                return Ok(response);
215            }
216        };
217
218        // Read final request ID (directive may have overwritten X-Request-ID)
219        let request_id = final_request_id(&req, &initial_request_id);
220        #[cfg(feature = "logging")]
221        {
222            log_guard.set_request_id(request_id.clone());
223            // Update the tracing span with the final req_id
224            tracing::Span::current().record("req_id", request_id.as_str());
225        }
226
227        // Execute action
228        match action_result {
229            ActionResult::Redirect { status, url } => {
230                let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);
231
232                let boxed: ResponseBody = Full::new(Bytes::new())
233                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
234                    .boxed();
235                let response = Response::builder()
236                    .status(status_code)
237                    .header("Location", &url)
238                    .header("X-Request-ID", &request_id)
239                    .body(boxed)?;
240                #[cfg(feature = "logging")]
241                {
242                    log_guard.set_bytes_sent(0);
243                    log_guard.finish(status_code.as_u16());
244                }
245                Ok(response)
246            }
247            ActionResult::Respond { status, body } => {
248                let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
249                let _body_len = body.len();
250
251                let boxed: ResponseBody = Full::new(Bytes::from(body))
252                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
253                    .boxed();
254                let response = Response::builder()
255                    .status(status_code)
256                    .header("X-Request-ID", &request_id)
257                    .body(boxed)?;
258                #[cfg(feature = "logging")]
259                {
260                    log_guard.set_bytes_sent(_body_len);
261                    log_guard.finish(status_code.as_u16());
262                }
263                Ok(response)
264            }
265            ActionResult::ReverseProxy {
266                backend_url,
267                path_to_send,
268                connect_timeout: _,
269                read_timeout,
270            } => {
271                // Add protocol if missing
272                let backend_with_proto =
273                    if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
274                        backend_url
275                    } else {
276                        format!("http://{}", backend_url)
277                    };
278
279                // Use Uri::from_parts() instead of format!() + parse() - faster!
280                let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
281                parts.path_and_query = Some(path_to_send.parse()?);
282                let new_uri = Uri::from_parts(parts)?;
283
284                *req.uri_mut() = new_uri.clone();
285
286                // Save original host for X-Forwarded headers
287                let original_host_header = req.headers().get(hyper::header::HOST).cloned();
288
289                // Update Host header for backend
290                req.headers_mut().remove(hyper::header::HOST);
291                if let Some(authority) = new_uri.authority() {
292                    if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>()
293                    {
294                        req.headers_mut().insert(hyper::header::HOST, host_value);
295                    }
296                }
297
298                // Add X-Forwarded-* headers for backend visibility
299                if let Some(host_value) = original_host_header.clone() {
300                    req.headers_mut().insert("X-Forwarded-Host", host_value);
301                }
302
303                // X-Forwarded-Proto: based on whether the connection is TLS
304                req.headers_mut().insert(
305                    "X-Forwarded-Proto",
306                    hyper::header::HeaderValue::from_static(if is_tls { "https" } else { "http" }),
307                );
308
309                // X-Forwarded-For: real client IP
310                if let Ok(ip_value) =
311                    hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
312                {
313                    req.headers_mut().insert("X-Forwarded-For", ip_value);
314                }
315
316                // Remove hop-by-hop headers
317                req.headers_mut().remove(header::CONNECTION);
318                req.headers_mut().remove("accept-encoding");
319
320                // Forward request to backend with configurable timeout (default 30s)
321                let backend_timeout = read_timeout.unwrap_or(30);
322                match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
323                    Ok(Ok(response)) => {
324                        let status = response.status();
325                        let headers = response.headers().clone();
326
327                        // Stream response body directly (no buffering)
328                        let mut builder = Response::builder().status(status);
329
330                        // Copy headers, filtering hop-by-hop
331                        for (name, value) in headers.iter() {
332                            if !is_hop_header(name) && name != header::CONTENT_LENGTH {
333                                builder = builder.header(name, value);
334                            }
335                        }
336
337                        let (_, incoming_body) = response.into_parts();
338                        let boxed: ResponseBody = incoming_body
339                            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
340                            .boxed();
341
342                        let response = builder.header("X-Request-ID", &request_id).body(boxed)?;
343                        #[cfg(feature = "logging")]
344                        log_guard.finish(status.as_u16());
345                        Ok(response)
346                    }
347                    Ok(Err(e)) => {
348                        error!("Backend connection failed: {:?}", e);
349                        if e.is_connect() {
350                            error!("   Reason: Connection refused - backend unavailable");
351                        } else {
352                            error!("   Reason: Other connection error");
353                        }
354
355                        let (response, _body_len) = error_response_with_id(
356                            StatusCode::BAD_GATEWAY,
357                            "Backend service unavailable",
358                            &request_id,
359                        );
360                        #[cfg(feature = "logging")]
361                        {
362                            log_guard.set_bytes_sent(_body_len);
363                            log_guard.finish(502);
364                        }
365                        Ok(response)
366                    }
367                    Err(_) => {
368                        error!(
369                            "Backend request timed out after {} seconds",
370                            backend_timeout
371                        );
372
373                        let (response, _body_len) = error_response_with_id(
374                            StatusCode::GATEWAY_TIMEOUT,
375                            "Backend request timed out",
376                            &request_id,
377                        );
378                        #[cfg(feature = "logging")]
379                        {
380                            log_guard.set_bytes_sent(_body_len);
381                            log_guard.finish(504);
382                        }
383                        Ok(response)
384                    }
385                }
386            }
387        }
388    };
389
390    #[cfg(feature = "logging")]
391    let future = future.instrument(span);
392
393    future.await
394}
395
396/// Creates HTTP response with error and X-Request-ID header
397///
398/// Returns both the response and the body length (for access logging).
399fn error_response_with_id(
400    status: StatusCode,
401    message: &str,
402    request_id: &str,
403) -> (Response<ResponseBody>, usize) {
404    let body = format!(
405        r#"<!DOCTYPE html>
406        <html>
407        <head><title>{} {}</title></head>
408        <body>
409        <h1>{} {}</h1>
410        <p>{}</p>
411        <hr>
412        <p><em>Rust Proxy Server</em></p>
413        </body>
414        </html>"#,
415        status.as_u16(),
416        status.canonical_reason().unwrap_or("Error"),
417        status.as_u16(),
418        status.canonical_reason().unwrap_or("Error"),
419        message
420    );
421
422    let body_len = body.len();
423    let full = Full::new(Bytes::from(body));
424    let boxed: ResponseBody = full
425        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
426        .boxed();
427
428    let mut builder = Response::builder()
429        .status(status)
430        .header("Content-Type", "text/html; charset=utf-8");
431
432    if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
433        builder = builder.header("X-Request-ID", val);
434    }
435
436    let response = builder.body(boxed).unwrap_or_else(|_| {
437        Response::new(
438            Full::new(Bytes::from("Internal Server Error"))
439                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
440                .boxed(),
441        )
442    });
443
444    (response, body_len)
445}
446
447/// Match path against pattern (supports wildcard *)
448/// Returns Some(remaining_path) if match, None otherwise
449pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
450    if let Some(prefix) = pattern.strip_suffix("/*") {
451        if path.starts_with(prefix) {
452            let remaining = path.strip_prefix(prefix).unwrap_or(path);
453            Some(remaining.to_string())
454        } else {
455            None
456        }
457    } else if pattern == path {
458        Some("/".to_string())
459    } else {
460        None
461    }
462}
463
464/// Find a site configuration matching the given Host header value.
465///
466/// Browsers on default ports omit the port from the Host header:
467/// - HTTPS on 443 → `Host: example.com` (no `:443`)
468/// - HTTP on 80 → `Host: example.com` (no `:80`)
469///
470/// But config keys include the port: `"example.com:443"`, `"example.com:80"`.
471///
472/// This function tries multiple lookup strategies:
473/// 1. Exact match: `host` as-is
474/// 2. If `host` has no port → try `host:<default_port>` based on `is_tls`
475/// 3. If `host` has a port → also try just the hostname (in case config has no port)
476///
477/// # Limitations
478///
479/// For non-default TLS ports (e.g., 8443), browsers always include the port
480/// in the `Host` header (`Host: example.com:8443`), so strategy 1 (exact match)
481/// works fine. The fallback in strategy 2 only tries ports 443 (TLS) and 80 (HTTP).
482/// This means a non-browser client sending `Host: example.com` without a port to
483/// a TLS listener on :8443 will get a 404 — this is a protocol violation by the client.
484pub fn find_site<'a>(config: &'a Config, host: &str, is_tls: bool) -> Option<&'a SiteConfig> {
485    // 1. Exact match
486    if let Some(site) = config.sites.get(host) {
487        return Some(site);
488    }
489
490    // Determine if host already contains a port
491    // IPv6: [::1]:8080 — port is after the last ']' + ':'
492    // IPv4/hostname: example.com:443
493    let has_port = if host.starts_with('[') {
494        // IPv6 — look for port after ']'
495        if let Some(bracket_end) = host.find(']') {
496            host[bracket_end..].contains(':')
497        } else {
498            false
499        }
500    } else {
501        host.contains(':')
502    };
503
504    if !has_port {
505        // 2. Host has no port — try appending the default port
506        let default_port = if is_tls { 443 } else { 80 };
507        let candidate = format!("{}:{}", host, default_port);
508        if let Some(site) = config.sites.get(&candidate) {
509            return Some(site);
510        }
511
512        // 3. TLS on a non-standard port — match by SNI hostname if unambiguous
513        if is_tls {
514            let mut matches = config.sites.values().filter(|s| {
515                s.tls.is_some() && extract_hostname(&s.address).eq_ignore_ascii_case(host)
516            });
517            if let Some(site) = matches.next() {
518                if matches.next().is_none() {
519                    return Some(site);
520                }
521            }
522        }
523    } else {
524        // 4. Host has a port — try just the hostname (strip port)
525        let hostname = if host.starts_with('[') {
526            // IPv6 [::1]:port → ::1
527            let end = host.find(']').unwrap_or(host.len());
528            host[1..end].to_string()
529        } else {
530            host.rsplit(':').next_back().unwrap_or(host).to_string()
531        };
532        if let Some(site) = config.sites.get(&hostname) {
533            return Some(site);
534        }
535    }
536
537    None
538}
539
540#[cfg(test)]
541mod find_site_tests {
542    use super::*;
543    use std::collections::HashMap;
544
545    fn make_config(sites: Vec<(&str, bool)>) -> Config {
546        let mut map = HashMap::new();
547        for (addr, has_tls) in sites {
548            map.insert(
549                addr.to_string(),
550                crate::config::SiteConfig {
551                    address: addr.to_string(),
552                    directives: vec![],
553                    tls: if has_tls {
554                        Some(crate::config::TlsConfig {
555                            cert_path: "/fake/cert.pem".to_string(),
556                            key_path: "/fake/key.pem".to_string(),
557                        })
558                    } else {
559                        None
560                    },
561                },
562            );
563        }
564        Config { sites: map }
565    }
566
567    #[test]
568    fn test_exact_match() {
569        let config = make_config(vec![("example.com:443", true)]);
570        assert!(find_site(&config, "example.com:443", true).is_some());
571    }
572
573    #[test]
574    fn test_tls_host_without_port_finds_443() {
575        let config = make_config(vec![("example.com:443", true)]);
576        // Browser sends Host: example.com (no :443) on HTTPS
577        assert!(
578            find_site(&config, "example.com", true).is_some(),
579            "Should find example.com:443 when Host has no port and is_tls=true"
580        );
581    }
582
583    #[test]
584    fn test_http_host_without_port_finds_80() {
585        let config = make_config(vec![("example.com:80", false)]);
586        // Browser sends Host: example.com (no :80) on HTTP
587        assert!(
588            find_site(&config, "example.com", false).is_some(),
589            "Should find example.com:80 when Host has no port and is_tls=false"
590        );
591    }
592
593    #[test]
594    fn test_tls_host_without_port_no_match_on_80() {
595        let config = make_config(vec![("example.com:80", false)]);
596        // Host: example.com on TLS should NOT match :80
597        assert!(
598            find_site(&config, "example.com", true).is_none(),
599            "TLS on port 443 should not find :80 site"
600        );
601    }
602
603    #[test]
604    fn test_host_with_port_strips_port_fallback() {
605        let config = make_config(vec![("example.com", false)]);
606        // Config has "example.com" (no port), Host has "example.com:8080"
607        assert!(
608            find_site(&config, "example.com:8080", false).is_some(),
609            "Should strip port from Host and find config without port"
610        );
611    }
612
613    #[test]
614    fn test_tls_host_without_port_finds_non_standard_port() {
615        let config = make_config(vec![("alpha.local:8443", true)]);
616        assert!(
617            find_site(&config, "alpha.local", true).is_some(),
618            "Should find alpha.local:8443 when Host has no port on TLS"
619        );
620    }
621
622    #[test]
623    fn test_no_match() {
624        let config = make_config(vec![("other.com:443", true)]);
625        assert!(find_site(&config, "example.com", true).is_none());
626    }
627}