1#[cfg(test)]
27mod tests;
28
29use std::io::{Read, Write};
30use std::net::{TcpStream, ToSocketAddrs};
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::time::Duration;
33
34use crate::application::Application;
35use crate::core::New;
36use crate::middleware::Middleware;
37use crate::mime_type::MimeType;
38use crate::range::Range;
39use crate::request::Request;
40use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
41use crate::server::ConnectionInfo;
42
43const HOP_BY_HOP: &[&str] = &[
45 "connection",
46 "keep-alive",
47 "proxy-authenticate",
48 "proxy-authorization",
49 "te",
50 "trailers",
51 "transfer-encoding",
52 "upgrade",
53];
54
55pub enum LoadBalancing {
57 RoundRobin,
59}
60
61pub struct ReverseProxy {
76 backends: Vec<Backend>,
77 path_prefix: Option<String>,
78 connect_timeout: Duration,
79 read_timeout: Duration,
80 counter: AtomicUsize,
81}
82
83impl ReverseProxy {
84 pub fn new<I, S>(backends: I) -> Self
88 where
89 I: IntoIterator<Item = S>,
90 S: AsRef<str>,
91 {
92 Self {
93 backends: backends
94 .into_iter()
95 .filter_map(|u| Backend::parse(u.as_ref()))
96 .collect(),
97 path_prefix: None,
98 connect_timeout: Duration::from_secs(5),
99 read_timeout: Duration::from_secs(30),
100 counter: AtomicUsize::new(0),
101 }
102 }
103
104 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
109 self.path_prefix = Some(prefix.into());
110 self
111 }
112
113 pub fn strategy(self, _strategy: LoadBalancing) -> Self {
115 self
116 }
117
118 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
120 self.connect_timeout = Duration::from_millis(ms);
121 self
122 }
123
124 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
126 self.read_timeout = Duration::from_millis(ms);
127 self
128 }
129
130 fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
131 if self.backends.is_empty() {
132 return Err("no backends configured".to_string());
133 }
134 let n = self.backends.len();
135 let start = self.counter.fetch_add(1, Ordering::Relaxed);
136 for attempt in 0..n {
137 let idx = (start + attempt) % n;
138 match self.try_backend(request, connection, &self.backends[idx]) {
139 Ok(resp) => return Ok(resp),
140 Err(_) if attempt + 1 < n => continue,
141 Err(e) => return Err(e),
142 }
143 }
144 Err("all backends failed".to_string())
145 }
146
147 fn try_backend(
148 &self,
149 request: &Request,
150 connection: &ConnectionInfo,
151 backend: &Backend,
152 ) -> Result<Response, String> {
153 let addr_str = format!("{}:{}", backend.host, backend.port);
154 let sock_addr = addr_str
155 .to_socket_addrs()
156 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
157 .next()
158 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
159
160 let stream = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
161 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
162 stream
163 .set_read_timeout(Some(self.read_timeout))
164 .map_err(|e| e.to_string())?;
165 stream
166 .set_write_timeout(Some(Duration::from_secs(10)))
167 .map_err(|e| e.to_string())?;
168
169 let req_bytes = build_request(request, &backend.host, &connection.client.ip);
170 let mut stream = stream;
171 stream
172 .write_all(&req_bytes)
173 .map_err(|e| format!("write to backend failed: {}", e))?;
174
175 let resp_bytes = read_response(&mut stream)?;
176 Response::parse(&resp_bytes)
177 }
178}
179
180impl Middleware for ReverseProxy {
181 fn handle(
182 &self,
183 request: &Request,
184 connection: &ConnectionInfo,
185 next: &dyn Application,
186 ) -> Result<Response, String> {
187 if let Some(prefix) = &self.path_prefix {
188 if !request.request_uri.starts_with(prefix.as_str()) {
189 return next.execute(request, connection);
190 }
191 }
192 match self.proxy(request, connection) {
193 Ok(resp) => Ok(resp),
194 Err(_) => Ok(bad_gateway()),
195 }
196 }
197}
198
199fn build_request(request: &Request, backend_host: &str, client_ip: &str) -> Vec<u8> {
202 let mut out: Vec<u8> = Vec::new();
203 let _ = write!(
204 out,
205 "{} {} HTTP/1.1\r\nHost: {}\r\n",
206 request.method, request.request_uri, backend_host
207 );
208 for h in &request.headers {
209 let lower = h.name.to_lowercase();
210 if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
211 continue;
212 }
213 let _ = write!(out, "{}: {}\r\n", h.name, h.value);
214 }
215 let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
216 let _ = write!(out, "Via: 1.1 rws\r\n");
217 let _ = write!(out, "Connection: close\r\n");
218 if !request.body.is_empty() {
219 let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
220 }
221 let _ = write!(out, "\r\n");
222 out.extend_from_slice(&request.body);
223 out
224}
225
226fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
227 let mut buf: Vec<u8> = Vec::with_capacity(8192);
228 let mut tmp = [0u8; 4096];
229
230 let header_end = loop {
232 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
233 if n == 0 {
234 return if buf.is_empty() {
235 Err("backend closed connection without sending a response".to_string())
236 } else {
237 Ok(buf)
238 };
239 }
240 buf.extend_from_slice(&tmp[..n]);
241 if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
242 break pos + 4;
243 }
244 };
245
246 let content_length = std::str::from_utf8(&buf[..header_end])
248 .unwrap_or("")
249 .lines()
250 .find_map(|line| {
251 line.to_lowercase()
252 .starts_with("content-length:")
253 .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
254 .flatten()
255 });
256
257 match content_length {
258 Some(len) => {
259 while buf.len() < header_end + len {
260 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
261 if n == 0 {
262 break;
263 }
264 buf.extend_from_slice(&tmp[..n]);
265 }
266 }
267 None => loop {
268 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
269 if n == 0 {
270 break;
271 }
272 buf.extend_from_slice(&tmp[..n]);
273 },
274 }
275
276 Ok(buf)
277}
278
279fn bad_gateway() -> Response {
280 let cr = Range::get_content_range(
281 b"502 Bad Gateway".to_vec(),
282 MimeType::TEXT_PLAIN.to_string(),
283 );
284 let mut r = Response::new();
285 r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
286 r.reason_phrase = STATUS_CODE_REASON_PHRASE
287 .n502_bad_gateway
288 .reason_phrase
289 .to_string();
290 r.content_range_list = vec![cr];
291 r
292}
293
294struct Backend {
297 host: String,
298 port: u16,
299}
300
301impl Backend {
302 fn parse(url: &str) -> Option<Self> {
303 let rest = url
304 .strip_prefix("https://")
305 .or_else(|| url.strip_prefix("http://"))
306 .unwrap_or(url);
307 let host_port = rest.split('/').next().unwrap_or(rest);
309 let (host, port) = if let Some(colon) = host_port.rfind(':') {
310 let port_str = &host_port[colon + 1..];
311 if let Ok(p) = port_str.parse::<u16>() {
312 (host_port[..colon].to_string(), p)
313 } else {
314 (host_port.to_string(), 80)
315 }
316 } else {
317 (host_port.to_string(), 80)
318 };
319 if host.is_empty() {
320 return None;
321 }
322 Some(Backend { host, port })
323 }
324}