Skip to main content

rust_web_server/proxy/
mod.rs

1//! Reverse proxy middleware with round-robin load balancing.
2//!
3//! `ReverseProxy` implements [`Middleware`] — wrap any application with it and
4//! all matching requests are forwarded to one of the configured backends over
5//! plain HTTP/1.1.  Failed backends are skipped and the next one is tried
6//! before returning `502 Bad Gateway`.
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! use rust_web_server::app::App;
12//! use rust_web_server::core::New;
13//! use rust_web_server::proxy::{LoadBalancing, ReverseProxy};
14//!
15//! // Proxy every request across two backends in round-robin order.
16//! let app = App::new()
17//!     .wrap(ReverseProxy::new(["http://backend-1:8080", "http://backend-2:8080"])
18//!         .strategy(LoadBalancing::RoundRobin));
19//!
20//! // Only proxy /api/* requests; everything else is handled locally.
21//! let app2 = App::new()
22//!     .wrap(ReverseProxy::new(["http://api-service:3000"])
23//!         .path_prefix("/api"));
24//! ```
25
26#[cfg(test)]
27mod tests;
28
29use std::io::{Read, Write};
30use std::net::{TcpStream, ToSocketAddrs};
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::time::Duration;
33
34
35use crate::application::Application;
36use crate::core::New;
37use crate::middleware::Middleware;
38use crate::mime_type::MimeType;
39use crate::range::Range;
40use crate::request::Request;
41use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
42use crate::server::ConnectionInfo;
43
44// Hop-by-hop headers that must not be forwarded (RFC 7230 §6.1)
45const HOP_BY_HOP: &[&str] = &[
46    "connection",
47    "keep-alive",
48    "proxy-authenticate",
49    "proxy-authorization",
50    "te",
51    "trailers",
52    "transfer-encoding",
53    "upgrade",
54];
55
56/// Load balancing strategy used by [`ReverseProxy`].
57pub enum LoadBalancing {
58    /// Distribute requests across backends in a cyclic order.
59    RoundRobin,
60}
61
62/// Reverse proxy middleware.
63///
64/// Forwards incoming requests to one of the configured backends over HTTP/1.1.
65/// On connection failure the next backend in the list is tried; when all
66/// backends have failed the middleware returns `502 Bad Gateway`.
67///
68/// Hop-by-hop headers are stripped before forwarding.  `X-Forwarded-For` and
69/// `Via` are added to every forwarded request.
70///
71/// # Limitations
72///
73/// * Only plain HTTP backends are supported (no TLS to the upstream).
74/// * Chunked transfer encoding from the backend is forwarded as-is; callers
75///   that need decoded bodies should set `Content-Length` on the upstream.
76pub struct ReverseProxy {
77    backends: Vec<Backend>,
78    path_prefix: Option<String>,
79    connect_timeout: Duration,
80    read_timeout: Duration,
81    counter: AtomicUsize,
82}
83
84impl ReverseProxy {
85    /// Create a proxy that distributes requests across `backends` in
86    /// round-robin order.  Each entry must be `"http://host:port"` or
87    /// `"host:port"` (port defaults to 80).
88    pub fn new<I, S>(backends: I) -> Self
89    where
90        I: IntoIterator<Item = S>,
91        S: AsRef<str>,
92    {
93        Self {
94            backends: backends
95                .into_iter()
96                .filter_map(|u| Backend::parse(u.as_ref()))
97                .collect(),
98            path_prefix: None,
99            connect_timeout: Duration::from_secs(5),
100            read_timeout: Duration::from_secs(30),
101            counter: AtomicUsize::new(0),
102        }
103    }
104
105    /// Only proxy requests whose URI starts with `prefix`.
106    ///
107    /// Other requests are passed through to the next layer in the middleware
108    /// chain (or the inner application).
109    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
110        self.path_prefix = Some(prefix.into());
111        self
112    }
113
114    /// Override the load balancing strategy (currently only `RoundRobin`).
115    pub fn strategy(self, _strategy: LoadBalancing) -> Self {
116        self
117    }
118
119    /// Override the TCP connect timeout (default: 5 000 ms).
120    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
121        self.connect_timeout = Duration::from_millis(ms);
122        self
123    }
124
125    /// Override the response read timeout (default: 30 000 ms).
126    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
127        self.read_timeout = Duration::from_millis(ms);
128        self
129    }
130
131    fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
132        if self.backends.is_empty() {
133            return Err("no backends configured".to_string());
134        }
135        let n = self.backends.len();
136        let start = self.counter.fetch_add(1, Ordering::Relaxed);
137        for attempt in 0..n {
138            let idx = (start + attempt) % n;
139            match self.try_backend(request, connection, &self.backends[idx]) {
140                Ok(resp) => return Ok(resp),
141                Err(_) if attempt + 1 < n => continue,
142                Err(e) => return Err(e),
143            }
144        }
145        Err("all backends failed".to_string())
146    }
147
148    fn try_backend(
149        &self,
150        request: &Request,
151        connection: &ConnectionInfo,
152        backend: &Backend,
153    ) -> Result<Response, String> {
154        let addr_str = format!("{}:{}", backend.host, backend.port);
155        let sock_addr = addr_str
156            .to_socket_addrs()
157            .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
158            .next()
159            .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
160
161        let stream = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
162            .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
163        stream
164            .set_read_timeout(Some(self.read_timeout))
165            .map_err(|e| e.to_string())?;
166        stream
167            .set_write_timeout(Some(Duration::from_secs(10)))
168            .map_err(|e| e.to_string())?;
169
170        let req_bytes = build_request(request, &backend.host, &connection.client.ip);
171        let mut stream = stream;
172        stream
173            .write_all(&req_bytes)
174            .map_err(|e| format!("write to backend failed: {}", e))?;
175
176        let resp_bytes = read_response(&mut stream)?;
177        Response::parse(&resp_bytes)
178    }
179}
180
181impl Middleware for ReverseProxy {
182    fn handle(
183        &self,
184        request: &Request,
185        connection: &ConnectionInfo,
186        next: &dyn Application,
187    ) -> Result<Response, String> {
188        if let Some(prefix) = &self.path_prefix {
189            if !request.request_uri.starts_with(prefix.as_str()) {
190                return next.execute(request, connection);
191            }
192        }
193        match self.proxy(request, connection) {
194            Ok(resp) => Ok(resp),
195            Err(_) => Ok(bad_gateway()),
196        }
197    }
198}
199
200// ── helpers ───────────────────────────────────────────────────────────────────
201
202fn build_request(request: &Request, backend_host: &str, client_ip: &str) -> Vec<u8> {
203    let mut out: Vec<u8> = Vec::new();
204    let _ = write!(
205        out,
206        "{} {} HTTP/1.1\r\nHost: {}\r\n",
207        request.method, request.request_uri, backend_host
208    );
209    for h in &request.headers {
210        let lower = h.name.to_lowercase();
211        if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
212            continue;
213        }
214        let _ = write!(out, "{}: {}\r\n", h.name, h.value);
215    }
216    let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
217    let _ = write!(out, "Via: 1.1 rws\r\n");
218    let _ = write!(out, "Connection: close\r\n");
219    if !request.body.is_empty() {
220        let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
221    }
222    let _ = write!(out, "\r\n");
223    out.extend_from_slice(&request.body);
224    out
225}
226
227fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
228    let mut buf: Vec<u8> = Vec::with_capacity(8192);
229    let mut tmp = [0u8; 4096];
230
231    // Read until the header block ends (\r\n\r\n)
232    let header_end = loop {
233        let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
234        if n == 0 {
235            return if buf.is_empty() {
236                Err("backend closed connection without sending a response".to_string())
237            } else {
238                Ok(buf)
239            };
240        }
241        buf.extend_from_slice(&tmp[..n]);
242        if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
243            break pos + 4;
244        }
245    };
246
247    // Parse Content-Length from headers
248    let content_length = std::str::from_utf8(&buf[..header_end])
249        .unwrap_or("")
250        .lines()
251        .find_map(|line| {
252            line.to_lowercase()
253                .starts_with("content-length:")
254                .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
255                .flatten()
256        });
257
258    match content_length {
259        Some(len) => {
260            while buf.len() < header_end + len {
261                let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
262                if n == 0 {
263                    break;
264                }
265                buf.extend_from_slice(&tmp[..n]);
266            }
267        }
268        None => loop {
269            let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
270            if n == 0 {
271                break;
272            }
273            buf.extend_from_slice(&tmp[..n]);
274        },
275    }
276
277    Ok(buf)
278}
279
280fn bad_gateway() -> Response {
281    let cr = Range::get_content_range(
282        b"502 Bad Gateway".to_vec(),
283        MimeType::TEXT_PLAIN.to_string(),
284    );
285    let mut r = Response::new();
286    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
287    r.reason_phrase = STATUS_CODE_REASON_PHRASE
288        .n502_bad_gateway
289        .reason_phrase
290        .to_string();
291    r.content_range_list = vec![cr];
292    r
293}
294
295// ── Backend URL parsing ───────────────────────────────────────────────────────
296
297struct Backend {
298    host: String,
299    port: u16,
300}
301
302impl Backend {
303    fn parse(url: &str) -> Option<Self> {
304        let rest = url
305            .strip_prefix("https://")
306            .or_else(|| url.strip_prefix("http://"))
307            .or_else(|| url.strip_prefix("h2://"))
308            .unwrap_or(url);
309        // Drop any path component
310        let host_port = rest.split('/').next().unwrap_or(rest);
311        let (host, port) = if let Some(colon) = host_port.rfind(':') {
312            let port_str = &host_port[colon + 1..];
313            if let Ok(p) = port_str.parse::<u16>() {
314                (host_port[..colon].to_string(), p)
315            } else {
316                (host_port.to_string(), 80)
317            }
318        } else {
319            (host_port.to_string(), 80)
320        };
321        if host.is_empty() {
322            return None;
323        }
324        Some(Backend { host, port })
325    }
326}
327
328// ── HTTP/2 reverse proxy ──────────────────────────────────────────────────────
329
330/// Reverse proxy that forwards requests to HTTP/2 backends.
331///
332/// Wraps [`ReverseProxy`] and forces HTTP/2 (`h2`) for all upstream connections.
333/// Requires the `http2` Cargo feature.
334///
335/// This proxy also transparently handles gRPC traffic
336/// (`Content-Type: application/grpc*`) — gRPC DATA frames are forwarded
337/// as-is because gRPC is HTTP/2. Note that HTTP/2 trailers (used by gRPC for
338/// `grpc-status` and `grpc-message`) are not yet propagated.
339///
340/// # Example
341///
342/// ```rust,no_run
343/// use rust_web_server::app::App;
344/// use rust_web_server::core::New;
345/// use rust_web_server::proxy::H2ReverseProxy;
346///
347/// let app = App::new()
348///     .wrap(H2ReverseProxy::new(["grpc-service:9090"])
349///         .path_prefix("/svc.MyService"));
350/// ```
351#[cfg(feature = "http2")]
352pub struct H2ReverseProxy {
353    inner: ReverseProxy,
354}
355
356#[cfg(feature = "http2")]
357impl H2ReverseProxy {
358    /// Create a proxy distributing requests across `backends` in round-robin order.
359    /// Each entry must be `"host:port"` or `"h2://host:port"`.
360    pub fn new<I, S>(backends: I) -> Self
361    where
362        I: IntoIterator<Item = S>,
363        S: AsRef<str>,
364    {
365        H2ReverseProxy {
366            inner: ReverseProxy::new(backends),
367        }
368    }
369
370    /// Only proxy requests whose URI starts with `prefix`; pass others through.
371    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
372        self.inner = self.inner.path_prefix(prefix);
373        self
374    }
375
376    /// Override the TCP connect timeout (default: 5 s).
377    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
378        self.inner = self.inner.connect_timeout_ms(ms);
379        self
380    }
381
382    /// Override the response read timeout (default: 30 s).
383    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
384        self.inner = self.inner.read_timeout_ms(ms);
385        self
386    }
387}
388
389#[cfg(feature = "http2")]
390impl crate::middleware::Middleware for H2ReverseProxy {
391    fn handle(
392        &self,
393        request: &crate::request::Request,
394        connection: &crate::server::ConnectionInfo,
395        next: &dyn crate::application::Application,
396    ) -> Result<crate::response::Response, String> {
397        if let Some(prefix) = &self.inner.path_prefix {
398            if !request.request_uri.starts_with(prefix.as_str()) {
399                return next.execute(request, connection);
400            }
401        }
402        if self.inner.backends.is_empty() {
403            return Ok(bad_gateway());
404        }
405        let n = self.inner.backends.len();
406        let start = self.inner.counter.fetch_add(1, Ordering::Relaxed);
407        for attempt in 0..n {
408            let idx = (start + attempt) % n;
409            match try_backend_h2(request, &connection.client.ip, &self.inner.backends[idx],
410                                  self.inner.connect_timeout, self.inner.read_timeout) {
411                Ok(resp) => return Ok(resp),
412                Err(_) if attempt + 1 < n => continue,
413                Err(_) => break,
414            }
415        }
416        Ok(bad_gateway())
417    }
418}
419
420#[cfg(feature = "http2")]
421fn try_backend_h2(
422    request: &Request,
423    client_ip: &str,
424    backend: &Backend,
425    connect_timeout: Duration,
426    _read_timeout: Duration,
427) -> Result<Response, String> {
428    use tokio::runtime::Handle;
429    match Handle::try_current() {
430        Ok(_) => tokio::task::block_in_place(|| {
431            Handle::current().block_on(forward_h2_async(request, client_ip, backend, connect_timeout))
432        }),
433        Err(_) => {
434            // No tokio runtime (http1-only path): fall back to HTTP/1.1 upstream.
435            Err("no async runtime for H2 proxy; falling back to 502".to_string())
436        }
437    }
438}
439
440#[cfg(feature = "http2")]
441async fn forward_h2_async(
442    request: &Request,
443    client_ip: &str,
444    backend: &Backend,
445    connect_timeout: Duration,
446) -> Result<Response, String> {
447    use bytes::Bytes;
448    use http as hc;
449
450    let addr = format!("{}:{}", backend.host, backend.port);
451
452    let tcp = tokio::time::timeout(
453        connect_timeout,
454        tokio::net::TcpStream::connect(&addr),
455    )
456    .await
457    .map_err(|_| format!("h2 proxy: connect to {} timed out", addr))?
458    .map_err(|e| format!("h2 proxy: connect to {} failed: {}", addr, e))?;
459
460    let (send_req, conn) = h2::client::handshake(tcp)
461        .await
462        .map_err(|e| format!("h2 proxy: handshake with {} failed: {}", addr, e))?;
463
464    tokio::spawn(async move {
465        let _ = conn.await;
466    });
467
468    let uri_str = format!("http://{}{}", addr, request.request_uri);
469    let uri: hc::Uri = uri_str.parse().map_err(|e: hc::uri::InvalidUri| e.to_string())?;
470    let method = hc::Method::from_bytes(request.method.as_bytes()).map_err(|e| e.to_string())?;
471
472    let mut builder = hc::Request::builder().method(method).uri(uri);
473    builder = builder.header("host", &backend.host);
474    for h in &request.headers {
475        let lower = h.name.to_lowercase();
476        if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
477            continue;
478        }
479        builder = builder.header(&h.name, &h.value);
480    }
481    builder = builder.header("x-forwarded-for", client_ip);
482    builder = builder.header("via", "2 rws");
483
484    let body_bytes = Bytes::from(request.body.clone());
485    let end_of_stream = body_bytes.is_empty();
486    let http_req = builder.body(()).map_err(|e| e.to_string())?;
487
488    let mut send_req = send_req.ready().await.map_err(|e| e.to_string())?;
489    let (resp_future, mut req_body) = send_req
490        .send_request(http_req, end_of_stream)
491        .map_err(|e| e.to_string())?;
492    if !end_of_stream {
493        req_body.send_data(body_bytes, true).map_err(|e| e.to_string())?;
494    }
495
496    let resp = resp_future.await.map_err(|e| e.to_string())?;
497    let (parts, mut body) = resp.into_parts();
498
499    let content_type = parts
500        .headers
501        .get("content-type")
502        .and_then(|v| v.to_str().ok())
503        .unwrap_or("application/octet-stream")
504        .to_string();
505
506    let mut body_bytes: Vec<u8> = Vec::new();
507    while let Some(chunk) = body.data().await {
508        body_bytes.extend_from_slice(&chunk.map_err(|e| e.to_string())?);
509    }
510
511    let mut response = Response::new();
512    response.status_code = parts.status.as_u16() as i16;
513    response.reason_phrase = parts.status.canonical_reason().unwrap_or("").to_string();
514
515    const H2_HOP: &[&str] = &["connection", "keep-alive", "transfer-encoding",
516                                "upgrade", "proxy-connection", "te"];
517    for (name, value) in &parts.headers {
518        let lower = name.as_str().to_lowercase();
519        if H2_HOP.contains(&lower.as_str()) { continue; }
520        if let Ok(v) = value.to_str() {
521            response.headers.push(crate::header::Header {
522                name: name.as_str().to_string(),
523                value: v.to_string(),
524            });
525        }
526    }
527
528    if !body_bytes.is_empty() {
529        response.content_range_list = vec![Range::get_content_range(body_bytes, content_type)];
530    }
531
532    Ok(response)
533}
534
535// ── gRPC proxy ────────────────────────────────────────────────────────────────
536
537/// gRPC reverse proxy middleware.
538///
539/// Recognises requests with `Content-Type: application/grpc*` and forwards them
540/// to a backend over HTTP/2, leaving all other requests to the next layer.
541///
542/// gRPC DATA frames are forwarded as-is because gRPC is layered directly on
543/// HTTP/2. HTTP/2 trailers (`grpc-status`, `grpc-message`) are not yet
544/// propagated — a known limitation of the current implementation.
545///
546/// Requires the `http2` Cargo feature.
547///
548/// # Example
549///
550/// ```rust,no_run
551/// use rust_web_server::app::App;
552/// use rust_web_server::core::New;
553/// use rust_web_server::proxy::GrpcProxy;
554///
555/// let app = App::new()
556///     .wrap(GrpcProxy::new(["grpc-service:50051"]));
557/// ```
558#[cfg(feature = "http2")]
559pub struct GrpcProxy {
560    inner: H2ReverseProxy,
561}
562
563#[cfg(feature = "http2")]
564impl GrpcProxy {
565    /// Create a proxy distributing gRPC connections across `backends` in round-robin order.
566    pub fn new<I, S>(backends: I) -> Self
567    where
568        I: IntoIterator<Item = S>,
569        S: AsRef<str>,
570    {
571        GrpcProxy { inner: H2ReverseProxy::new(backends) }
572    }
573
574    /// Only proxy requests whose URI starts with `prefix`; pass others through.
575    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
576        self.inner = self.inner.path_prefix(prefix);
577        self
578    }
579}
580
581#[cfg(feature = "http2")]
582impl crate::middleware::Middleware for GrpcProxy {
583    fn handle(
584        &self,
585        request: &crate::request::Request,
586        connection: &crate::server::ConnectionInfo,
587        next: &dyn crate::application::Application,
588    ) -> Result<crate::response::Response, String> {
589        let ct = request
590            .get_header("content-type".to_string())
591            .map(|h| h.value.as_str())
592            .unwrap_or("");
593        if ct.starts_with("application/grpc") {
594            self.inner.handle(request, connection, next)
595        } else {
596            next.execute(request, connection)
597        }
598    }
599}