Skip to main content

rust_web_server/proxy/
mod.rs

1//! Reverse proxy middleware with round-robin load balancing.
2//!
3//! `ReverseProxy` implements [`Middleware`] — wrap any application with it and
4//! all matching requests are forwarded to one of the configured backends over
5//! plain HTTP/1.1.  Failed backends are skipped and the next one is tried
6//! before returning `502 Bad Gateway`.
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! use rust_web_server::app::App;
12//! use rust_web_server::core::New;
13//! use rust_web_server::proxy::{LoadBalancing, ReverseProxy};
14//!
15//! // Proxy every request across two backends in round-robin order.
16//! let app = App::new()
17//!     .wrap(ReverseProxy::new(["http://backend-1:8080", "http://backend-2:8080"])
18//!         .strategy(LoadBalancing::RoundRobin));
19//!
20//! // Only proxy /api/* requests; everything else is handled locally.
21//! let app2 = App::new()
22//!     .wrap(ReverseProxy::new(["http://api-service:3000"])
23//!         .path_prefix("/api"));
24//! ```
25
26#[cfg(test)]
27mod tests;
28
29use std::io::{Read, Write};
30use std::net::{TcpStream, ToSocketAddrs};
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::time::Duration;
33
34
35use crate::application::Application;
36use crate::core::New;
37use crate::middleware::Middleware;
38use crate::mime_type::MimeType;
39use crate::range::Range;
40use crate::request::Request;
41use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
42use crate::server::ConnectionInfo;
43
44// Hop-by-hop headers that must not be forwarded (RFC 7230 §6.1)
45const HOP_BY_HOP: &[&str] = &[
46    "connection",
47    "keep-alive",
48    "proxy-authenticate",
49    "proxy-authorization",
50    "te",
51    "trailers",
52    "transfer-encoding",
53    "upgrade",
54];
55
56/// Load balancing strategy used by [`ReverseProxy`].
57pub enum LoadBalancing {
58    /// Distribute requests across backends in a cyclic order.
59    RoundRobin,
60}
61
62/// Reverse proxy middleware.
63///
64/// Forwards incoming requests to one of the configured backends over HTTP/1.1.
65/// On connection failure the next backend in the list is tried; when all
66/// backends have failed the middleware returns `502 Bad Gateway`.
67///
68/// Hop-by-hop headers are stripped before forwarding.  `X-Forwarded-For` and
69/// `Via` are added to every forwarded request.
70///
71/// # Limitations
72///
73/// * Only plain HTTP backends are supported (no TLS to the upstream).
74/// * Chunked transfer encoding from the backend is forwarded as-is; callers
75///   that need decoded bodies should set `Content-Length` on the upstream.
76pub struct ReverseProxy {
77    backends: Vec<Backend>,
78    path_prefix: Option<String>,
79    connect_timeout: Duration,
80    read_timeout: Duration,
81    counter: AtomicUsize,
82}
83
84impl ReverseProxy {
85    /// Create a proxy that distributes requests across `backends` in
86    /// round-robin order.  Each entry must be `"http://host:port"` or
87    /// `"host:port"` (port defaults to 80).
88    pub fn new<I, S>(backends: I) -> Self
89    where
90        I: IntoIterator<Item = S>,
91        S: AsRef<str>,
92    {
93        Self {
94            backends: backends
95                .into_iter()
96                .filter_map(|u| Backend::parse(u.as_ref()))
97                .collect(),
98            path_prefix: None,
99            connect_timeout: Duration::from_secs(5),
100            read_timeout: Duration::from_secs(30),
101            counter: AtomicUsize::new(0),
102        }
103    }
104
105    /// Only proxy requests whose URI starts with `prefix`.
106    ///
107    /// Other requests are passed through to the next layer in the middleware
108    /// chain (or the inner application).
109    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
110        self.path_prefix = Some(prefix.into());
111        self
112    }
113
114    /// Override the load balancing strategy (currently only `RoundRobin`).
115    pub fn strategy(self, _strategy: LoadBalancing) -> Self {
116        self
117    }
118
119    /// Override the TCP connect timeout (default: 5 000 ms).
120    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
121        self.connect_timeout = Duration::from_millis(ms);
122        self
123    }
124
125    /// Override the response read timeout (default: 30 000 ms).
126    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
127        self.read_timeout = Duration::from_millis(ms);
128        self
129    }
130
131    fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
132        if self.backends.is_empty() {
133            return Err("no backends configured".to_string());
134        }
135        let n = self.backends.len();
136        let start = self.counter.fetch_add(1, Ordering::Relaxed);
137        for attempt in 0..n {
138            let idx = (start + attempt) % n;
139            match self.try_backend(request, connection, &self.backends[idx]) {
140                Ok(resp) => return Ok(resp),
141                Err(_) if attempt + 1 < n => continue,
142                Err(e) => return Err(e),
143            }
144        }
145        Err("all backends failed".to_string())
146    }
147
148    fn try_backend(
149        &self,
150        request: &Request,
151        connection: &ConnectionInfo,
152        backend: &Backend,
153    ) -> Result<Response, String> {
154        let addr_str = format!("{}:{}", backend.host, backend.port);
155        let sock_addr = addr_str
156            .to_socket_addrs()
157            .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
158            .next()
159            .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
160
161        let stream = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
162            .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
163        stream
164            .set_read_timeout(Some(self.read_timeout))
165            .map_err(|e| e.to_string())?;
166        stream
167            .set_write_timeout(Some(Duration::from_secs(10)))
168            .map_err(|e| e.to_string())?;
169
170        let req_bytes = build_request(request, &backend.host, &connection.client.ip);
171        let mut stream = stream;
172        stream
173            .write_all(&req_bytes)
174            .map_err(|e| format!("write to backend failed: {}", e))?;
175
176        let resp_bytes = read_response(&mut stream)?;
177        Response::parse(&resp_bytes)
178    }
179}
180
181impl Middleware for ReverseProxy {
182    fn handle(
183        &self,
184        request: &Request,
185        connection: &ConnectionInfo,
186        next: &dyn Application,
187    ) -> Result<Response, String> {
188        if let Some(prefix) = &self.path_prefix {
189            if !request.request_uri.starts_with(prefix.as_str()) {
190                return next.execute(request, connection);
191            }
192        }
193        match self.proxy(request, connection) {
194            Ok(resp) => Ok(resp),
195            Err(_) => Ok(bad_gateway()),
196        }
197    }
198}
199
200// ── helpers ───────────────────────────────────────────────────────────────────
201
202pub(crate) fn build_request(request: &Request, backend_host: &str, client_ip: &str) -> Vec<u8> {
203    let mut out: Vec<u8> = Vec::new();
204    let _ = write!(
205        out,
206        "{} {} HTTP/1.1\r\nHost: {}\r\n",
207        request.method, request.request_uri, backend_host
208    );
209    for h in &request.headers {
210        let lower = h.name.to_lowercase();
211        if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
212            continue;
213        }
214        let _ = write!(out, "{}: {}\r\n", h.name, h.value);
215    }
216    let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
217    let _ = write!(out, "Via: 1.1 rws\r\n");
218    let _ = write!(out, "Connection: close\r\n");
219    if !request.body.is_empty() {
220        let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
221    }
222    let _ = write!(out, "\r\n");
223    out.extend_from_slice(&request.body);
224    out
225}
226
227pub(crate) fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
228    let mut buf: Vec<u8> = Vec::with_capacity(8192);
229    let mut tmp = [0u8; 4096];
230
231    // Read until the header block ends (\r\n\r\n)
232    let header_end = loop {
233        let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
234        if n == 0 {
235            return if buf.is_empty() {
236                Err("backend closed connection without sending a response".to_string())
237            } else {
238                Ok(buf)
239            };
240        }
241        buf.extend_from_slice(&tmp[..n]);
242        if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
243            break pos + 4;
244        }
245    };
246
247    // Parse Content-Length from headers
248    let content_length = std::str::from_utf8(&buf[..header_end])
249        .unwrap_or("")
250        .lines()
251        .find_map(|line| {
252            line.to_lowercase()
253                .starts_with("content-length:")
254                .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
255                .flatten()
256        });
257
258    match content_length {
259        Some(len) => {
260            while buf.len() < header_end + len {
261                let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
262                if n == 0 {
263                    break;
264                }
265                buf.extend_from_slice(&tmp[..n]);
266            }
267        }
268        None => loop {
269            let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
270            if n == 0 {
271                break;
272            }
273            buf.extend_from_slice(&tmp[..n]);
274        },
275    }
276
277    Ok(buf)
278}
279
280/// Forward a single HTTP/1.1 request to `host:port` and return the response.
281///
282/// This is the shared low-level building block used by [`crate::canary`] and
283/// [`crate::ingress`] so they don't have to duplicate the TCP + request/response
284/// marshalling code.
285pub(crate) fn proxy_http1(
286    request: &Request,
287    client_ip: &str,
288    host: &str,
289    port: u16,
290    connect_timeout: Duration,
291    read_timeout: Duration,
292) -> Result<Response, String> {
293    use std::net::ToSocketAddrs;
294    let addr_str = format!("{}:{}", host, port);
295    let sock_addr = addr_str
296        .to_socket_addrs()
297        .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
298        .next()
299        .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
300    let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
301        .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
302    stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
303    stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
304    let req_bytes = build_request(request, host, client_ip);
305    let mut stream = stream;
306    stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
307    let resp_bytes = read_response(&mut stream)?;
308    Response::parse(&resp_bytes)
309}
310
311fn bad_gateway() -> Response {
312    let cr = Range::get_content_range(
313        b"502 Bad Gateway".to_vec(),
314        MimeType::TEXT_PLAIN.to_string(),
315    );
316    let mut r = Response::new();
317    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
318    r.reason_phrase = STATUS_CODE_REASON_PHRASE
319        .n502_bad_gateway
320        .reason_phrase
321        .to_string();
322    r.content_range_list = vec![cr];
323    r
324}
325
326// ── Backend URL parsing ───────────────────────────────────────────────────────
327
328struct Backend {
329    host: String,
330    port: u16,
331}
332
333impl Backend {
334    fn parse(url: &str) -> Option<Self> {
335        let rest = url
336            .strip_prefix("https://")
337            .or_else(|| url.strip_prefix("http://"))
338            .or_else(|| url.strip_prefix("h2://"))
339            .unwrap_or(url);
340        // Drop any path component
341        let host_port = rest.split('/').next().unwrap_or(rest);
342        let (host, port) = if let Some(colon) = host_port.rfind(':') {
343            let port_str = &host_port[colon + 1..];
344            if let Ok(p) = port_str.parse::<u16>() {
345                (host_port[..colon].to_string(), p)
346            } else {
347                (host_port.to_string(), 80)
348            }
349        } else {
350            (host_port.to_string(), 80)
351        };
352        if host.is_empty() {
353            return None;
354        }
355        Some(Backend { host, port })
356    }
357}
358
359// ── HTTP/2 reverse proxy ──────────────────────────────────────────────────────
360
361/// Reverse proxy that forwards requests to HTTP/2 backends.
362///
363/// Wraps [`ReverseProxy`] and forces HTTP/2 (`h2`) for all upstream connections.
364/// Requires the `http2` Cargo feature.
365///
366/// This proxy also transparently handles gRPC traffic
367/// (`Content-Type: application/grpc*`) — gRPC DATA frames are forwarded
368/// as-is because gRPC is HTTP/2. Note that HTTP/2 trailers (used by gRPC for
369/// `grpc-status` and `grpc-message`) are not yet propagated.
370///
371/// # Example
372///
373/// ```rust,no_run
374/// use rust_web_server::app::App;
375/// use rust_web_server::core::New;
376/// use rust_web_server::proxy::H2ReverseProxy;
377///
378/// let app = App::new()
379///     .wrap(H2ReverseProxy::new(["grpc-service:9090"])
380///         .path_prefix("/svc.MyService"));
381/// ```
382#[cfg(feature = "http2")]
383pub struct H2ReverseProxy {
384    inner: ReverseProxy,
385}
386
387#[cfg(feature = "http2")]
388impl H2ReverseProxy {
389    /// Create a proxy distributing requests across `backends` in round-robin order.
390    /// Each entry must be `"host:port"` or `"h2://host:port"`.
391    pub fn new<I, S>(backends: I) -> Self
392    where
393        I: IntoIterator<Item = S>,
394        S: AsRef<str>,
395    {
396        H2ReverseProxy {
397            inner: ReverseProxy::new(backends),
398        }
399    }
400
401    /// Only proxy requests whose URI starts with `prefix`; pass others through.
402    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
403        self.inner = self.inner.path_prefix(prefix);
404        self
405    }
406
407    /// Override the TCP connect timeout (default: 5 s).
408    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
409        self.inner = self.inner.connect_timeout_ms(ms);
410        self
411    }
412
413    /// Override the response read timeout (default: 30 s).
414    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
415        self.inner = self.inner.read_timeout_ms(ms);
416        self
417    }
418}
419
420#[cfg(feature = "http2")]
421impl crate::middleware::Middleware for H2ReverseProxy {
422    fn handle(
423        &self,
424        request: &crate::request::Request,
425        connection: &crate::server::ConnectionInfo,
426        next: &dyn crate::application::Application,
427    ) -> Result<crate::response::Response, String> {
428        if let Some(prefix) = &self.inner.path_prefix {
429            if !request.request_uri.starts_with(prefix.as_str()) {
430                return next.execute(request, connection);
431            }
432        }
433        if self.inner.backends.is_empty() {
434            return Ok(bad_gateway());
435        }
436        let n = self.inner.backends.len();
437        let start = self.inner.counter.fetch_add(1, Ordering::Relaxed);
438        for attempt in 0..n {
439            let idx = (start + attempt) % n;
440            match try_backend_h2(request, &connection.client.ip, &self.inner.backends[idx],
441                                  self.inner.connect_timeout, self.inner.read_timeout) {
442                Ok(resp) => return Ok(resp),
443                Err(_) if attempt + 1 < n => continue,
444                Err(_) => break,
445            }
446        }
447        Ok(bad_gateway())
448    }
449}
450
451#[cfg(feature = "http2")]
452fn try_backend_h2(
453    request: &Request,
454    client_ip: &str,
455    backend: &Backend,
456    connect_timeout: Duration,
457    _read_timeout: Duration,
458) -> Result<Response, String> {
459    use tokio::runtime::Handle;
460    match Handle::try_current() {
461        Ok(_) => tokio::task::block_in_place(|| {
462            Handle::current().block_on(forward_h2_async(request, client_ip, backend, connect_timeout))
463        }),
464        Err(_) => {
465            // No tokio runtime (http1-only path): fall back to HTTP/1.1 upstream.
466            Err("no async runtime for H2 proxy; falling back to 502".to_string())
467        }
468    }
469}
470
471#[cfg(feature = "http2")]
472async fn forward_h2_async(
473    request: &Request,
474    client_ip: &str,
475    backend: &Backend,
476    connect_timeout: Duration,
477) -> Result<Response, String> {
478    use bytes::Bytes;
479    use http as hc;
480
481    let addr = format!("{}:{}", backend.host, backend.port);
482
483    let tcp = tokio::time::timeout(
484        connect_timeout,
485        tokio::net::TcpStream::connect(&addr),
486    )
487    .await
488    .map_err(|_| format!("h2 proxy: connect to {} timed out", addr))?
489    .map_err(|e| format!("h2 proxy: connect to {} failed: {}", addr, e))?;
490
491    let (send_req, conn) = h2::client::handshake(tcp)
492        .await
493        .map_err(|e| format!("h2 proxy: handshake with {} failed: {}", addr, e))?;
494
495    tokio::spawn(async move {
496        let _ = conn.await;
497    });
498
499    let uri_str = format!("http://{}{}", addr, request.request_uri);
500    let uri: hc::Uri = uri_str.parse().map_err(|e: hc::uri::InvalidUri| e.to_string())?;
501    let method = hc::Method::from_bytes(request.method.as_bytes()).map_err(|e| e.to_string())?;
502
503    let mut builder = hc::Request::builder().method(method).uri(uri);
504    builder = builder.header("host", &backend.host);
505    for h in &request.headers {
506        let lower = h.name.to_lowercase();
507        if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
508            continue;
509        }
510        builder = builder.header(&h.name, &h.value);
511    }
512    builder = builder.header("x-forwarded-for", client_ip);
513    builder = builder.header("via", "2 rws");
514
515    let body_bytes = Bytes::from(request.body.clone());
516    let end_of_stream = body_bytes.is_empty();
517    let http_req = builder.body(()).map_err(|e| e.to_string())?;
518
519    let mut send_req = send_req.ready().await.map_err(|e| e.to_string())?;
520    let (resp_future, mut req_body) = send_req
521        .send_request(http_req, end_of_stream)
522        .map_err(|e| e.to_string())?;
523    if !end_of_stream {
524        req_body.send_data(body_bytes, true).map_err(|e| e.to_string())?;
525    }
526
527    let resp = resp_future.await.map_err(|e| e.to_string())?;
528    let (parts, mut body) = resp.into_parts();
529
530    let content_type = parts
531        .headers
532        .get("content-type")
533        .and_then(|v| v.to_str().ok())
534        .unwrap_or("application/octet-stream")
535        .to_string();
536
537    let mut body_bytes: Vec<u8> = Vec::new();
538    while let Some(chunk) = body.data().await {
539        body_bytes.extend_from_slice(&chunk.map_err(|e| e.to_string())?);
540    }
541
542    let mut response = Response::new();
543    response.status_code = parts.status.as_u16() as i16;
544    response.reason_phrase = parts.status.canonical_reason().unwrap_or("").to_string();
545
546    const H2_HOP: &[&str] = &["connection", "keep-alive", "transfer-encoding",
547                                "upgrade", "proxy-connection", "te"];
548    for (name, value) in &parts.headers {
549        let lower = name.as_str().to_lowercase();
550        if H2_HOP.contains(&lower.as_str()) { continue; }
551        if let Ok(v) = value.to_str() {
552            response.headers.push(crate::header::Header {
553                name: name.as_str().to_string(),
554                value: v.to_string(),
555            });
556        }
557    }
558
559    if !body_bytes.is_empty() {
560        response.content_range_list = vec![Range::get_content_range(body_bytes, content_type)];
561    }
562
563    Ok(response)
564}
565
566// ── gRPC proxy ────────────────────────────────────────────────────────────────
567
568/// gRPC reverse proxy middleware.
569///
570/// Recognises requests with `Content-Type: application/grpc*` and forwards them
571/// to a backend over HTTP/2, leaving all other requests to the next layer.
572///
573/// gRPC DATA frames are forwarded as-is because gRPC is layered directly on
574/// HTTP/2. HTTP/2 trailers (`grpc-status`, `grpc-message`) are not yet
575/// propagated — a known limitation of the current implementation.
576///
577/// Requires the `http2` Cargo feature.
578///
579/// # Example
580///
581/// ```rust,no_run
582/// use rust_web_server::app::App;
583/// use rust_web_server::core::New;
584/// use rust_web_server::proxy::GrpcProxy;
585///
586/// let app = App::new()
587///     .wrap(GrpcProxy::new(["grpc-service:50051"]));
588/// ```
589#[cfg(feature = "http2")]
590pub struct GrpcProxy {
591    inner: H2ReverseProxy,
592}
593
594#[cfg(feature = "http2")]
595impl GrpcProxy {
596    /// Create a proxy distributing gRPC connections across `backends` in round-robin order.
597    pub fn new<I, S>(backends: I) -> Self
598    where
599        I: IntoIterator<Item = S>,
600        S: AsRef<str>,
601    {
602        GrpcProxy { inner: H2ReverseProxy::new(backends) }
603    }
604
605    /// Only proxy requests whose URI starts with `prefix`; pass others through.
606    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
607        self.inner = self.inner.path_prefix(prefix);
608        self
609    }
610}
611
612#[cfg(feature = "http2")]
613impl crate::middleware::Middleware for GrpcProxy {
614    fn handle(
615        &self,
616        request: &crate::request::Request,
617        connection: &crate::server::ConnectionInfo,
618        next: &dyn crate::application::Application,
619    ) -> Result<crate::response::Response, String> {
620        let ct = request
621            .get_header("content-type".to_string())
622            .map(|h| h.value.as_str())
623            .unwrap_or("");
624        if ct.starts_with("application/grpc") {
625            self.inner.handle(request, connection, next)
626        } else {
627            next.execute(request, connection)
628        }
629    }
630}