Skip to main content

tiny_proxy/proxy/
handler.rs

1use anyhow::Error;
2use bytes::Bytes;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::header;
6use hyper::{Request, Response, StatusCode, Uri};
7use hyper_rustls::HttpsConnector;
8use hyper_util::client::legacy::connect::HttpConnector;
9use hyper_util::client::legacy::Client;
10use std::sync::Arc;
11use tokio::time::{timeout, Duration};
12use tracing::{error, info};
13
14#[cfg(feature = "logging")]
15use tracing::info_span;
16#[cfg(feature = "logging")]
17use tracing::Instrument;
18
19use crate::config::{extract_hostname, Config, SiteConfig};
20#[cfg(feature = "logging")]
21use crate::proxy::access_log::AccessLogGuard;
22use crate::proxy::access_log::{ensure_request_id, final_request_id};
23use crate::proxy::ActionResult;
24
25use crate::proxy::directives::{
26    apply_header_up, handle_header, handle_method, handle_redirect, handle_respond,
27    handle_reverse_proxy, handle_strip_prefix, handle_uri_replace,
28};
29
30/// Unified response body type - can handle both streaming (`Incoming`) and buffered (`Full<Bytes>`)
31/// This allows us to support SSE streaming while maintaining a simple API
32type ResponseBody =
33    http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
34
35/// Check if header is hop-by-hop (should not be proxied)
36///
37/// Hop-by-hop headers are defined in RFC 7230 Section 6.1
38/// These headers are meant for a single connection and should NOT be proxied
39/// Uses hyper::header constants for optimal performance (no allocations!)
40fn is_hop_header(name: &header::HeaderName) -> bool {
41    matches!(
42        name,
43        &header::CONNECTION
44            | &header::UPGRADE
45            | &header::TE
46            | &header::TRAILER
47            | &header::PROXY_AUTHENTICATE
48            | &header::PROXY_AUTHORIZATION
49    )
50}
51
52/// Process directives in order, applying modifications and returning final action.
53/// Supports recursive handling of handle_path blocks.
54///
55/// Note: `info!` logs here are correlated with request ID only when the `logging`
56/// feature is enabled (via the tracing span set in `proxy()`). Without `logging`,
57/// these logs appear without request context.
58pub fn process_directives(
59    directives: &[crate::config::Directive],
60    req: &mut Request<Incoming>,
61    current_path: &str,
62) -> Result<ActionResult, String> {
63    let mut modified_path = current_path.to_string();
64
65    for directive in directives {
66        match directive {
67            crate::config::Directive::Header { name, value } => {
68                if let Err(e) = handle_header(name, value.as_deref(), req) {
69                    info!("   Failed to apply header {}: {}", name, e);
70                }
71            }
72
73            crate::config::Directive::UriReplace { find, replace } => {
74                handle_uri_replace(find, replace, &mut modified_path);
75            }
76
77            crate::config::Directive::StripPrefix { prefix } => {
78                handle_strip_prefix(prefix, &mut modified_path);
79            }
80
81            crate::config::Directive::HandlePath {
82                pattern,
83                directives: nested_directives,
84            } => {
85                if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
86                    info!("   Matched handle_path: {}", pattern);
87                    return process_directives(nested_directives, req, &remaining_path);
88                }
89            }
90
91            crate::config::Directive::Method {
92                methods,
93                directives: nested_directives,
94            } => {
95                if handle_method(methods, req) {
96                    info!("   Matched method directive");
97                    return process_directives(nested_directives, req, &modified_path);
98                }
99            }
100
101            crate::config::Directive::Redirect { status, url } => {
102                return Ok(handle_redirect(status, url));
103            }
104
105            crate::config::Directive::Respond { status, body } => {
106                return Ok(handle_respond(status, body));
107            }
108
109            crate::config::Directive::ReverseProxy {
110                to,
111                connect_timeout,
112                read_timeout,
113                header_up,
114            } => {
115                return Ok(handle_reverse_proxy(
116                    to,
117                    &modified_path,
118                    *connect_timeout,
119                    *read_timeout,
120                    header_up.clone(),
121                ));
122            }
123        }
124    }
125
126    Err(format!(
127        "No action directive (respond or reverse_proxy) found in configuration for path: {}",
128        current_path
129    ))
130}
131
132/// Process a single request through the proxy
133///
134/// This implementation ALWAYS streams backend responses (nginx-style):
135/// - No buffering of response body
136/// - Direct streaming from backend to client
137/// - Works for both SSE and regular HTTP
138/// - Optimal performance and memory usage
139///
140/// For direct responses (Respond directive) and errors, buffering is used
141/// since these are small and generated by the proxy itself
142pub async fn proxy(
143    mut req: Request<Incoming>,
144    client: Client<HttpsConnector<HttpConnector>, Incoming>,
145    config: Arc<Config>,
146    remote_addr: std::net::SocketAddr,
147    is_tls: bool,
148) -> Result<Response<ResponseBody>, Error> {
149    // Generate or reuse request ID
150    let initial_request_id = ensure_request_id(&mut req);
151
152    #[cfg(feature = "logging")]
153    let span = info_span!("request", req_id = %initial_request_id);
154
155    let future = async move {
156        let path = req.uri().path().to_string();
157        let host = req
158            .headers()
159            .get(hyper::header::HOST)
160            .and_then(|h| h.to_str().ok())
161            .unwrap_or("localhost");
162
163        #[cfg(feature = "metrics")]
164        let mut metrics_guard =
165            crate::metrics::MetricsGuard::new(req.method().to_string(), host.to_string());
166
167        #[cfg(feature = "logging")]
168        let mut log_guard = AccessLogGuard::new(
169            initial_request_id.clone(),
170            remote_addr,
171            req.method().to_string(),
172            path.clone(),
173            host.to_string(),
174        );
175
176        // Find site configuration by host
177        let site_config = match find_site(&config, host, is_tls) {
178            Some(config) => config,
179            None => {
180                error!("No configuration found for host: {}", host);
181                let (response, _body_len) = error_response_with_id(
182                    StatusCode::NOT_FOUND,
183                    &format!("No configuration found for host: {}", host),
184                    &initial_request_id,
185                );
186                #[cfg(feature = "logging")]
187                {
188                    log_guard.set_bytes_sent(_body_len);
189                    log_guard.finish(404);
190                }
191                #[cfg(feature = "metrics")]
192                metrics_guard.record(404);
193                return Ok(response);
194            }
195        };
196
197        // Process directives in correct order
198        let action_result = match process_directives(&site_config.directives, &mut req, &path) {
199            Ok(result) => result,
200            Err(e) => {
201                error!("Directive processing error: {}", e);
202                let final_id = final_request_id(&req, &initial_request_id);
203                #[cfg(feature = "logging")]
204                {
205                    log_guard.set_request_id(final_id.clone());
206                    tracing::Span::current().record("req_id", final_id.as_str());
207                }
208                let (response, _body_len) =
209                    error_response_with_id(StatusCode::INTERNAL_SERVER_ERROR, &e, &final_id);
210                #[cfg(feature = "logging")]
211                {
212                    log_guard.set_bytes_sent(_body_len);
213                    log_guard.finish(500);
214                }
215                #[cfg(feature = "metrics")]
216                metrics_guard.record(500);
217                return Ok(response);
218            }
219        };
220
221        // Read final request ID (directive may have overwritten X-Request-ID)
222        let request_id = final_request_id(&req, &initial_request_id);
223        #[cfg(feature = "logging")]
224        {
225            log_guard.set_request_id(request_id.clone());
226            // Update the tracing span with the final req_id
227            tracing::Span::current().record("req_id", request_id.as_str());
228        }
229
230        // Execute action
231        match action_result {
232            ActionResult::Redirect { status, url } => {
233                let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);
234
235                let boxed: ResponseBody = Full::new(Bytes::new())
236                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
237                    .boxed();
238                let response = Response::builder()
239                    .status(status_code)
240                    .header("Location", &url)
241                    .header("X-Request-ID", &request_id)
242                    .body(boxed)?;
243                #[cfg(feature = "logging")]
244                {
245                    log_guard.set_bytes_sent(0);
246                    log_guard.finish(status_code.as_u16());
247                }
248                #[cfg(feature = "metrics")]
249                metrics_guard.record(status_code.as_u16());
250                Ok(response)
251            }
252            ActionResult::Respond { status, body } => {
253                let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
254                let _body_len = body.len();
255
256                let boxed: ResponseBody = Full::new(Bytes::from(body))
257                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
258                    .boxed();
259                let response = Response::builder()
260                    .status(status_code)
261                    .header("X-Request-ID", &request_id)
262                    .body(boxed)?;
263                #[cfg(feature = "logging")]
264                {
265                    log_guard.set_bytes_sent(_body_len);
266                    log_guard.finish(status_code.as_u16());
267                }
268                #[cfg(feature = "metrics")]
269                metrics_guard.record(status_code.as_u16());
270                Ok(response)
271            }
272            ActionResult::ReverseProxy {
273                backend_url,
274                path_to_send,
275                connect_timeout: _,
276                read_timeout,
277                header_up,
278            } => {
279                // Add protocol if missing
280                let backend_with_proto =
281                    if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
282                        backend_url
283                    } else {
284                        format!("http://{}", backend_url)
285                    };
286
287                // Use Uri::from_parts() instead of format!() + parse() - faster!
288                let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
289                parts.path_and_query = Some(path_to_send.parse()?);
290                let new_uri = Uri::from_parts(parts)?;
291
292                // Capture the original request URI (path + query) before we overwrite it —
293                // needed for the {request.uri} placeholder in header_up.
294                let original_request_uri = req
295                    .uri()
296                    .path_and_query()
297                    .map(|pq| pq.as_str().to_string())
298                    .unwrap_or_default();
299
300                *req.uri_mut() = new_uri.clone();
301
302                // upstream_host is the authority of the backend URL, used for {upstream_host}.
303                let upstream_host = new_uri
304                    .authority()
305                    .map(|a| a.as_str().to_string())
306                    .unwrap_or_default();
307
308                // remote_ip: prefer X-Forwarded-For / X-Real-IP, fall back to the socket peer.
309                let remote_ip = crate::auth::headers::extract_remote_ip(&req)
310                    .unwrap_or_else(|| remote_addr.ip().to_string());
311
312                // Save original host for X-Forwarded headers
313                let original_host_header = req.headers().get(hyper::header::HOST).cloned();
314
315                // Update Host header for backend
316                req.headers_mut().remove(hyper::header::HOST);
317                if let Some(authority) = new_uri.authority() {
318                    if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>()
319                    {
320                        req.headers_mut().insert(hyper::header::HOST, host_value);
321                    }
322                }
323
324                // Add X-Forwarded-* headers for backend visibility
325                if let Some(host_value) = original_host_header.clone() {
326                    req.headers_mut().insert("X-Forwarded-Host", host_value);
327                }
328
329                // X-Forwarded-Proto: based on whether the connection is TLS
330                req.headers_mut().insert(
331                    "X-Forwarded-Proto",
332                    hyper::header::HeaderValue::from_static(if is_tls { "https" } else { "http" }),
333                );
334
335                // X-Forwarded-For: real client IP
336                if let Ok(ip_value) =
337                    hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
338                {
339                    req.headers_mut().insert("X-Forwarded-For", ip_value);
340                }
341
342                // Remove hop-by-hop headers
343                req.headers_mut().remove(header::CONNECTION);
344                req.headers_mut().remove("accept-encoding");
345
346                apply_header_up(
347                    &header_up,
348                    &mut req,
349                    &upstream_host,
350                    &original_request_uri,
351                    &remote_ip,
352                );
353
354                // Forward request to backend with configurable timeout (default 30s)
355                let backend_timeout = read_timeout.unwrap_or(30);
356                match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
357                    Ok(Ok(response)) => {
358                        let status = response.status();
359                        let headers = response.headers().clone();
360
361                        // Stream response body directly (no buffering)
362                        let mut builder = Response::builder().status(status);
363
364                        // Copy headers, filtering hop-by-hop
365                        for (name, value) in headers.iter() {
366                            if !is_hop_header(name) && name != header::CONTENT_LENGTH {
367                                builder = builder.header(name, value);
368                            }
369                        }
370
371                        let (_, incoming_body) = response.into_parts();
372                        let boxed: ResponseBody = incoming_body
373                            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
374                            .boxed();
375
376                        let response = builder.header("X-Request-ID", &request_id).body(boxed)?;
377                        #[cfg(feature = "logging")]
378                        log_guard.finish(status.as_u16());
379                        #[cfg(feature = "metrics")]
380                        metrics_guard.record(status.as_u16());
381                        Ok(response)
382                    }
383                    Ok(Err(e)) => {
384                        error!("Backend connection failed: {:?}", e);
385                        if e.is_connect() {
386                            error!("   Reason: Connection refused - backend unavailable");
387                        } else {
388                            error!("   Reason: Other connection error");
389                        }
390
391                        let (response, _body_len) = error_response_with_id(
392                            StatusCode::BAD_GATEWAY,
393                            "Backend service unavailable",
394                            &request_id,
395                        );
396                        #[cfg(feature = "logging")]
397                        {
398                            log_guard.set_bytes_sent(_body_len);
399                            log_guard.finish(502);
400                        }
401                        #[cfg(feature = "metrics")]
402                        metrics_guard.record(502);
403                        Ok(response)
404                    }
405                    Err(_) => {
406                        error!(
407                            "Backend request timed out after {} seconds",
408                            backend_timeout
409                        );
410
411                        let (response, _body_len) = error_response_with_id(
412                            StatusCode::GATEWAY_TIMEOUT,
413                            "Backend request timed out",
414                            &request_id,
415                        );
416                        #[cfg(feature = "logging")]
417                        {
418                            log_guard.set_bytes_sent(_body_len);
419                            log_guard.finish(504);
420                        }
421                        #[cfg(feature = "metrics")]
422                        metrics_guard.record(504);
423                        Ok(response)
424                    }
425                }
426            }
427        }
428    };
429
430    #[cfg(feature = "logging")]
431    let future = future.instrument(span);
432
433    future.await
434}
435
436/// Creates HTTP response with error and X-Request-ID header
437///
438/// Returns both the response and the body length (for access logging).
439fn error_response_with_id(
440    status: StatusCode,
441    message: &str,
442    request_id: &str,
443) -> (Response<ResponseBody>, usize) {
444    let body = format!(
445        r#"<!DOCTYPE html>
446        <html>
447        <head><title>{} {}</title></head>
448        <body>
449        <h1>{} {}</h1>
450        <p>{}</p>
451        <hr>
452        <p><em>Rust Proxy Server</em></p>
453        </body>
454        </html>"#,
455        status.as_u16(),
456        status.canonical_reason().unwrap_or("Error"),
457        status.as_u16(),
458        status.canonical_reason().unwrap_or("Error"),
459        message
460    );
461
462    let body_len = body.len();
463    let full = Full::new(Bytes::from(body));
464    let boxed: ResponseBody = full
465        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
466        .boxed();
467
468    let mut builder = Response::builder()
469        .status(status)
470        .header("Content-Type", "text/html; charset=utf-8");
471
472    if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
473        builder = builder.header("X-Request-ID", val);
474    }
475
476    let response = builder.body(boxed).unwrap_or_else(|_| {
477        Response::new(
478            Full::new(Bytes::from("Internal Server Error"))
479                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
480                .boxed(),
481        )
482    });
483
484    (response, body_len)
485}
486
487/// Match path against pattern (supports wildcard *)
488/// Returns Some(remaining_path) if match, None otherwise
489pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
490    if let Some(prefix) = pattern.strip_suffix("/*") {
491        if path.starts_with(prefix) {
492            let remaining = path.strip_prefix(prefix).unwrap_or(path);
493            Some(remaining.to_string())
494        } else {
495            None
496        }
497    } else if pattern == path {
498        Some("/".to_string())
499    } else {
500        None
501    }
502}
503
504/// Find a site configuration matching the given Host header value.
505///
506/// Browsers on default ports omit the port from the Host header:
507/// - HTTPS on 443 → `Host: example.com` (no `:443`)
508/// - HTTP on 80 → `Host: example.com` (no `:80`)
509///
510/// But config keys include the port: `"example.com:443"`, `"example.com:80"`.
511///
512/// Look up a site by Host header value.
513///
514/// Tries, in order:
515/// 1. Exact match on the raw host string.
516/// 2. If host has no port → append default port (443 for TLS, 80 for HTTP) and retry.
517/// 3. For TLS, match by SNI hostname if exactly one site matches.
518/// 4. If host has a port → strip it and try the bare hostname.
519///
520/// # Limitations
521///
522/// For non-default TLS ports (e.g., 8443), browsers include the port in `Host`
523/// (`Host: example.com:8443`), so exact match works. A client sending `Host: example.com`
524/// without a port to a TLS listener on :8443 gets 404 — that violates normal HTTP usage.
525pub fn find_site<'a>(config: &'a Config, host: &str, is_tls: bool) -> Option<&'a SiteConfig> {
526    // 1. Exact match
527    if let Some(site) = config.sites.get(host) {
528        return Some(site);
529    }
530
531    // Determine if host already contains a port
532    // IPv6: [::1]:8080 — port is after the last ']' + ':'
533    // IPv4/hostname: example.com:443
534    let has_port = if host.starts_with('[') {
535        // IPv6 — look for port after ']'
536        if let Some(bracket_end) = host.find(']') {
537            host[bracket_end..].contains(':')
538        } else {
539            false
540        }
541    } else {
542        host.contains(':')
543    };
544
545    if !has_port {
546        // 2. Host has no port — try appending the default port
547        let default_port = if is_tls { 443 } else { 80 };
548        let candidate = format!("{}:{}", host, default_port);
549        if let Some(site) = config.sites.get(&candidate) {
550            return Some(site);
551        }
552
553        // 3. TLS on a non-standard port — match by SNI hostname if unambiguous
554        if is_tls {
555            let mut matches = config.sites.values().filter(|s| {
556                s.tls.is_some() && extract_hostname(&s.address).eq_ignore_ascii_case(host)
557            });
558            if let Some(site) = matches.next() {
559                if matches.next().is_none() {
560                    return Some(site);
561                }
562            }
563        }
564    } else {
565        // 4. Host has a port — try just the hostname (strip port)
566        let hostname = if host.starts_with('[') {
567            // IPv6 [::1]:port → ::1
568            let end = host.find(']').unwrap_or(host.len());
569            &host[1..end]
570        } else {
571            // example.com:443 → example.com
572            host.rsplit_once(':').map(|(name, _)| name).unwrap_or(host)
573        };
574        if let Some(site) = config.sites.get(hostname) {
575            return Some(site);
576        }
577    }
578
579    None
580}
581
582#[cfg(test)]
583mod find_site_tests {
584    use super::*;
585    use std::collections::HashMap;
586
587    fn make_config(sites: Vec<(&str, bool)>) -> Config {
588        let mut map = HashMap::new();
589        for (addr, has_tls) in sites {
590            map.insert(
591                addr.to_string(),
592                crate::config::SiteConfig {
593                    address: addr.to_string(),
594                    directives: vec![],
595                    tls: if has_tls {
596                        Some(crate::config::TlsConfig {
597                            cert_path: "/fake/cert.pem".to_string(),
598                            key_path: "/fake/key.pem".to_string(),
599                        })
600                    } else {
601                        None
602                    },
603                },
604            );
605        }
606        Config { sites: map }
607    }
608
609    #[test]
610    fn test_exact_match() {
611        let config = make_config(vec![("example.com:443", true)]);
612        assert!(find_site(&config, "example.com:443", true).is_some());
613    }
614
615    #[test]
616    fn test_tls_host_without_port_finds_443() {
617        let config = make_config(vec![("example.com:443", true)]);
618        // Browser sends Host: example.com (no :443) on HTTPS
619        assert!(
620            find_site(&config, "example.com", true).is_some(),
621            "Should find example.com:443 when Host has no port and is_tls=true"
622        );
623    }
624
625    #[test]
626    fn test_http_host_without_port_finds_80() {
627        let config = make_config(vec![("example.com:80", false)]);
628        // Browser sends Host: example.com (no :80) on HTTP
629        assert!(
630            find_site(&config, "example.com", false).is_some(),
631            "Should find example.com:80 when Host has no port and is_tls=false"
632        );
633    }
634
635    #[test]
636    fn test_tls_host_without_port_no_match_on_80() {
637        let config = make_config(vec![("example.com:80", false)]);
638        // Host: example.com on TLS should NOT match :80
639        assert!(
640            find_site(&config, "example.com", true).is_none(),
641            "TLS on port 443 should not find :80 site"
642        );
643    }
644
645    #[test]
646    fn test_host_with_port_strips_port_fallback() {
647        let config = make_config(vec![("example.com", false)]);
648        // Config has "example.com" (no port), Host has "example.com:8080"
649        assert!(
650            find_site(&config, "example.com:8080", false).is_some(),
651            "Should strip port from Host and find config without port"
652        );
653    }
654
655    #[test]
656    fn test_tls_host_without_port_finds_non_standard_port() {
657        let config = make_config(vec![("alpha.local:8443", true)]);
658        assert!(
659            find_site(&config, "alpha.local", true).is_some(),
660            "Should find alpha.local:8443 when Host has no port on TLS"
661        );
662    }
663
664    #[test]
665    fn test_no_match() {
666        let config = make_config(vec![("other.com:443", true)]);
667        assert!(find_site(&config, "example.com", true).is_none());
668    }
669}