Skip to main content

rust_web_server/proxy/
mod.rs

1//! Reverse proxy middleware with round-robin load balancing.
2//!
3//! `ReverseProxy` implements [`Middleware`] — wrap any application with it and
4//! all matching requests are forwarded to one of the configured backends over
5//! plain HTTP/1.1.  Failed backends are skipped and the next one is tried
6//! before returning `502 Bad Gateway`.
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! use rust_web_server::app::App;
12//! use rust_web_server::core::New;
13//! use rust_web_server::proxy::{LoadBalancing, ReverseProxy};
14//!
15//! // Proxy every request across two backends in round-robin order.
16//! let app = App::new()
17//!     .wrap(ReverseProxy::new(["http://backend-1:8080", "http://backend-2:8080"])
18//!         .strategy(LoadBalancing::RoundRobin));
19//!
20//! // Only proxy /api/* requests; everything else is handled locally.
21//! let app2 = App::new()
22//!     .wrap(ReverseProxy::new(["http://api-service:3000"])
23//!         .path_prefix("/api"));
24//! ```
25
26pub mod pool;
27
28#[cfg(test)]
29mod tests;
30
31use std::io::{Read, Write};
32use std::net::{TcpStream, ToSocketAddrs};
33use std::sync::atomic::{AtomicUsize, Ordering};
34use std::sync::Arc;
35use std::time::Duration;
36
37pub use pool::ConnPool;
38
39use crate::application::Application;
40use crate::core::New;
41use crate::middleware::Middleware;
42use crate::mime_type::MimeType;
43use crate::range::Range;
44use crate::request::Request;
45use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
46use crate::server::ConnectionInfo;
47
48// Hop-by-hop headers that must not be forwarded (RFC 7230 §6.1)
49const HOP_BY_HOP: &[&str] = &[
50    "connection",
51    "keep-alive",
52    "proxy-authenticate",
53    "proxy-authorization",
54    "te",
55    "trailers",
56    "transfer-encoding",
57    "upgrade",
58];
59
60/// Load balancing strategy used by [`ReverseProxy`].
61pub enum LoadBalancing {
62    /// Distribute requests across backends in a cyclic order.
63    RoundRobin,
64}
65
66/// Reverse proxy middleware.
67///
68/// Forwards incoming requests to one of the configured backends over HTTP/1.1.
69/// On connection failure the next backend in the list is tried; when all
70/// backends have failed the middleware returns `502 Bad Gateway`.
71///
72/// Hop-by-hop headers are stripped before forwarding.  `X-Forwarded-For` and
73/// `Via` are added to every forwarded request.
74///
75/// Idle connections are pooled and reused across requests (up to
76/// [`ConnPool::new_default`] limits: 8 idle per backend, 60-second timeout).
77/// This eliminates per-request TCP handshake overhead and ephemeral-port
78/// exhaustion.  Use [`ReverseProxy::with_pool`] to share a pool across
79/// multiple proxy instances or to tune pool parameters.
80pub struct ReverseProxy {
81    backends: Vec<Backend>,
82    path_prefix: Option<String>,
83    connect_timeout: Duration,
84    read_timeout: Duration,
85    counter: AtomicUsize,
86    pool: Arc<ConnPool>,
87}
88
89impl ReverseProxy {
90    /// Create a proxy that distributes requests across `backends` in
91    /// round-robin order.  Each entry must be `"http://host:port"` or
92    /// `"host:port"` (port defaults to 80).
93    pub fn new<I, S>(backends: I) -> Self
94    where
95        I: IntoIterator<Item = S>,
96        S: AsRef<str>,
97    {
98        Self {
99            backends: backends
100                .into_iter()
101                .filter_map(|u| Backend::parse(u.as_ref()))
102                .collect(),
103            path_prefix: None,
104            connect_timeout: Duration::from_secs(5),
105            read_timeout: Duration::from_secs(30),
106            counter: AtomicUsize::new(0),
107            pool: Arc::new(ConnPool::new_default()),
108        }
109    }
110
111    /// Only proxy requests whose URI starts with `prefix`.
112    ///
113    /// Other requests are passed through to the next layer in the middleware
114    /// chain (or the inner application).
115    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
116        self.path_prefix = Some(prefix.into());
117        self
118    }
119
120    /// Override the load balancing strategy (currently only `RoundRobin`).
121    pub fn strategy(self, _strategy: LoadBalancing) -> Self {
122        self
123    }
124
125    /// Override the TCP connect timeout (default: 5 000 ms).
126    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
127        self.connect_timeout = Duration::from_millis(ms);
128        self
129    }
130
131    /// Override the response read timeout (default: 30 000 ms).
132    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
133        self.read_timeout = Duration::from_millis(ms);
134        self
135    }
136
137    /// Attach a shared connection pool.
138    ///
139    /// Useful for sharing one pool across multiple `ReverseProxy` instances
140    /// or for tuning pool parameters (capacity, idle timeout).
141    pub fn with_pool(mut self, pool: Arc<ConnPool>) -> Self {
142        self.pool = pool;
143        self
144    }
145
146    /// Set the maximum number of idle connections per backend (default: 8).
147    pub fn max_idle_conns(mut self, n: usize) -> Self {
148        self.pool = Arc::new(ConnPool::new(n, Duration::from_secs(60)));
149        self
150    }
151
152    fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
153        if self.backends.is_empty() {
154            return Err("no backends configured".to_string());
155        }
156        let n = self.backends.len();
157        let start = self.counter.fetch_add(1, Ordering::Relaxed);
158        for attempt in 0..n {
159            let idx = (start + attempt) % n;
160            match self.try_backend(request, connection, &self.backends[idx]) {
161                Ok(resp) => return Ok(resp),
162                Err(_) if attempt + 1 < n => continue,
163                Err(e) => return Err(e),
164            }
165        }
166        Err("all backends failed".to_string())
167    }
168
169    fn try_backend(
170        &self,
171        request: &Request,
172        connection: &ConnectionInfo,
173        backend: &Backend,
174    ) -> Result<Response, String> {
175        let key = format!("{}:{}", backend.host, backend.port);
176
177        // Try a pooled connection first; fall back to a fresh one.
178        let stream = if let Some(pooled) = self.pool.acquire(&key) {
179            pooled
180        } else {
181            let addr_str = key.as_str();
182            let sock_addr = addr_str
183                .to_socket_addrs()
184                .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
185                .next()
186                .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
187            TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
188                .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?
189        };
190
191        stream.set_read_timeout(Some(self.read_timeout)).map_err(|e| e.to_string())?;
192        stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
193
194        // keep_alive = true: send Connection: keep-alive so the server holds
195        // the connection open after responding.
196        let req_bytes = build_request(request, &backend.host, &connection.client.ip, true);
197        let mut stream = stream;
198        stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
199
200        let mut tmp = [0u8; 4096];
201        let (header_bytes, body_prefix) = read_headers_only(&mut stream, &mut tmp)?;
202        let header_lower =
203            std::str::from_utf8(&header_bytes).unwrap_or("").to_ascii_lowercase();
204
205        if should_stream_response(&header_lower) {
206            // Streaming path — pipe bytes straight to the client.
207            // The connection cannot be reused while the body is in flight.
208            let mut resp = parse_status_and_headers(&header_bytes)?;
209            resp.stream_pipe =
210                Some(Box::new(ConcatReader::new(body_prefix, stream)));
211            Ok(resp)
212        } else {
213            // Buffered path — read the full body, then optionally return the
214            // connection to the pool.
215            let (resp_bytes, reusable) =
216                read_response_from_partial(&mut stream, header_bytes, body_prefix, &mut tmp)?;
217            if reusable {
218                self.pool.release(&key, stream);
219            }
220            Response::parse(&resp_bytes)
221        }
222    }
223}
224
225impl Middleware for ReverseProxy {
226    fn handle(
227        &self,
228        request: &Request,
229        connection: &ConnectionInfo,
230        next: &dyn Application,
231    ) -> Result<Response, String> {
232        if let Some(prefix) = &self.path_prefix {
233            if !request.request_uri.starts_with(prefix.as_str()) {
234                return next.execute(request, connection);
235            }
236        }
237        match self.proxy(request, connection) {
238            Ok(resp) => Ok(resp),
239            Err(_) => Ok(bad_gateway()),
240        }
241    }
242}
243
244// ── helpers ───────────────────────────────────────────────────────────────────
245
246pub(crate) fn build_request(
247    request: &Request,
248    backend_host: &str,
249    client_ip: &str,
250    keep_alive: bool,
251) -> Vec<u8> {
252    let mut out: Vec<u8> = Vec::new();
253    let _ = write!(
254        out,
255        "{} {} HTTP/1.1\r\nHost: {}\r\n",
256        request.method, request.request_uri, backend_host
257    );
258    for h in &request.headers {
259        let lower = h.name.to_lowercase();
260        if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
261            continue;
262        }
263        let _ = write!(out, "{}: {}\r\n", h.name, h.value);
264    }
265    let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
266    let _ = write!(out, "Via: 1.1 rws\r\n");
267    if keep_alive {
268        let _ = write!(out, "Connection: keep-alive\r\n");
269    } else {
270        let _ = write!(out, "Connection: close\r\n");
271    }
272    if !request.body.is_empty() {
273        let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
274    }
275    let _ = write!(out, "\r\n");
276    out.extend_from_slice(&request.body);
277    out
278}
279
280/// Decode HTTP/1.1 chunked transfer-encoding from `stream`.
281///
282/// `buf[header_end..]` may already contain some body bytes that arrived in
283/// the same read as the headers.  Returns the fully decoded body.
284fn decode_chunked(
285    stream: &mut TcpStream,
286    buf: &[u8],
287    header_end: usize,
288    tmp: &mut [u8],
289) -> Result<Vec<u8>, String> {
290    // Seed `raw` with any body bytes already buffered alongside the headers.
291    let mut raw: Vec<u8> = buf[header_end..].to_vec();
292    let mut decoded: Vec<u8> = Vec::new();
293
294    loop {
295        // Wait until we have at least one complete chunk-size line (ends with \r\n).
296        let crlf = loop {
297            if let Some(p) = raw.windows(2).position(|w| w == b"\r\n") {
298                break p;
299            }
300            let n = stream.read(tmp).map_err(|e| e.to_string())?;
301            if n == 0 {
302                return Err("chunked: premature EOF reading chunk size".to_string());
303            }
304            raw.extend_from_slice(&tmp[..n]);
305        };
306
307        // Chunk size is hex, optionally followed by chunk-extensions (";…").
308        let size_line = std::str::from_utf8(&raw[..crlf])
309            .map_err(|_| "chunked: non-UTF-8 chunk size line".to_string())?;
310        let size_str = size_line.split(';').next().unwrap_or("").trim();
311        let chunk_size = usize::from_str_radix(size_str, 16)
312            .map_err(|_| format!("chunked: invalid chunk size '{}'", size_str))?;
313        raw.drain(..crlf + 2); // consume "<size>\r\n"
314
315        if chunk_size == 0 {
316            // Last chunk — consume the trailing CRLF ("0\r\n\r\n" → trailing "\r\n" still pending).
317            while raw.len() < 2 {
318                let n = stream.read(tmp).map_err(|e| e.to_string())?;
319                if n == 0 {
320                    break;
321                }
322                raw.extend_from_slice(&tmp[..n]);
323            }
324            break;
325        }
326
327        // Read chunk data + trailing CRLF.
328        while raw.len() < chunk_size + 2 {
329            let n = stream.read(tmp).map_err(|e| e.to_string())?;
330            if n == 0 {
331                return Err("chunked: premature EOF reading chunk body".to_string());
332            }
333            raw.extend_from_slice(&tmp[..n]);
334        }
335        decoded.extend_from_slice(&raw[..chunk_size]);
336        raw.drain(..chunk_size + 2); // consume "<data>\r\n"
337    }
338
339    Ok(decoded)
340}
341
342/// Rewrite `buf` in-place: strip `Transfer-Encoding`, add `Content-Length`,
343/// replace the old (undecoded) body with `decoded`.
344fn rewrite_as_content_length(buf: &mut Vec<u8>, header_end: usize, decoded: &[u8]) {
345    let header_str = std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_string();
346    buf.clear();
347    for line in header_str.lines() {
348        if line.to_ascii_lowercase().starts_with("transfer-encoding:") || line.is_empty() {
349            continue;
350        }
351        buf.extend_from_slice(line.as_bytes());
352        buf.extend_from_slice(b"\r\n");
353    }
354    let _ = write!(buf, "Content-Length: {}\r\n\r\n", decoded.len());
355    buf.extend_from_slice(decoded);
356}
357
358/// Non-pooled version of the response reader, used by callers that send
359/// `Connection: close` (e.g. `proxy_http1`, `proxy_https1`).
360pub(crate) fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
361    read_response_from(stream)
362}
363
364pub(crate) fn read_response_from<R: Read>(stream: &mut R) -> Result<Vec<u8>, String> {
365    let mut buf: Vec<u8> = Vec::with_capacity(8192);
366    let mut tmp = [0u8; 4096];
367
368    let header_end = loop {
369        let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
370        if n == 0 {
371            return if buf.is_empty() {
372                Err("backend closed connection without sending a response".to_string())
373            } else {
374                Ok(buf)
375            };
376        }
377        buf.extend_from_slice(&tmp[..n]);
378        if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
379            break pos + 4;
380        }
381    };
382
383    let content_length = std::str::from_utf8(&buf[..header_end])
384        .unwrap_or("")
385        .lines()
386        .find_map(|line| {
387            line.to_lowercase()
388                .starts_with("content-length:")
389                .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
390                .flatten()
391        });
392
393    match content_length {
394        Some(len) => {
395            while buf.len() < header_end + len {
396                let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
397                if n == 0 {
398                    break;
399                }
400                buf.extend_from_slice(&tmp[..n]);
401            }
402        }
403        None => loop {
404            let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
405            if n == 0 {
406                break;
407            }
408            buf.extend_from_slice(&tmp[..n]);
409        },
410    }
411
412    Ok(buf)
413}
414
415/// Forward a single HTTP/1.1 request to `host:port` and return the response.
416///
417/// This is the shared low-level building block used by [`crate::canary`] and
418/// [`crate::ingress`] so they don't have to duplicate the TCP + request/response
419/// marshalling code.
420pub(crate) fn proxy_http1(
421    request: &Request,
422    client_ip: &str,
423    host: &str,
424    port: u16,
425    connect_timeout: Duration,
426    read_timeout: Duration,
427) -> Result<Response, String> {
428    use std::net::ToSocketAddrs;
429    let addr_str = format!("{}:{}", host, port);
430    let sock_addr = addr_str
431        .to_socket_addrs()
432        .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
433        .next()
434        .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
435    let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
436        .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
437    stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
438    stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
439    let req_bytes = build_request(request, host, client_ip, false);
440    let mut stream = stream;
441    stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
442    let resp_bytes = read_response(&mut stream)?;
443    Response::parse(&resp_bytes)
444}
445
446/// Forward a single HTTPS/1.1 request to `host:port` over TLS and return the
447/// response. Requires the `http-client` or `http2` feature (both bring in
448/// `rustls` + `webpki-roots`).
449#[cfg(any(feature = "http-client", feature = "http2"))]
450pub(crate) fn proxy_https1(
451    request: &Request,
452    client_ip: &str,
453    host: &str,
454    port: u16,
455    connect_timeout: Duration,
456    read_timeout: Duration,
457) -> Result<Response, String> {
458    use rustls::pki_types::ServerName;
459    use rustls::ClientConfig;
460    use std::net::ToSocketAddrs;
461    use std::sync::Arc;
462
463    let addr_str = format!("{}:{}", host, port);
464    let sock_addr = addr_str
465        .to_socket_addrs()
466        .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
467        .next()
468        .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
469
470    let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
471        .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
472    stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
473    stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
474
475    let root_store =
476        rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
477    let config = Arc::new(
478        ClientConfig::builder()
479            .with_root_certificates(root_store)
480            .with_no_client_auth(),
481    );
482    let server_name = ServerName::try_from(host.to_string())
483        .map_err(|e| format!("invalid upstream hostname '{}': {}", host, e))?;
484    let conn = rustls::ClientConnection::new(config, server_name).map_err(|e| e.to_string())?;
485    let mut tls = rustls::StreamOwned::new(conn, stream);
486
487    let req_bytes = build_request(request, host, client_ip, false);
488    tls.write_all(&req_bytes)
489        .map_err(|e| format!("write to upstream failed: {}", e))?;
490
491    let resp_bytes = read_response_from(&mut tls)?;
492    Response::parse(&resp_bytes)
493}
494
495fn bad_gateway() -> Response {
496    let cr = Range::get_content_range(
497        b"502 Bad Gateway".to_vec(),
498        MimeType::TEXT_PLAIN.to_string(),
499    );
500    let mut r = Response::new();
501    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
502    r.reason_phrase = STATUS_CODE_REASON_PHRASE
503        .n502_bad_gateway
504        .reason_phrase
505        .to_string();
506    r.content_range_list = vec![cr];
507    r
508}
509
510// ── Streaming proxy helpers ───────────────────────────────────────────────────
511
512/// Responses larger than this threshold are streamed instead of buffered.
513const STREAM_THRESHOLD: usize = 1024 * 1024; // 1 MB
514
515/// A `Read` implementation that drains a prefix buffer before reading from the
516/// inner stream. Used to replay body bytes that arrived with the HTTP headers.
517pub(crate) struct ConcatReader<R: Read + Send> {
518    prefix: Vec<u8>,
519    prefix_pos: usize,
520    inner: R,
521}
522
523impl<R: Read + Send> ConcatReader<R> {
524    fn new(prefix: Vec<u8>, inner: R) -> Self {
525        ConcatReader { prefix, prefix_pos: 0, inner }
526    }
527}
528
529impl<R: Read + Send> Read for ConcatReader<R> {
530    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
531        if self.prefix_pos < self.prefix.len() {
532            let avail = &self.prefix[self.prefix_pos..];
533            let n = buf.len().min(avail.len());
534            buf[..n].copy_from_slice(&avail[..n]);
535            self.prefix_pos += n;
536            return Ok(n);
537        }
538        self.inner.read(buf)
539    }
540}
541
542/// Read exactly the HTTP response headers (up to and including `\r\n\r\n`).
543///
544/// Returns `(header_bytes, body_prefix)` where `body_prefix` contains any
545/// body bytes that arrived in the same TCP segment as the headers.
546fn read_headers_only(stream: &mut TcpStream, tmp: &mut [u8]) -> Result<(Vec<u8>, Vec<u8>), String> {
547    let mut buf: Vec<u8> = Vec::with_capacity(4096);
548    loop {
549        let n = stream.read(tmp).map_err(|e| e.to_string())?;
550        if n == 0 {
551            return Err("backend closed connection before headers were complete".to_string());
552        }
553        buf.extend_from_slice(&tmp[..n]);
554        if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
555            let body_prefix = buf[pos + 4..].to_vec();
556            buf.truncate(pos + 4);
557            return Ok((buf, body_prefix));
558        }
559    }
560}
561
562/// Returns `true` when the response should be streamed rather than buffered.
563///
564/// Streams when any of the following hold:
565/// - `Content-Type: text/event-stream` — SSE
566/// - `Transfer-Encoding: chunked` — AI token streams, etc.
567/// - `Content-Length` exceeds 1 MB — large file downloads
568pub(crate) fn should_stream_response(header_lower: &str) -> bool {
569    let is_sse = header_lower.lines().any(|l| {
570        l.starts_with("content-type:") && l.contains("text/event-stream")
571    });
572    let is_chunked = header_lower.lines().any(|l| {
573        l.starts_with("transfer-encoding:") && l.contains("chunked")
574    });
575    let content_length: Option<usize> = header_lower.lines().find_map(|l| {
576        l.strip_prefix("content-length:")?.trim().parse().ok()
577    });
578    let is_large = content_length.map_or(false, |n| n > STREAM_THRESHOLD);
579    is_sse || is_chunked || is_large
580}
581
582/// Parse the status line and headers from raw header bytes (ending at `\r\n\r\n`).
583fn parse_status_and_headers(header_bytes: &[u8]) -> Result<Response, String> {
584    let s = std::str::from_utf8(header_bytes)
585        .map_err(|e| format!("non-UTF-8 response headers: {}", e))?;
586    let mut lines = s.lines();
587    let status_line = lines.next().ok_or("empty backend response")?;
588    let mut parts = status_line.splitn(3, ' ');
589    let http_version = parts.next().unwrap_or("HTTP/1.1").to_string();
590    let status_code: i16 = parts
591        .next()
592        .unwrap_or("502")
593        .parse()
594        .map_err(|_| format!("invalid status code in '{}'", status_line))?;
595    let reason_phrase = parts.next().unwrap_or("").trim_end_matches('\r').to_string();
596    let mut headers = Vec::new();
597    for line in lines {
598        let line = line.trim_end_matches('\r');
599        if line.is_empty() { break; }
600        if let Some(colon) = line.find(':') {
601            headers.push(crate::header::Header {
602                name: line[..colon].trim().to_string(),
603                value: line[colon + 1..].trim().to_string(),
604            });
605        }
606    }
607    Ok(Response {
608        http_version,
609        status_code,
610        reason_phrase,
611        headers,
612        content_range_list: vec![],
613        stream_file: None,
614        stream_pipe: None,
615    })
616}
617
618/// Read the remaining body after headers have already been read.
619///
620/// `header_bytes` ends with `\r\n\r\n`; `body_prefix` holds any body bytes
621/// that arrived in the same TCP read.  Handles all three body mechanisms
622/// (chunked, content-length, read-to-EOF).  Returns `(full_response_bytes, can_reuse)`.
623fn read_response_from_partial(
624    stream: &mut TcpStream,
625    header_bytes: Vec<u8>,
626    body_prefix: Vec<u8>,
627    tmp: &mut [u8],
628) -> Result<(Vec<u8>, bool), String> {
629    let header_end = header_bytes.len();
630    let mut buf = header_bytes;
631    buf.extend_from_slice(&body_prefix);
632
633    let header_lower =
634        std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_ascii_lowercase();
635    let connection_close =
636        header_lower.lines().any(|l| l.starts_with("connection:") && l.contains("close"));
637    let is_chunked = header_lower
638        .lines()
639        .any(|l| l.starts_with("transfer-encoding:") && l.contains("chunked"));
640    let content_length: Option<usize> = header_lower.lines().find_map(|l| {
641        l.strip_prefix("content-length:")?.trim().parse().ok()
642    });
643
644    if is_chunked {
645        let decoded = decode_chunked(stream, &buf, header_end, tmp)?;
646        rewrite_as_content_length(&mut buf, header_end, &decoded);
647        Ok((buf, !connection_close))
648    } else if let Some(len) = content_length {
649        while buf.len() < header_end + len {
650            let n = stream.read(tmp).map_err(|e| e.to_string())?;
651            if n == 0 { break; }
652            buf.extend_from_slice(&tmp[..n]);
653        }
654        Ok((buf, !connection_close))
655    } else {
656        loop {
657            let n = stream.read(tmp).map_err(|e| e.to_string())?;
658            if n == 0 { break; }
659            buf.extend_from_slice(&tmp[..n]);
660        }
661        Ok((buf, false))
662    }
663}
664
665// ── Backend URL parsing ───────────────────────────────────────────────────────
666
667struct Backend {
668    host: String,
669    port: u16,
670    /// Whether the upstream connection should use TLS.
671    /// Set when the URL scheme is `https://`, `h2s://`, or `grpcs://`.
672    #[cfg_attr(not(feature = "http2"), allow(dead_code))]
673    tls: bool,
674}
675
676impl Backend {
677    fn parse(url: &str) -> Option<Self> {
678        let (rest, tls, default_port) = if let Some(r) = url.strip_prefix("https://") {
679            (r, true, 443u16)
680        } else if let Some(r) = url.strip_prefix("h2s://") {
681            (r, true, 443u16)
682        } else if let Some(r) = url.strip_prefix("grpcs://") {
683            (r, true, 443u16)
684        } else if let Some(r) = url.strip_prefix("http://") {
685            (r, false, 80u16)
686        } else if let Some(r) = url.strip_prefix("h2://") {
687            (r, false, 80u16)
688        } else if let Some(r) = url.strip_prefix("grpc://") {
689            (r, false, 80u16)
690        } else {
691            (url, false, 80u16)
692        };
693        // Drop any path component.
694        let host_port = rest.split('/').next().unwrap_or(rest);
695        let (host, port) = if let Some(colon) = host_port.rfind(':') {
696            let port_str = &host_port[colon + 1..];
697            if let Ok(p) = port_str.parse::<u16>() {
698                (host_port[..colon].to_string(), p)
699            } else {
700                (host_port.to_string(), default_port)
701            }
702        } else {
703            (host_port.to_string(), default_port)
704        };
705        if host.is_empty() {
706            return None;
707        }
708        Some(Backend { host, port, tls })
709    }
710}
711
712// ── HTTP/2 reverse proxy ──────────────────────────────────────────────────────
713
714/// Reverse proxy that forwards requests to HTTP/2 backends.
715///
716/// Wraps [`ReverseProxy`] and forces HTTP/2 (`h2`) for all upstream connections.
717/// Requires the `http2` Cargo feature.
718///
719/// This proxy also transparently handles gRPC traffic
720/// (`Content-Type: application/grpc*`) — gRPC DATA frames are forwarded
721/// as-is because gRPC is layered directly on HTTP/2.
722#[cfg(feature = "http2")]
723pub struct H2ReverseProxy {
724    inner: ReverseProxy,
725}
726
727#[cfg(feature = "http2")]
728impl H2ReverseProxy {
729    /// Create a proxy distributing requests across `backends` in round-robin order.
730    ///
731    /// Each backend entry can be:
732    /// - `"host:port"` — plain TCP (HTTP/2 cleartext)
733    /// - `"h2://host:port"` — plain TCP (explicit scheme)
734    /// - `"h2s://host:port"` — TLS (HTTP/2 over HTTPS; port defaults to 443)
735    /// - `"https://host:port"` — TLS (same as `h2s://`)
736    ///
737    /// TLS backends require the `http2` Cargo feature (includes `rustls` +
738    /// `webpki-roots`).  Certificate verification uses the WebPKI trust store.
739    pub fn new<I, S>(backends: I) -> Self
740    where
741        I: IntoIterator<Item = S>,
742        S: AsRef<str>,
743    {
744        H2ReverseProxy {
745            inner: ReverseProxy::new(backends),
746        }
747    }
748
749    /// Only proxy requests whose URI starts with `prefix`; pass others through.
750    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
751        self.inner = self.inner.path_prefix(prefix);
752        self
753    }
754
755    /// Override the TCP connect timeout (default: 5 s).
756    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
757        self.inner = self.inner.connect_timeout_ms(ms);
758        self
759    }
760
761    /// Override the response read timeout (default: 30 s).
762    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
763        self.inner = self.inner.read_timeout_ms(ms);
764        self
765    }
766}
767
768#[cfg(feature = "http2")]
769impl crate::middleware::Middleware for H2ReverseProxy {
770    fn handle(
771        &self,
772        request: &crate::request::Request,
773        connection: &crate::server::ConnectionInfo,
774        next: &dyn crate::application::Application,
775    ) -> Result<crate::response::Response, String> {
776        if let Some(prefix) = &self.inner.path_prefix {
777            if !request.request_uri.starts_with(prefix.as_str()) {
778                return next.execute(request, connection);
779            }
780        }
781        if self.inner.backends.is_empty() {
782            return Ok(bad_gateway());
783        }
784        let n = self.inner.backends.len();
785        let start = self.inner.counter.fetch_add(1, Ordering::Relaxed);
786        for attempt in 0..n {
787            let idx = (start + attempt) % n;
788            match try_backend_h2(request, &connection.client.ip, &self.inner.backends[idx],
789                                  self.inner.connect_timeout, self.inner.read_timeout) {
790                Ok(resp) => return Ok(resp),
791                Err(_) if attempt + 1 < n => continue,
792                Err(_) => break,
793            }
794        }
795        Ok(bad_gateway())
796    }
797}
798
799#[cfg(feature = "http2")]
800fn try_backend_h2(
801    request: &Request,
802    client_ip: &str,
803    backend: &Backend,
804    connect_timeout: Duration,
805    _read_timeout: Duration,
806) -> Result<Response, String> {
807    use tokio::runtime::Handle;
808    match Handle::try_current() {
809        Ok(_) => tokio::task::block_in_place(|| {
810            Handle::current().block_on(forward_h2_async(request, client_ip, backend, connect_timeout))
811        }),
812        Err(_) => {
813            Err("no async runtime for H2 proxy; falling back to 502".to_string())
814        }
815    }
816}
817
818#[cfg(feature = "http2")]
819async fn forward_h2_async(
820    request: &Request,
821    client_ip: &str,
822    backend: &Backend,
823    connect_timeout: Duration,
824) -> Result<Response, String> {
825    let addr = format!("{}:{}", backend.host, backend.port);
826    let tcp = tokio::time::timeout(
827        connect_timeout,
828        tokio::net::TcpStream::connect(&addr),
829    )
830    .await
831    .map_err(|_| format!("h2 proxy: connect to {} timed out", addr))?
832    .map_err(|e| format!("h2 proxy: connect to {} failed: {}", addr, e))?;
833
834    if backend.tls {
835        use rustls::pki_types::ServerName;
836        use rustls::ClientConfig;
837        use std::sync::Arc;
838        use tokio_rustls::TlsConnector;
839
840        let root_store = rustls::RootCertStore::from_iter(
841            webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
842        );
843        let mut config = ClientConfig::builder()
844            .with_root_certificates(root_store)
845            .with_no_client_auth();
846        // Advertise h2 via ALPN so the server selects HTTP/2.
847        config.alpn_protocols = vec![b"h2".to_vec()];
848        let connector = TlsConnector::from(Arc::new(config));
849        let server_name = ServerName::try_from(backend.host.as_str())
850            .map_err(|e| format!("invalid upstream hostname '{}': {}", backend.host, e))?
851            .to_owned();
852        let tls_stream = connector
853            .connect(server_name, tcp)
854            .await
855            .map_err(|e| format!("h2 proxy: TLS handshake with {} failed: {}", addr, e))?;
856        send_h2_request(request, client_ip, backend, tls_stream).await
857    } else {
858        send_h2_request(request, client_ip, backend, tcp).await
859    }
860}
861
862/// Drive the h2 client handshake + request/response over any async I/O stream.
863///
864/// Accepts both plain `TcpStream` and `TlsStream<TcpStream>` — anything that
865/// satisfies `AsyncRead + AsyncWrite + Unpin + Send + 'static`.
866#[cfg(feature = "http2")]
867async fn send_h2_request<T>(
868    request: &Request,
869    client_ip: &str,
870    backend: &Backend,
871    stream: T,
872) -> Result<Response, String>
873where
874    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
875{
876    use bytes::Bytes;
877    use http as hc;
878
879    let addr = format!("{}:{}", backend.host, backend.port);
880
881    let (send_req, conn) = h2::client::handshake(stream)
882        .await
883        .map_err(|e| format!("h2 proxy: handshake with {} failed: {}", addr, e))?;
884
885    tokio::spawn(async move {
886        let _ = conn.await;
887    });
888
889    let scheme = if backend.tls { "https" } else { "http" };
890    let uri_str = format!("{}://{}{}", scheme, addr, request.request_uri);
891    let uri: hc::Uri = uri_str.parse().map_err(|e: hc::uri::InvalidUri| e.to_string())?;
892    let method = hc::Method::from_bytes(request.method.as_bytes()).map_err(|e| e.to_string())?;
893
894    let mut builder = hc::Request::builder().method(method).uri(uri);
895    builder = builder.header("host", &backend.host);
896    for h in &request.headers {
897        let lower = h.name.to_lowercase();
898        if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
899            continue;
900        }
901        builder = builder.header(&h.name, &h.value);
902    }
903    builder = builder.header("x-forwarded-for", client_ip);
904    builder = builder.header("via", "2 rws");
905
906    let body_bytes = Bytes::from(request.body.clone());
907    let end_of_stream = body_bytes.is_empty();
908    let http_req = builder.body(()).map_err(|e| e.to_string())?;
909
910    let mut send_req = send_req.ready().await.map_err(|e| e.to_string())?;
911    let (resp_future, mut req_body) = send_req
912        .send_request(http_req, end_of_stream)
913        .map_err(|e| e.to_string())?;
914    if !end_of_stream {
915        req_body.send_data(body_bytes, true).map_err(|e| e.to_string())?;
916    }
917
918    let resp = resp_future.await.map_err(|e| e.to_string())?;
919    let (parts, mut body) = resp.into_parts();
920
921    let content_type = parts
922        .headers
923        .get("content-type")
924        .and_then(|v| v.to_str().ok())
925        .unwrap_or("application/octet-stream")
926        .to_string();
927
928    let mut body_bytes: Vec<u8> = Vec::new();
929    while let Some(chunk) = body.data().await {
930        body_bytes.extend_from_slice(&chunk.map_err(|e| e.to_string())?);
931    }
932
933    let mut response = Response::new();
934    response.status_code = parts.status.as_u16() as i16;
935    response.reason_phrase = parts.status.canonical_reason().unwrap_or("").to_string();
936
937    const H2_HOP: &[&str] = &[
938        "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-connection", "te",
939    ];
940    for (name, value) in &parts.headers {
941        let lower = name.as_str().to_lowercase();
942        if H2_HOP.contains(&lower.as_str()) {
943            continue;
944        }
945        if let Ok(v) = value.to_str() {
946            response.headers.push(crate::header::Header {
947                name: name.as_str().to_string(),
948                value: v.to_string(),
949            });
950        }
951    }
952
953    if !body_bytes.is_empty() {
954        response.content_range_list = vec![Range::get_content_range(body_bytes, content_type)];
955    }
956
957    Ok(response)
958}
959
960// ── gRPC proxy ────────────────────────────────────────────────────────────────
961
962/// gRPC reverse proxy middleware.
963///
964/// Recognises requests with `Content-Type: application/grpc*` and forwards them
965/// to a backend over HTTP/2, leaving all other requests to the next layer.
966///
967/// Requires the `http2` Cargo feature.
968///
969/// # Example
970///
971/// ```rust,no_run
972/// use rust_web_server::app::App;
973/// use rust_web_server::core::New;
974/// use rust_web_server::proxy::GrpcProxy;
975///
976/// let app = App::new()
977///     .wrap(GrpcProxy::new(["grpc-service:50051"]));
978/// ```
979#[cfg(feature = "http2")]
980pub struct GrpcProxy {
981    inner: H2ReverseProxy,
982}
983
984#[cfg(feature = "http2")]
985impl GrpcProxy {
986    /// Create a proxy distributing gRPC connections across `backends` in round-robin order.
987    ///
988    /// Each backend entry can be:
989    /// - `"host:port"` — plain TCP (gRPC cleartext)
990    /// - `"grpc://host:port"` — plain TCP (explicit scheme)
991    /// - `"grpcs://host:port"` — TLS (gRPC over TLS; port defaults to 443)
992    /// - `"https://host:port"` — TLS (same as `grpcs://`)
993    ///
994    /// TLS backends require the `http2` Cargo feature.
995    pub fn new<I, S>(backends: I) -> Self
996    where
997        I: IntoIterator<Item = S>,
998        S: AsRef<str>,
999    {
1000        GrpcProxy { inner: H2ReverseProxy::new(backends) }
1001    }
1002
1003    /// Only proxy requests whose URI starts with `prefix`; pass others through.
1004    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
1005        self.inner = self.inner.path_prefix(prefix);
1006        self
1007    }
1008}
1009
1010#[cfg(feature = "http2")]
1011impl crate::middleware::Middleware for GrpcProxy {
1012    fn handle(
1013        &self,
1014        request: &crate::request::Request,
1015        connection: &crate::server::ConnectionInfo,
1016        next: &dyn crate::application::Application,
1017    ) -> Result<crate::response::Response, String> {
1018        let ct = request
1019            .get_header("content-type".to_string())
1020            .map(|h| h.value.as_str())
1021            .unwrap_or("");
1022        if ct.starts_with("application/grpc") {
1023            self.inner.handle(request, connection, next)
1024        } else {
1025            next.execute(request, connection)
1026        }
1027    }
1028}
1029
1030// ── Backend::parse unit tests ─────────────────────────────────────────────────
1031
1032#[cfg(test)]
1033mod backend_parse_tests {
1034    use super::Backend;
1035
1036    fn parse(url: &str) -> Option<(String, u16, bool)> {
1037        Backend::parse(url).map(|b| (b.host, b.port, b.tls))
1038    }
1039
1040    #[test]
1041    fn bare_host_port() {
1042        assert_eq!(Some(("api.example.com".into(), 8080, false)), parse("api.example.com:8080"));
1043    }
1044
1045    #[test]
1046    fn http_scheme() {
1047        assert_eq!(Some(("backend".into(), 3000, false)), parse("http://backend:3000"));
1048    }
1049
1050    #[test]
1051    fn h2_scheme_plain() {
1052        assert_eq!(Some(("svc".into(), 50051, false)), parse("h2://svc:50051"));
1053    }
1054
1055    #[test]
1056    fn grpc_scheme_plain() {
1057        assert_eq!(Some(("svc".into(), 50051, false)), parse("grpc://svc:50051"));
1058    }
1059
1060    #[test]
1061    fn https_scheme_sets_tls_and_default_port() {
1062        assert_eq!(Some(("api.example.com".into(), 443, true)), parse("https://api.example.com"));
1063    }
1064
1065    #[test]
1066    fn https_scheme_explicit_port() {
1067        assert_eq!(Some(("api.example.com".into(), 8443, true)), parse("https://api.example.com:8443"));
1068    }
1069
1070    #[test]
1071    fn h2s_scheme_sets_tls() {
1072        assert_eq!(Some(("svc".into(), 443, true)), parse("h2s://svc"));
1073    }
1074
1075    #[test]
1076    fn h2s_scheme_explicit_port() {
1077        assert_eq!(Some(("svc".into(), 8443, true)), parse("h2s://svc:8443"));
1078    }
1079
1080    #[test]
1081    fn grpcs_scheme_sets_tls() {
1082        assert_eq!(Some(("grpc-svc".into(), 443, true)), parse("grpcs://grpc-svc"));
1083    }
1084
1085    #[test]
1086    fn grpcs_scheme_explicit_port() {
1087        assert_eq!(Some(("grpc-svc".into(), 50052, true)), parse("grpcs://grpc-svc:50052"));
1088    }
1089
1090    #[test]
1091    fn empty_host_returns_none() {
1092        assert_eq!(None, parse("https://"));
1093    }
1094
1095    #[test]
1096    fn bare_host_no_port_defaults_to_80() {
1097        assert_eq!(Some(("myhost".into(), 80, false)), parse("myhost"));
1098    }
1099}