Skip to main content

rust_web_server/proxy/
mod.rs

1//! Reverse proxy middleware with round-robin load balancing.
2//!
3//! `ReverseProxy` implements [`Middleware`] — wrap any application with it and
4//! all matching requests are forwarded to one of the configured backends over
5//! plain HTTP/1.1.  Failed backends are skipped and the next one is tried
6//! before returning `502 Bad Gateway`.
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! use rust_web_server::app::App;
12//! use rust_web_server::core::New;
13//! use rust_web_server::proxy::{LoadBalancing, ReverseProxy};
14//!
15//! // Proxy every request across two backends in round-robin order.
16//! let app = App::new()
17//!     .wrap(ReverseProxy::new(["http://backend-1:8080", "http://backend-2:8080"])
18//!         .strategy(LoadBalancing::RoundRobin));
19//!
20//! // Only proxy /api/* requests; everything else is handled locally.
21//! let app2 = App::new()
22//!     .wrap(ReverseProxy::new(["http://api-service:3000"])
23//!         .path_prefix("/api"));
24//! ```
25
26#[cfg(test)]
27mod tests;
28
29use std::io::{Read, Write};
30use std::net::{TcpStream, ToSocketAddrs};
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::time::Duration;
33
34use crate::application::Application;
35use crate::core::New;
36use crate::middleware::Middleware;
37use crate::mime_type::MimeType;
38use crate::range::Range;
39use crate::request::Request;
40use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
41use crate::server::ConnectionInfo;
42
43// Hop-by-hop headers that must not be forwarded (RFC 7230 §6.1)
44const HOP_BY_HOP: &[&str] = &[
45    "connection",
46    "keep-alive",
47    "proxy-authenticate",
48    "proxy-authorization",
49    "te",
50    "trailers",
51    "transfer-encoding",
52    "upgrade",
53];
54
55/// Load balancing strategy used by [`ReverseProxy`].
56pub enum LoadBalancing {
57    /// Distribute requests across backends in a cyclic order.
58    RoundRobin,
59}
60
61/// Reverse proxy middleware.
62///
63/// Forwards incoming requests to one of the configured backends over HTTP/1.1.
64/// On connection failure the next backend in the list is tried; when all
65/// backends have failed the middleware returns `502 Bad Gateway`.
66///
67/// Hop-by-hop headers are stripped before forwarding.  `X-Forwarded-For` and
68/// `Via` are added to every forwarded request.
69///
70/// # Limitations
71///
72/// * Only plain HTTP backends are supported (no TLS to the upstream).
73/// * Chunked transfer encoding from the backend is forwarded as-is; callers
74///   that need decoded bodies should set `Content-Length` on the upstream.
75pub struct ReverseProxy {
76    backends: Vec<Backend>,
77    path_prefix: Option<String>,
78    connect_timeout: Duration,
79    read_timeout: Duration,
80    counter: AtomicUsize,
81}
82
83impl ReverseProxy {
84    /// Create a proxy that distributes requests across `backends` in
85    /// round-robin order.  Each entry must be `"http://host:port"` or
86    /// `"host:port"` (port defaults to 80).
87    pub fn new<I, S>(backends: I) -> Self
88    where
89        I: IntoIterator<Item = S>,
90        S: AsRef<str>,
91    {
92        Self {
93            backends: backends
94                .into_iter()
95                .filter_map(|u| Backend::parse(u.as_ref()))
96                .collect(),
97            path_prefix: None,
98            connect_timeout: Duration::from_secs(5),
99            read_timeout: Duration::from_secs(30),
100            counter: AtomicUsize::new(0),
101        }
102    }
103
104    /// Only proxy requests whose URI starts with `prefix`.
105    ///
106    /// Other requests are passed through to the next layer in the middleware
107    /// chain (or the inner application).
108    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
109        self.path_prefix = Some(prefix.into());
110        self
111    }
112
113    /// Override the load balancing strategy (currently only `RoundRobin`).
114    pub fn strategy(self, _strategy: LoadBalancing) -> Self {
115        self
116    }
117
118    /// Override the TCP connect timeout (default: 5 000 ms).
119    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
120        self.connect_timeout = Duration::from_millis(ms);
121        self
122    }
123
124    /// Override the response read timeout (default: 30 000 ms).
125    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
126        self.read_timeout = Duration::from_millis(ms);
127        self
128    }
129
130    fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
131        if self.backends.is_empty() {
132            return Err("no backends configured".to_string());
133        }
134        let n = self.backends.len();
135        let start = self.counter.fetch_add(1, Ordering::Relaxed);
136        for attempt in 0..n {
137            let idx = (start + attempt) % n;
138            match self.try_backend(request, connection, &self.backends[idx]) {
139                Ok(resp) => return Ok(resp),
140                Err(_) if attempt + 1 < n => continue,
141                Err(e) => return Err(e),
142            }
143        }
144        Err("all backends failed".to_string())
145    }
146
147    fn try_backend(
148        &self,
149        request: &Request,
150        connection: &ConnectionInfo,
151        backend: &Backend,
152    ) -> Result<Response, String> {
153        let addr_str = format!("{}:{}", backend.host, backend.port);
154        let sock_addr = addr_str
155            .to_socket_addrs()
156            .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
157            .next()
158            .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
159
160        let stream = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
161            .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
162        stream
163            .set_read_timeout(Some(self.read_timeout))
164            .map_err(|e| e.to_string())?;
165        stream
166            .set_write_timeout(Some(Duration::from_secs(10)))
167            .map_err(|e| e.to_string())?;
168
169        let req_bytes = build_request(request, &backend.host, &connection.client.ip);
170        let mut stream = stream;
171        stream
172            .write_all(&req_bytes)
173            .map_err(|e| format!("write to backend failed: {}", e))?;
174
175        let resp_bytes = read_response(&mut stream)?;
176        Response::parse(&resp_bytes)
177    }
178}
179
180impl Middleware for ReverseProxy {
181    fn handle(
182        &self,
183        request: &Request,
184        connection: &ConnectionInfo,
185        next: &dyn Application,
186    ) -> Result<Response, String> {
187        if let Some(prefix) = &self.path_prefix {
188            if !request.request_uri.starts_with(prefix.as_str()) {
189                return next.execute(request, connection);
190            }
191        }
192        match self.proxy(request, connection) {
193            Ok(resp) => Ok(resp),
194            Err(_) => Ok(bad_gateway()),
195        }
196    }
197}
198
199// ── helpers ───────────────────────────────────────────────────────────────────
200
201fn build_request(request: &Request, backend_host: &str, client_ip: &str) -> Vec<u8> {
202    let mut out: Vec<u8> = Vec::new();
203    let _ = write!(
204        out,
205        "{} {} HTTP/1.1\r\nHost: {}\r\n",
206        request.method, request.request_uri, backend_host
207    );
208    for h in &request.headers {
209        let lower = h.name.to_lowercase();
210        if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
211            continue;
212        }
213        let _ = write!(out, "{}: {}\r\n", h.name, h.value);
214    }
215    let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
216    let _ = write!(out, "Via: 1.1 rws\r\n");
217    let _ = write!(out, "Connection: close\r\n");
218    if !request.body.is_empty() {
219        let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
220    }
221    let _ = write!(out, "\r\n");
222    out.extend_from_slice(&request.body);
223    out
224}
225
226fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
227    let mut buf: Vec<u8> = Vec::with_capacity(8192);
228    let mut tmp = [0u8; 4096];
229
230    // Read until the header block ends (\r\n\r\n)
231    let header_end = loop {
232        let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
233        if n == 0 {
234            return if buf.is_empty() {
235                Err("backend closed connection without sending a response".to_string())
236            } else {
237                Ok(buf)
238            };
239        }
240        buf.extend_from_slice(&tmp[..n]);
241        if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
242            break pos + 4;
243        }
244    };
245
246    // Parse Content-Length from headers
247    let content_length = std::str::from_utf8(&buf[..header_end])
248        .unwrap_or("")
249        .lines()
250        .find_map(|line| {
251            line.to_lowercase()
252                .starts_with("content-length:")
253                .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
254                .flatten()
255        });
256
257    match content_length {
258        Some(len) => {
259            while buf.len() < header_end + len {
260                let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
261                if n == 0 {
262                    break;
263                }
264                buf.extend_from_slice(&tmp[..n]);
265            }
266        }
267        None => loop {
268            let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
269            if n == 0 {
270                break;
271            }
272            buf.extend_from_slice(&tmp[..n]);
273        },
274    }
275
276    Ok(buf)
277}
278
279fn bad_gateway() -> Response {
280    let cr = Range::get_content_range(
281        b"502 Bad Gateway".to_vec(),
282        MimeType::TEXT_PLAIN.to_string(),
283    );
284    let mut r = Response::new();
285    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
286    r.reason_phrase = STATUS_CODE_REASON_PHRASE
287        .n502_bad_gateway
288        .reason_phrase
289        .to_string();
290    r.content_range_list = vec![cr];
291    r
292}
293
294// ── Backend URL parsing ───────────────────────────────────────────────────────
295
296struct Backend {
297    host: String,
298    port: u16,
299}
300
301impl Backend {
302    fn parse(url: &str) -> Option<Self> {
303        let rest = url
304            .strip_prefix("https://")
305            .or_else(|| url.strip_prefix("http://"))
306            .unwrap_or(url);
307        // Drop any path component
308        let host_port = rest.split('/').next().unwrap_or(rest);
309        let (host, port) = if let Some(colon) = host_port.rfind(':') {
310            let port_str = &host_port[colon + 1..];
311            if let Ok(p) = port_str.parse::<u16>() {
312                (host_port[..colon].to_string(), p)
313            } else {
314                (host_port.to_string(), 80)
315            }
316        } else {
317            (host_port.to_string(), 80)
318        };
319        if host.is_empty() {
320            return None;
321        }
322        Some(Backend { host, port })
323    }
324}