Skip to main content

rust_web_server/proxy/
mod.rs

1//! Reverse proxy middleware with round-robin load balancing.
2//!
3//! `ReverseProxy` implements [`Middleware`] — wrap any application with it and
4//! all matching requests are forwarded to one of the configured backends over
5//! plain HTTP/1.1.  Failed backends are skipped and the next one is tried
6//! before returning `502 Bad Gateway`.
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! use rust_web_server::app::App;
12//! use rust_web_server::core::New;
13//! use rust_web_server::proxy::{LoadBalancing, ReverseProxy};
14//!
15//! // Proxy every request across two backends in round-robin order.
16//! let app = App::new()
17//!     .wrap(ReverseProxy::new(["http://backend-1:8080", "http://backend-2:8080"])
18//!         .strategy(LoadBalancing::RoundRobin));
19//!
20//! // Only proxy /api/* requests; everything else is handled locally.
21//! let app2 = App::new()
22//!     .wrap(ReverseProxy::new(["http://api-service:3000"])
23//!         .path_prefix("/api"));
24//! ```
25
26pub mod pool;
27
28#[cfg(test)]
29mod tests;
30
31use std::io::{Read, Write};
32use std::net::{TcpStream, ToSocketAddrs};
33use std::sync::atomic::{AtomicUsize, Ordering};
34use std::sync::Arc;
35use std::time::Duration;
36
37pub use pool::ConnPool;
38
39use crate::application::Application;
40use crate::core::New;
41use crate::middleware::Middleware;
42use crate::mime_type::MimeType;
43use crate::range::Range;
44use crate::request::Request;
45use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
46use crate::server::ConnectionInfo;
47
48// Hop-by-hop headers that must not be forwarded (RFC 7230 §6.1)
49const HOP_BY_HOP: &[&str] = &[
50    "connection",
51    "keep-alive",
52    "proxy-authenticate",
53    "proxy-authorization",
54    "te",
55    "trailers",
56    "transfer-encoding",
57    "upgrade",
58];
59
60/// Load balancing strategy used by [`ReverseProxy`].
61pub enum LoadBalancing {
62    /// Distribute requests across backends in a cyclic order.
63    RoundRobin,
64}
65
66/// Reverse proxy middleware.
67///
68/// Forwards incoming requests to one of the configured backends over HTTP/1.1.
69/// On connection failure the next backend in the list is tried; when all
70/// backends have failed the middleware returns `502 Bad Gateway`.
71///
72/// Hop-by-hop headers are stripped before forwarding.  `X-Forwarded-For` and
73/// `Via` are added to every forwarded request.
74///
75/// Idle connections are pooled and reused across requests (up to
76/// [`ConnPool::new_default`] limits: 8 idle per backend, 60-second timeout).
77/// This eliminates per-request TCP handshake overhead and ephemeral-port
78/// exhaustion.  Use [`ReverseProxy::with_pool`] to share a pool across
79/// multiple proxy instances or to tune pool parameters.
80pub struct ReverseProxy {
81    backends: Vec<Backend>,
82    path_prefix: Option<String>,
83    connect_timeout: Duration,
84    read_timeout: Duration,
85    counter: AtomicUsize,
86    pool: Arc<ConnPool>,
87}
88
89impl ReverseProxy {
90    /// Create a proxy that distributes requests across `backends` in
91    /// round-robin order.  Each entry must be `"http://host:port"` or
92    /// `"host:port"` (port defaults to 80).
93    pub fn new<I, S>(backends: I) -> Self
94    where
95        I: IntoIterator<Item = S>,
96        S: AsRef<str>,
97    {
98        Self {
99            backends: backends
100                .into_iter()
101                .filter_map(|u| Backend::parse(u.as_ref()))
102                .collect(),
103            path_prefix: None,
104            connect_timeout: Duration::from_secs(5),
105            read_timeout: Duration::from_secs(30),
106            counter: AtomicUsize::new(0),
107            pool: Arc::new(ConnPool::new_default()),
108        }
109    }
110
111    /// Only proxy requests whose URI starts with `prefix`.
112    ///
113    /// Other requests are passed through to the next layer in the middleware
114    /// chain (or the inner application).
115    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
116        self.path_prefix = Some(prefix.into());
117        self
118    }
119
120    /// Override the load balancing strategy (currently only `RoundRobin`).
121    pub fn strategy(self, _strategy: LoadBalancing) -> Self {
122        self
123    }
124
125    /// Override the TCP connect timeout (default: 5 000 ms).
126    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
127        self.connect_timeout = Duration::from_millis(ms);
128        self
129    }
130
131    /// Override the response read timeout (default: 30 000 ms).
132    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
133        self.read_timeout = Duration::from_millis(ms);
134        self
135    }
136
137    /// Attach a shared connection pool.
138    ///
139    /// Useful for sharing one pool across multiple `ReverseProxy` instances
140    /// or for tuning pool parameters (capacity, idle timeout).
141    pub fn with_pool(mut self, pool: Arc<ConnPool>) -> Self {
142        self.pool = pool;
143        self
144    }
145
146    /// Set the maximum number of idle connections per backend (default: 8).
147    pub fn max_idle_conns(mut self, n: usize) -> Self {
148        self.pool = Arc::new(ConnPool::new(n, Duration::from_secs(60)));
149        self
150    }
151
152    fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
153        if self.backends.is_empty() {
154            return Err("no backends configured".to_string());
155        }
156        let n = self.backends.len();
157        let start = self.counter.fetch_add(1, Ordering::Relaxed);
158        for attempt in 0..n {
159            let idx = (start + attempt) % n;
160            match self.try_backend(request, connection, &self.backends[idx]) {
161                Ok(resp) => return Ok(resp),
162                Err(_) if attempt + 1 < n => continue,
163                Err(e) => return Err(e),
164            }
165        }
166        Err("all backends failed".to_string())
167    }
168
169    fn try_backend(
170        &self,
171        request: &Request,
172        connection: &ConnectionInfo,
173        backend: &Backend,
174    ) -> Result<Response, String> {
175        let key = format!("{}:{}", backend.host, backend.port);
176
177        // Try a pooled connection first; fall back to a fresh one.
178        let stream = if let Some(pooled) = self.pool.acquire(&key) {
179            pooled
180        } else {
181            let addr_str = key.as_str();
182            let sock_addr = addr_str
183                .to_socket_addrs()
184                .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
185                .next()
186                .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
187            TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
188                .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?
189        };
190
191        stream.set_read_timeout(Some(self.read_timeout)).map_err(|e| e.to_string())?;
192        stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
193
194        // keep_alive = true: send Connection: keep-alive so the server holds
195        // the connection open after responding.
196        let req_bytes = build_request(request, &backend.host, &connection.client.ip, true);
197        let mut stream = stream;
198        stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
199
200        let mut tmp = [0u8; 4096];
201        let (header_bytes, body_prefix) = read_headers_only(&mut stream, &mut tmp)?;
202        let header_lower =
203            std::str::from_utf8(&header_bytes).unwrap_or("").to_ascii_lowercase();
204
205        if should_stream_response(&header_lower) {
206            // Streaming path — pipe bytes straight to the client.
207            // The connection cannot be reused while the body is in flight.
208            let mut resp = parse_status_and_headers(&header_bytes)?;
209            resp.stream_pipe =
210                Some(Box::new(ConcatReader::new(body_prefix, stream)));
211            Ok(resp)
212        } else {
213            // Buffered path — read the full body, then optionally return the
214            // connection to the pool.
215            let (resp_bytes, reusable) =
216                read_response_from_partial(&mut stream, header_bytes, body_prefix, &mut tmp)?;
217            if reusable {
218                self.pool.release(&key, stream);
219            }
220            Response::parse(&resp_bytes)
221        }
222    }
223}
224
225impl Middleware for ReverseProxy {
226    fn handle(
227        &self,
228        request: &Request,
229        connection: &ConnectionInfo,
230        next: &dyn Application,
231    ) -> Result<Response, String> {
232        if let Some(prefix) = &self.path_prefix {
233            if !request.request_uri.starts_with(prefix.as_str()) {
234                return next.execute(request, connection);
235            }
236        }
237        match self.proxy(request, connection) {
238            Ok(resp) => Ok(resp),
239            Err(_) => Ok(bad_gateway()),
240        }
241    }
242}
243
244// ── helpers ───────────────────────────────────────────────────────────────────
245
246pub(crate) fn build_request(
247    request: &Request,
248    backend_host: &str,
249    client_ip: &str,
250    keep_alive: bool,
251) -> Vec<u8> {
252    let mut out: Vec<u8> = Vec::new();
253    let _ = write!(
254        out,
255        "{} {} HTTP/1.1\r\nHost: {}\r\n",
256        request.method, request.request_uri, backend_host
257    );
258    for h in &request.headers {
259        let lower = h.name.to_lowercase();
260        if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
261            continue;
262        }
263        let _ = write!(out, "{}: {}\r\n", h.name, h.value);
264    }
265    let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
266    let _ = write!(out, "Via: 1.1 rws\r\n");
267    if keep_alive {
268        let _ = write!(out, "Connection: keep-alive\r\n");
269    } else {
270        let _ = write!(out, "Connection: close\r\n");
271    }
272    if !request.body.is_empty() {
273        let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
274    }
275    let _ = write!(out, "\r\n");
276    out.extend_from_slice(&request.body);
277    out
278}
279
280/// Read a full HTTP/1.1 response from `stream`.
281///
282/// Supports three body-length mechanisms:
283/// - `Content-Length: N` — reads exactly N bytes; stream reusable if backend allows.
284/// - `Transfer-Encoding: chunked` — decodes all chunks and rewrites the
285///   response as a `Content-Length` response; stream reusable if backend allows.
286/// - Neither — reads until EOF; stream is never reusable (connection closes).
287///
288/// Returns `(response_bytes, can_reuse)`.  When `can_reuse` is `true` and the
289/// caller has a [`ConnPool`], it should call [`ConnPool::release`].
290pub(crate) fn read_response_poolable(stream: &mut TcpStream) -> Result<(Vec<u8>, bool), String> {
291    let mut buf: Vec<u8> = Vec::with_capacity(8192);
292    let mut tmp = [0u8; 4096];
293
294    // Read until headers end (\r\n\r\n).
295    let header_end = loop {
296        let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
297        if n == 0 {
298            return if buf.is_empty() {
299                Err("backend closed connection without sending a response".to_string())
300            } else {
301                Ok((buf, false))
302            };
303        }
304        buf.extend_from_slice(&tmp[..n]);
305        if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
306            break pos + 4;
307        }
308    };
309
310    let header_str_lower =
311        std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_ascii_lowercase();
312
313    let connection_close =
314        header_str_lower.lines().any(|l| l.starts_with("connection:") && l.contains("close"));
315
316    let is_chunked = header_str_lower
317        .lines()
318        .any(|l| l.starts_with("transfer-encoding:") && l.contains("chunked"));
319
320    let content_length: Option<usize> = header_str_lower.lines().find_map(|l| {
321        l.strip_prefix("content-length:")?
322            .trim()
323            .parse()
324            .ok()
325    });
326
327    if is_chunked {
328        let decoded = decode_chunked(stream, &buf, header_end, &mut tmp)?;
329        rewrite_as_content_length(&mut buf, header_end, &decoded);
330        Ok((buf, !connection_close))
331    } else if let Some(len) = content_length {
332        while buf.len() < header_end + len {
333            let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
334            if n == 0 {
335                break;
336            }
337            buf.extend_from_slice(&tmp[..n]);
338        }
339        Ok((buf, !connection_close))
340    } else {
341        // No body-length indicator — read until EOF; connection cannot be reused.
342        loop {
343            let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
344            if n == 0 {
345                break;
346            }
347            buf.extend_from_slice(&tmp[..n]);
348        }
349        Ok((buf, false))
350    }
351}
352
353/// Decode HTTP/1.1 chunked transfer-encoding from `stream`.
354///
355/// `buf[header_end..]` may already contain some body bytes that arrived in
356/// the same read as the headers.  Returns the fully decoded body.
357fn decode_chunked(
358    stream: &mut TcpStream,
359    buf: &[u8],
360    header_end: usize,
361    tmp: &mut [u8],
362) -> Result<Vec<u8>, String> {
363    // Seed `raw` with any body bytes already buffered alongside the headers.
364    let mut raw: Vec<u8> = buf[header_end..].to_vec();
365    let mut decoded: Vec<u8> = Vec::new();
366
367    loop {
368        // Wait until we have at least one complete chunk-size line (ends with \r\n).
369        let crlf = loop {
370            if let Some(p) = raw.windows(2).position(|w| w == b"\r\n") {
371                break p;
372            }
373            let n = stream.read(tmp).map_err(|e| e.to_string())?;
374            if n == 0 {
375                return Err("chunked: premature EOF reading chunk size".to_string());
376            }
377            raw.extend_from_slice(&tmp[..n]);
378        };
379
380        // Chunk size is hex, optionally followed by chunk-extensions (";…").
381        let size_line = std::str::from_utf8(&raw[..crlf])
382            .map_err(|_| "chunked: non-UTF-8 chunk size line".to_string())?;
383        let size_str = size_line.split(';').next().unwrap_or("").trim();
384        let chunk_size = usize::from_str_radix(size_str, 16)
385            .map_err(|_| format!("chunked: invalid chunk size '{}'", size_str))?;
386        raw.drain(..crlf + 2); // consume "<size>\r\n"
387
388        if chunk_size == 0 {
389            // Last chunk — consume the trailing CRLF ("0\r\n\r\n" → trailing "\r\n" still pending).
390            while raw.len() < 2 {
391                let n = stream.read(tmp).map_err(|e| e.to_string())?;
392                if n == 0 {
393                    break;
394                }
395                raw.extend_from_slice(&tmp[..n]);
396            }
397            break;
398        }
399
400        // Read chunk data + trailing CRLF.
401        while raw.len() < chunk_size + 2 {
402            let n = stream.read(tmp).map_err(|e| e.to_string())?;
403            if n == 0 {
404                return Err("chunked: premature EOF reading chunk body".to_string());
405            }
406            raw.extend_from_slice(&tmp[..n]);
407        }
408        decoded.extend_from_slice(&raw[..chunk_size]);
409        raw.drain(..chunk_size + 2); // consume "<data>\r\n"
410    }
411
412    Ok(decoded)
413}
414
415/// Rewrite `buf` in-place: strip `Transfer-Encoding`, add `Content-Length`,
416/// replace the old (undecoded) body with `decoded`.
417fn rewrite_as_content_length(buf: &mut Vec<u8>, header_end: usize, decoded: &[u8]) {
418    let header_str = std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_string();
419    buf.clear();
420    for line in header_str.lines() {
421        if line.to_ascii_lowercase().starts_with("transfer-encoding:") || line.is_empty() {
422            continue;
423        }
424        buf.extend_from_slice(line.as_bytes());
425        buf.extend_from_slice(b"\r\n");
426    }
427    let _ = write!(buf, "Content-Length: {}\r\n\r\n", decoded.len());
428    buf.extend_from_slice(decoded);
429}
430
431/// Non-pooled version of the response reader, used by callers that send
432/// `Connection: close` (e.g. `proxy_http1`, `proxy_https1`).
433pub(crate) fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
434    read_response_from(stream)
435}
436
437pub(crate) fn read_response_from<R: Read>(stream: &mut R) -> Result<Vec<u8>, String> {
438    let mut buf: Vec<u8> = Vec::with_capacity(8192);
439    let mut tmp = [0u8; 4096];
440
441    let header_end = loop {
442        let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
443        if n == 0 {
444            return if buf.is_empty() {
445                Err("backend closed connection without sending a response".to_string())
446            } else {
447                Ok(buf)
448            };
449        }
450        buf.extend_from_slice(&tmp[..n]);
451        if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
452            break pos + 4;
453        }
454    };
455
456    let content_length = std::str::from_utf8(&buf[..header_end])
457        .unwrap_or("")
458        .lines()
459        .find_map(|line| {
460            line.to_lowercase()
461                .starts_with("content-length:")
462                .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
463                .flatten()
464        });
465
466    match content_length {
467        Some(len) => {
468            while buf.len() < header_end + len {
469                let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
470                if n == 0 {
471                    break;
472                }
473                buf.extend_from_slice(&tmp[..n]);
474            }
475        }
476        None => loop {
477            let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
478            if n == 0 {
479                break;
480            }
481            buf.extend_from_slice(&tmp[..n]);
482        },
483    }
484
485    Ok(buf)
486}
487
488/// Forward a single HTTP/1.1 request to `host:port` and return the response.
489///
490/// This is the shared low-level building block used by [`crate::canary`] and
491/// [`crate::ingress`] so they don't have to duplicate the TCP + request/response
492/// marshalling code.
493pub(crate) fn proxy_http1(
494    request: &Request,
495    client_ip: &str,
496    host: &str,
497    port: u16,
498    connect_timeout: Duration,
499    read_timeout: Duration,
500) -> Result<Response, String> {
501    use std::net::ToSocketAddrs;
502    let addr_str = format!("{}:{}", host, port);
503    let sock_addr = addr_str
504        .to_socket_addrs()
505        .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
506        .next()
507        .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
508    let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
509        .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
510    stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
511    stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
512    let req_bytes = build_request(request, host, client_ip, false);
513    let mut stream = stream;
514    stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
515    let resp_bytes = read_response(&mut stream)?;
516    Response::parse(&resp_bytes)
517}
518
519/// Forward a single HTTPS/1.1 request to `host:port` over TLS and return the
520/// response. Requires the `http-client` or `http2` feature (both bring in
521/// `rustls` + `webpki-roots`).
522#[cfg(any(feature = "http-client", feature = "http2"))]
523pub(crate) fn proxy_https1(
524    request: &Request,
525    client_ip: &str,
526    host: &str,
527    port: u16,
528    connect_timeout: Duration,
529    read_timeout: Duration,
530) -> Result<Response, String> {
531    use rustls::pki_types::ServerName;
532    use rustls::ClientConfig;
533    use std::net::ToSocketAddrs;
534    use std::sync::Arc;
535
536    let addr_str = format!("{}:{}", host, port);
537    let sock_addr = addr_str
538        .to_socket_addrs()
539        .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
540        .next()
541        .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
542
543    let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
544        .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
545    stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
546    stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
547
548    let root_store =
549        rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
550    let config = Arc::new(
551        ClientConfig::builder()
552            .with_root_certificates(root_store)
553            .with_no_client_auth(),
554    );
555    let server_name = ServerName::try_from(host.to_string())
556        .map_err(|e| format!("invalid upstream hostname '{}': {}", host, e))?;
557    let conn = rustls::ClientConnection::new(config, server_name).map_err(|e| e.to_string())?;
558    let mut tls = rustls::StreamOwned::new(conn, stream);
559
560    let req_bytes = build_request(request, host, client_ip, false);
561    tls.write_all(&req_bytes)
562        .map_err(|e| format!("write to upstream failed: {}", e))?;
563
564    let resp_bytes = read_response_from(&mut tls)?;
565    Response::parse(&resp_bytes)
566}
567
568fn bad_gateway() -> Response {
569    let cr = Range::get_content_range(
570        b"502 Bad Gateway".to_vec(),
571        MimeType::TEXT_PLAIN.to_string(),
572    );
573    let mut r = Response::new();
574    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
575    r.reason_phrase = STATUS_CODE_REASON_PHRASE
576        .n502_bad_gateway
577        .reason_phrase
578        .to_string();
579    r.content_range_list = vec![cr];
580    r
581}
582
583// ── Streaming proxy helpers ───────────────────────────────────────────────────
584
585/// Responses larger than this threshold are streamed instead of buffered.
586const STREAM_THRESHOLD: usize = 1024 * 1024; // 1 MB
587
588/// A `Read` implementation that drains a prefix buffer before reading from the
589/// inner stream. Used to replay body bytes that arrived with the HTTP headers.
590pub(crate) struct ConcatReader<R: Read + Send> {
591    prefix: Vec<u8>,
592    prefix_pos: usize,
593    inner: R,
594}
595
596impl<R: Read + Send> ConcatReader<R> {
597    fn new(prefix: Vec<u8>, inner: R) -> Self {
598        ConcatReader { prefix, prefix_pos: 0, inner }
599    }
600}
601
602impl<R: Read + Send> Read for ConcatReader<R> {
603    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
604        if self.prefix_pos < self.prefix.len() {
605            let avail = &self.prefix[self.prefix_pos..];
606            let n = buf.len().min(avail.len());
607            buf[..n].copy_from_slice(&avail[..n]);
608            self.prefix_pos += n;
609            return Ok(n);
610        }
611        self.inner.read(buf)
612    }
613}
614
615/// Read exactly the HTTP response headers (up to and including `\r\n\r\n`).
616///
617/// Returns `(header_bytes, body_prefix)` where `body_prefix` contains any
618/// body bytes that arrived in the same TCP segment as the headers.
619fn read_headers_only(stream: &mut TcpStream, tmp: &mut [u8]) -> Result<(Vec<u8>, Vec<u8>), String> {
620    let mut buf: Vec<u8> = Vec::with_capacity(4096);
621    loop {
622        let n = stream.read(tmp).map_err(|e| e.to_string())?;
623        if n == 0 {
624            return Err("backend closed connection before headers were complete".to_string());
625        }
626        buf.extend_from_slice(&tmp[..n]);
627        if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
628            let body_prefix = buf[pos + 4..].to_vec();
629            buf.truncate(pos + 4);
630            return Ok((buf, body_prefix));
631        }
632    }
633}
634
635/// Returns `true` when the response should be streamed rather than buffered.
636///
637/// Streams when any of the following hold:
638/// - `Content-Type: text/event-stream` — SSE
639/// - `Transfer-Encoding: chunked` — AI token streams, etc.
640/// - `Content-Length` exceeds 1 MB — large file downloads
641pub(crate) fn should_stream_response(header_lower: &str) -> bool {
642    let is_sse = header_lower.lines().any(|l| {
643        l.starts_with("content-type:") && l.contains("text/event-stream")
644    });
645    let is_chunked = header_lower.lines().any(|l| {
646        l.starts_with("transfer-encoding:") && l.contains("chunked")
647    });
648    let content_length: Option<usize> = header_lower.lines().find_map(|l| {
649        l.strip_prefix("content-length:")?.trim().parse().ok()
650    });
651    let is_large = content_length.map_or(false, |n| n > STREAM_THRESHOLD);
652    is_sse || is_chunked || is_large
653}
654
655/// Parse the status line and headers from raw header bytes (ending at `\r\n\r\n`).
656fn parse_status_and_headers(header_bytes: &[u8]) -> Result<Response, String> {
657    let s = std::str::from_utf8(header_bytes)
658        .map_err(|e| format!("non-UTF-8 response headers: {}", e))?;
659    let mut lines = s.lines();
660    let status_line = lines.next().ok_or("empty backend response")?;
661    let mut parts = status_line.splitn(3, ' ');
662    let http_version = parts.next().unwrap_or("HTTP/1.1").to_string();
663    let status_code: i16 = parts
664        .next()
665        .unwrap_or("502")
666        .parse()
667        .map_err(|_| format!("invalid status code in '{}'", status_line))?;
668    let reason_phrase = parts.next().unwrap_or("").trim_end_matches('\r').to_string();
669    let mut headers = Vec::new();
670    for line in lines {
671        let line = line.trim_end_matches('\r');
672        if line.is_empty() { break; }
673        if let Some(colon) = line.find(':') {
674            headers.push(crate::header::Header {
675                name: line[..colon].trim().to_string(),
676                value: line[colon + 1..].trim().to_string(),
677            });
678        }
679    }
680    Ok(Response {
681        http_version,
682        status_code,
683        reason_phrase,
684        headers,
685        content_range_list: vec![],
686        stream_file: None,
687        stream_pipe: None,
688    })
689}
690
691/// Read the remaining body after headers have already been read.
692///
693/// `header_bytes` ends with `\r\n\r\n`; `body_prefix` holds any body bytes
694/// that arrived in the same TCP read.  Handles all three body mechanisms
695/// (chunked, content-length, read-to-EOF).  Returns `(full_response_bytes, can_reuse)`.
696fn read_response_from_partial(
697    stream: &mut TcpStream,
698    header_bytes: Vec<u8>,
699    body_prefix: Vec<u8>,
700    tmp: &mut [u8],
701) -> Result<(Vec<u8>, bool), String> {
702    let header_end = header_bytes.len();
703    let mut buf = header_bytes;
704    buf.extend_from_slice(&body_prefix);
705
706    let header_lower =
707        std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_ascii_lowercase();
708    let connection_close =
709        header_lower.lines().any(|l| l.starts_with("connection:") && l.contains("close"));
710    let is_chunked = header_lower
711        .lines()
712        .any(|l| l.starts_with("transfer-encoding:") && l.contains("chunked"));
713    let content_length: Option<usize> = header_lower.lines().find_map(|l| {
714        l.strip_prefix("content-length:")?.trim().parse().ok()
715    });
716
717    if is_chunked {
718        let decoded = decode_chunked(stream, &buf, header_end, tmp)?;
719        rewrite_as_content_length(&mut buf, header_end, &decoded);
720        Ok((buf, !connection_close))
721    } else if let Some(len) = content_length {
722        while buf.len() < header_end + len {
723            let n = stream.read(tmp).map_err(|e| e.to_string())?;
724            if n == 0 { break; }
725            buf.extend_from_slice(&tmp[..n]);
726        }
727        Ok((buf, !connection_close))
728    } else {
729        loop {
730            let n = stream.read(tmp).map_err(|e| e.to_string())?;
731            if n == 0 { break; }
732            buf.extend_from_slice(&tmp[..n]);
733        }
734        Ok((buf, false))
735    }
736}
737
738// ── Backend URL parsing ───────────────────────────────────────────────────────
739
740struct Backend {
741    host: String,
742    port: u16,
743    /// Whether the upstream connection should use TLS.
744    /// Set when the URL scheme is `https://`, `h2s://`, or `grpcs://`.
745    #[cfg_attr(not(feature = "http2"), allow(dead_code))]
746    tls: bool,
747}
748
749impl Backend {
750    fn parse(url: &str) -> Option<Self> {
751        let (rest, tls, default_port) = if let Some(r) = url.strip_prefix("https://") {
752            (r, true, 443u16)
753        } else if let Some(r) = url.strip_prefix("h2s://") {
754            (r, true, 443u16)
755        } else if let Some(r) = url.strip_prefix("grpcs://") {
756            (r, true, 443u16)
757        } else if let Some(r) = url.strip_prefix("http://") {
758            (r, false, 80u16)
759        } else if let Some(r) = url.strip_prefix("h2://") {
760            (r, false, 80u16)
761        } else if let Some(r) = url.strip_prefix("grpc://") {
762            (r, false, 80u16)
763        } else {
764            (url, false, 80u16)
765        };
766        // Drop any path component.
767        let host_port = rest.split('/').next().unwrap_or(rest);
768        let (host, port) = if let Some(colon) = host_port.rfind(':') {
769            let port_str = &host_port[colon + 1..];
770            if let Ok(p) = port_str.parse::<u16>() {
771                (host_port[..colon].to_string(), p)
772            } else {
773                (host_port.to_string(), default_port)
774            }
775        } else {
776            (host_port.to_string(), default_port)
777        };
778        if host.is_empty() {
779            return None;
780        }
781        Some(Backend { host, port, tls })
782    }
783}
784
785// ── HTTP/2 reverse proxy ──────────────────────────────────────────────────────
786
787/// Reverse proxy that forwards requests to HTTP/2 backends.
788///
789/// Wraps [`ReverseProxy`] and forces HTTP/2 (`h2`) for all upstream connections.
790/// Requires the `http2` Cargo feature.
791///
792/// This proxy also transparently handles gRPC traffic
793/// (`Content-Type: application/grpc*`) — gRPC DATA frames are forwarded
794/// as-is because gRPC is layered directly on HTTP/2.
795#[cfg(feature = "http2")]
796pub struct H2ReverseProxy {
797    inner: ReverseProxy,
798}
799
800#[cfg(feature = "http2")]
801impl H2ReverseProxy {
802    /// Create a proxy distributing requests across `backends` in round-robin order.
803    ///
804    /// Each backend entry can be:
805    /// - `"host:port"` — plain TCP (HTTP/2 cleartext)
806    /// - `"h2://host:port"` — plain TCP (explicit scheme)
807    /// - `"h2s://host:port"` — TLS (HTTP/2 over HTTPS; port defaults to 443)
808    /// - `"https://host:port"` — TLS (same as `h2s://`)
809    ///
810    /// TLS backends require the `http2` Cargo feature (includes `rustls` +
811    /// `webpki-roots`).  Certificate verification uses the WebPKI trust store.
812    pub fn new<I, S>(backends: I) -> Self
813    where
814        I: IntoIterator<Item = S>,
815        S: AsRef<str>,
816    {
817        H2ReverseProxy {
818            inner: ReverseProxy::new(backends),
819        }
820    }
821
822    /// Only proxy requests whose URI starts with `prefix`; pass others through.
823    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
824        self.inner = self.inner.path_prefix(prefix);
825        self
826    }
827
828    /// Override the TCP connect timeout (default: 5 s).
829    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
830        self.inner = self.inner.connect_timeout_ms(ms);
831        self
832    }
833
834    /// Override the response read timeout (default: 30 s).
835    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
836        self.inner = self.inner.read_timeout_ms(ms);
837        self
838    }
839}
840
841#[cfg(feature = "http2")]
842impl crate::middleware::Middleware for H2ReverseProxy {
843    fn handle(
844        &self,
845        request: &crate::request::Request,
846        connection: &crate::server::ConnectionInfo,
847        next: &dyn crate::application::Application,
848    ) -> Result<crate::response::Response, String> {
849        if let Some(prefix) = &self.inner.path_prefix {
850            if !request.request_uri.starts_with(prefix.as_str()) {
851                return next.execute(request, connection);
852            }
853        }
854        if self.inner.backends.is_empty() {
855            return Ok(bad_gateway());
856        }
857        let n = self.inner.backends.len();
858        let start = self.inner.counter.fetch_add(1, Ordering::Relaxed);
859        for attempt in 0..n {
860            let idx = (start + attempt) % n;
861            match try_backend_h2(request, &connection.client.ip, &self.inner.backends[idx],
862                                  self.inner.connect_timeout, self.inner.read_timeout) {
863                Ok(resp) => return Ok(resp),
864                Err(_) if attempt + 1 < n => continue,
865                Err(_) => break,
866            }
867        }
868        Ok(bad_gateway())
869    }
870}
871
872#[cfg(feature = "http2")]
873fn try_backend_h2(
874    request: &Request,
875    client_ip: &str,
876    backend: &Backend,
877    connect_timeout: Duration,
878    _read_timeout: Duration,
879) -> Result<Response, String> {
880    use tokio::runtime::Handle;
881    match Handle::try_current() {
882        Ok(_) => tokio::task::block_in_place(|| {
883            Handle::current().block_on(forward_h2_async(request, client_ip, backend, connect_timeout))
884        }),
885        Err(_) => {
886            Err("no async runtime for H2 proxy; falling back to 502".to_string())
887        }
888    }
889}
890
891#[cfg(feature = "http2")]
892async fn forward_h2_async(
893    request: &Request,
894    client_ip: &str,
895    backend: &Backend,
896    connect_timeout: Duration,
897) -> Result<Response, String> {
898    let addr = format!("{}:{}", backend.host, backend.port);
899    let tcp = tokio::time::timeout(
900        connect_timeout,
901        tokio::net::TcpStream::connect(&addr),
902    )
903    .await
904    .map_err(|_| format!("h2 proxy: connect to {} timed out", addr))?
905    .map_err(|e| format!("h2 proxy: connect to {} failed: {}", addr, e))?;
906
907    if backend.tls {
908        use rustls::pki_types::ServerName;
909        use rustls::ClientConfig;
910        use std::sync::Arc;
911        use tokio_rustls::TlsConnector;
912
913        let root_store = rustls::RootCertStore::from_iter(
914            webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
915        );
916        let mut config = ClientConfig::builder()
917            .with_root_certificates(root_store)
918            .with_no_client_auth();
919        // Advertise h2 via ALPN so the server selects HTTP/2.
920        config.alpn_protocols = vec![b"h2".to_vec()];
921        let connector = TlsConnector::from(Arc::new(config));
922        let server_name = ServerName::try_from(backend.host.as_str())
923            .map_err(|e| format!("invalid upstream hostname '{}': {}", backend.host, e))?
924            .to_owned();
925        let tls_stream = connector
926            .connect(server_name, tcp)
927            .await
928            .map_err(|e| format!("h2 proxy: TLS handshake with {} failed: {}", addr, e))?;
929        send_h2_request(request, client_ip, backend, tls_stream).await
930    } else {
931        send_h2_request(request, client_ip, backend, tcp).await
932    }
933}
934
935/// Drive the h2 client handshake + request/response over any async I/O stream.
936///
937/// Accepts both plain `TcpStream` and `TlsStream<TcpStream>` — anything that
938/// satisfies `AsyncRead + AsyncWrite + Unpin + Send + 'static`.
939#[cfg(feature = "http2")]
940async fn send_h2_request<T>(
941    request: &Request,
942    client_ip: &str,
943    backend: &Backend,
944    stream: T,
945) -> Result<Response, String>
946where
947    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
948{
949    use bytes::Bytes;
950    use http as hc;
951
952    let addr = format!("{}:{}", backend.host, backend.port);
953
954    let (send_req, conn) = h2::client::handshake(stream)
955        .await
956        .map_err(|e| format!("h2 proxy: handshake with {} failed: {}", addr, e))?;
957
958    tokio::spawn(async move {
959        let _ = conn.await;
960    });
961
962    let scheme = if backend.tls { "https" } else { "http" };
963    let uri_str = format!("{}://{}{}", scheme, addr, request.request_uri);
964    let uri: hc::Uri = uri_str.parse().map_err(|e: hc::uri::InvalidUri| e.to_string())?;
965    let method = hc::Method::from_bytes(request.method.as_bytes()).map_err(|e| e.to_string())?;
966
967    let mut builder = hc::Request::builder().method(method).uri(uri);
968    builder = builder.header("host", &backend.host);
969    for h in &request.headers {
970        let lower = h.name.to_lowercase();
971        if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
972            continue;
973        }
974        builder = builder.header(&h.name, &h.value);
975    }
976    builder = builder.header("x-forwarded-for", client_ip);
977    builder = builder.header("via", "2 rws");
978
979    let body_bytes = Bytes::from(request.body.clone());
980    let end_of_stream = body_bytes.is_empty();
981    let http_req = builder.body(()).map_err(|e| e.to_string())?;
982
983    let mut send_req = send_req.ready().await.map_err(|e| e.to_string())?;
984    let (resp_future, mut req_body) = send_req
985        .send_request(http_req, end_of_stream)
986        .map_err(|e| e.to_string())?;
987    if !end_of_stream {
988        req_body.send_data(body_bytes, true).map_err(|e| e.to_string())?;
989    }
990
991    let resp = resp_future.await.map_err(|e| e.to_string())?;
992    let (parts, mut body) = resp.into_parts();
993
994    let content_type = parts
995        .headers
996        .get("content-type")
997        .and_then(|v| v.to_str().ok())
998        .unwrap_or("application/octet-stream")
999        .to_string();
1000
1001    let mut body_bytes: Vec<u8> = Vec::new();
1002    while let Some(chunk) = body.data().await {
1003        body_bytes.extend_from_slice(&chunk.map_err(|e| e.to_string())?);
1004    }
1005
1006    let mut response = Response::new();
1007    response.status_code = parts.status.as_u16() as i16;
1008    response.reason_phrase = parts.status.canonical_reason().unwrap_or("").to_string();
1009
1010    const H2_HOP: &[&str] = &[
1011        "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-connection", "te",
1012    ];
1013    for (name, value) in &parts.headers {
1014        let lower = name.as_str().to_lowercase();
1015        if H2_HOP.contains(&lower.as_str()) {
1016            continue;
1017        }
1018        if let Ok(v) = value.to_str() {
1019            response.headers.push(crate::header::Header {
1020                name: name.as_str().to_string(),
1021                value: v.to_string(),
1022            });
1023        }
1024    }
1025
1026    if !body_bytes.is_empty() {
1027        response.content_range_list = vec![Range::get_content_range(body_bytes, content_type)];
1028    }
1029
1030    Ok(response)
1031}
1032
1033// ── gRPC proxy ────────────────────────────────────────────────────────────────
1034
1035/// gRPC reverse proxy middleware.
1036///
1037/// Recognises requests with `Content-Type: application/grpc*` and forwards them
1038/// to a backend over HTTP/2, leaving all other requests to the next layer.
1039///
1040/// Requires the `http2` Cargo feature.
1041///
1042/// # Example
1043///
1044/// ```rust,no_run
1045/// use rust_web_server::app::App;
1046/// use rust_web_server::core::New;
1047/// use rust_web_server::proxy::GrpcProxy;
1048///
1049/// let app = App::new()
1050///     .wrap(GrpcProxy::new(["grpc-service:50051"]));
1051/// ```
1052#[cfg(feature = "http2")]
1053pub struct GrpcProxy {
1054    inner: H2ReverseProxy,
1055}
1056
1057#[cfg(feature = "http2")]
1058impl GrpcProxy {
1059    /// Create a proxy distributing gRPC connections across `backends` in round-robin order.
1060    ///
1061    /// Each backend entry can be:
1062    /// - `"host:port"` — plain TCP (gRPC cleartext)
1063    /// - `"grpc://host:port"` — plain TCP (explicit scheme)
1064    /// - `"grpcs://host:port"` — TLS (gRPC over TLS; port defaults to 443)
1065    /// - `"https://host:port"` — TLS (same as `grpcs://`)
1066    ///
1067    /// TLS backends require the `http2` Cargo feature.
1068    pub fn new<I, S>(backends: I) -> Self
1069    where
1070        I: IntoIterator<Item = S>,
1071        S: AsRef<str>,
1072    {
1073        GrpcProxy { inner: H2ReverseProxy::new(backends) }
1074    }
1075
1076    /// Only proxy requests whose URI starts with `prefix`; pass others through.
1077    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
1078        self.inner = self.inner.path_prefix(prefix);
1079        self
1080    }
1081}
1082
1083#[cfg(feature = "http2")]
1084impl crate::middleware::Middleware for GrpcProxy {
1085    fn handle(
1086        &self,
1087        request: &crate::request::Request,
1088        connection: &crate::server::ConnectionInfo,
1089        next: &dyn crate::application::Application,
1090    ) -> Result<crate::response::Response, String> {
1091        let ct = request
1092            .get_header("content-type".to_string())
1093            .map(|h| h.value.as_str())
1094            .unwrap_or("");
1095        if ct.starts_with("application/grpc") {
1096            self.inner.handle(request, connection, next)
1097        } else {
1098            next.execute(request, connection)
1099        }
1100    }
1101}
1102
1103// ── Backend::parse unit tests ─────────────────────────────────────────────────
1104
1105#[cfg(test)]
1106mod backend_parse_tests {
1107    use super::Backend;
1108
1109    fn parse(url: &str) -> Option<(String, u16, bool)> {
1110        Backend::parse(url).map(|b| (b.host, b.port, b.tls))
1111    }
1112
1113    #[test]
1114    fn bare_host_port() {
1115        assert_eq!(Some(("api.example.com".into(), 8080, false)), parse("api.example.com:8080"));
1116    }
1117
1118    #[test]
1119    fn http_scheme() {
1120        assert_eq!(Some(("backend".into(), 3000, false)), parse("http://backend:3000"));
1121    }
1122
1123    #[test]
1124    fn h2_scheme_plain() {
1125        assert_eq!(Some(("svc".into(), 50051, false)), parse("h2://svc:50051"));
1126    }
1127
1128    #[test]
1129    fn grpc_scheme_plain() {
1130        assert_eq!(Some(("svc".into(), 50051, false)), parse("grpc://svc:50051"));
1131    }
1132
1133    #[test]
1134    fn https_scheme_sets_tls_and_default_port() {
1135        assert_eq!(Some(("api.example.com".into(), 443, true)), parse("https://api.example.com"));
1136    }
1137
1138    #[test]
1139    fn https_scheme_explicit_port() {
1140        assert_eq!(Some(("api.example.com".into(), 8443, true)), parse("https://api.example.com:8443"));
1141    }
1142
1143    #[test]
1144    fn h2s_scheme_sets_tls() {
1145        assert_eq!(Some(("svc".into(), 443, true)), parse("h2s://svc"));
1146    }
1147
1148    #[test]
1149    fn h2s_scheme_explicit_port() {
1150        assert_eq!(Some(("svc".into(), 8443, true)), parse("h2s://svc:8443"));
1151    }
1152
1153    #[test]
1154    fn grpcs_scheme_sets_tls() {
1155        assert_eq!(Some(("grpc-svc".into(), 443, true)), parse("grpcs://grpc-svc"));
1156    }
1157
1158    #[test]
1159    fn grpcs_scheme_explicit_port() {
1160        assert_eq!(Some(("grpc-svc".into(), 50052, true)), parse("grpcs://grpc-svc:50052"));
1161    }
1162
1163    #[test]
1164    fn empty_host_returns_none() {
1165        assert_eq!(None, parse("https://"));
1166    }
1167
1168    #[test]
1169    fn bare_host_no_port_defaults_to_80() {
1170        assert_eq!(Some(("myhost".into(), 80, false)), parse("myhost"));
1171    }
1172}