1#[cfg(test)]
27mod tests;
28
29use std::io::{Read, Write};
30use std::net::{TcpStream, ToSocketAddrs};
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::time::Duration;
33
34
35use crate::application::Application;
36use crate::core::New;
37use crate::middleware::Middleware;
38use crate::mime_type::MimeType;
39use crate::range::Range;
40use crate::request::Request;
41use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
42use crate::server::ConnectionInfo;
43
44const HOP_BY_HOP: &[&str] = &[
46 "connection",
47 "keep-alive",
48 "proxy-authenticate",
49 "proxy-authorization",
50 "te",
51 "trailers",
52 "transfer-encoding",
53 "upgrade",
54];
55
56pub enum LoadBalancing {
58 RoundRobin,
60}
61
62pub struct ReverseProxy {
77 backends: Vec<Backend>,
78 path_prefix: Option<String>,
79 connect_timeout: Duration,
80 read_timeout: Duration,
81 counter: AtomicUsize,
82}
83
84impl ReverseProxy {
85 pub fn new<I, S>(backends: I) -> Self
89 where
90 I: IntoIterator<Item = S>,
91 S: AsRef<str>,
92 {
93 Self {
94 backends: backends
95 .into_iter()
96 .filter_map(|u| Backend::parse(u.as_ref()))
97 .collect(),
98 path_prefix: None,
99 connect_timeout: Duration::from_secs(5),
100 read_timeout: Duration::from_secs(30),
101 counter: AtomicUsize::new(0),
102 }
103 }
104
105 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
110 self.path_prefix = Some(prefix.into());
111 self
112 }
113
114 pub fn strategy(self, _strategy: LoadBalancing) -> Self {
116 self
117 }
118
119 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
121 self.connect_timeout = Duration::from_millis(ms);
122 self
123 }
124
125 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
127 self.read_timeout = Duration::from_millis(ms);
128 self
129 }
130
131 fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
132 if self.backends.is_empty() {
133 return Err("no backends configured".to_string());
134 }
135 let n = self.backends.len();
136 let start = self.counter.fetch_add(1, Ordering::Relaxed);
137 for attempt in 0..n {
138 let idx = (start + attempt) % n;
139 match self.try_backend(request, connection, &self.backends[idx]) {
140 Ok(resp) => return Ok(resp),
141 Err(_) if attempt + 1 < n => continue,
142 Err(e) => return Err(e),
143 }
144 }
145 Err("all backends failed".to_string())
146 }
147
148 fn try_backend(
149 &self,
150 request: &Request,
151 connection: &ConnectionInfo,
152 backend: &Backend,
153 ) -> Result<Response, String> {
154 let addr_str = format!("{}:{}", backend.host, backend.port);
155 let sock_addr = addr_str
156 .to_socket_addrs()
157 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
158 .next()
159 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
160
161 let stream = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
162 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
163 stream
164 .set_read_timeout(Some(self.read_timeout))
165 .map_err(|e| e.to_string())?;
166 stream
167 .set_write_timeout(Some(Duration::from_secs(10)))
168 .map_err(|e| e.to_string())?;
169
170 let req_bytes = build_request(request, &backend.host, &connection.client.ip);
171 let mut stream = stream;
172 stream
173 .write_all(&req_bytes)
174 .map_err(|e| format!("write to backend failed: {}", e))?;
175
176 let resp_bytes = read_response(&mut stream)?;
177 Response::parse(&resp_bytes)
178 }
179}
180
181impl Middleware for ReverseProxy {
182 fn handle(
183 &self,
184 request: &Request,
185 connection: &ConnectionInfo,
186 next: &dyn Application,
187 ) -> Result<Response, String> {
188 if let Some(prefix) = &self.path_prefix {
189 if !request.request_uri.starts_with(prefix.as_str()) {
190 return next.execute(request, connection);
191 }
192 }
193 match self.proxy(request, connection) {
194 Ok(resp) => Ok(resp),
195 Err(_) => Ok(bad_gateway()),
196 }
197 }
198}
199
200pub(crate) fn build_request(request: &Request, backend_host: &str, client_ip: &str) -> Vec<u8> {
203 let mut out: Vec<u8> = Vec::new();
204 let _ = write!(
205 out,
206 "{} {} HTTP/1.1\r\nHost: {}\r\n",
207 request.method, request.request_uri, backend_host
208 );
209 for h in &request.headers {
210 let lower = h.name.to_lowercase();
211 if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
212 continue;
213 }
214 let _ = write!(out, "{}: {}\r\n", h.name, h.value);
215 }
216 let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
217 let _ = write!(out, "Via: 1.1 rws\r\n");
218 let _ = write!(out, "Connection: close\r\n");
219 if !request.body.is_empty() {
220 let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
221 }
222 let _ = write!(out, "\r\n");
223 out.extend_from_slice(&request.body);
224 out
225}
226
227pub(crate) fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
228 let mut buf: Vec<u8> = Vec::with_capacity(8192);
229 let mut tmp = [0u8; 4096];
230
231 let header_end = loop {
233 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
234 if n == 0 {
235 return if buf.is_empty() {
236 Err("backend closed connection without sending a response".to_string())
237 } else {
238 Ok(buf)
239 };
240 }
241 buf.extend_from_slice(&tmp[..n]);
242 if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
243 break pos + 4;
244 }
245 };
246
247 let content_length = std::str::from_utf8(&buf[..header_end])
249 .unwrap_or("")
250 .lines()
251 .find_map(|line| {
252 line.to_lowercase()
253 .starts_with("content-length:")
254 .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
255 .flatten()
256 });
257
258 match content_length {
259 Some(len) => {
260 while buf.len() < header_end + len {
261 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
262 if n == 0 {
263 break;
264 }
265 buf.extend_from_slice(&tmp[..n]);
266 }
267 }
268 None => loop {
269 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
270 if n == 0 {
271 break;
272 }
273 buf.extend_from_slice(&tmp[..n]);
274 },
275 }
276
277 Ok(buf)
278}
279
280pub(crate) fn proxy_http1(
286 request: &Request,
287 client_ip: &str,
288 host: &str,
289 port: u16,
290 connect_timeout: Duration,
291 read_timeout: Duration,
292) -> Result<Response, String> {
293 use std::net::ToSocketAddrs;
294 let addr_str = format!("{}:{}", host, port);
295 let sock_addr = addr_str
296 .to_socket_addrs()
297 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
298 .next()
299 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
300 let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
301 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
302 stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
303 stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
304 let req_bytes = build_request(request, host, client_ip);
305 let mut stream = stream;
306 stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
307 let resp_bytes = read_response(&mut stream)?;
308 Response::parse(&resp_bytes)
309}
310
311fn bad_gateway() -> Response {
312 let cr = Range::get_content_range(
313 b"502 Bad Gateway".to_vec(),
314 MimeType::TEXT_PLAIN.to_string(),
315 );
316 let mut r = Response::new();
317 r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
318 r.reason_phrase = STATUS_CODE_REASON_PHRASE
319 .n502_bad_gateway
320 .reason_phrase
321 .to_string();
322 r.content_range_list = vec![cr];
323 r
324}
325
326struct Backend {
329 host: String,
330 port: u16,
331}
332
333impl Backend {
334 fn parse(url: &str) -> Option<Self> {
335 let rest = url
336 .strip_prefix("https://")
337 .or_else(|| url.strip_prefix("http://"))
338 .or_else(|| url.strip_prefix("h2://"))
339 .unwrap_or(url);
340 let host_port = rest.split('/').next().unwrap_or(rest);
342 let (host, port) = if let Some(colon) = host_port.rfind(':') {
343 let port_str = &host_port[colon + 1..];
344 if let Ok(p) = port_str.parse::<u16>() {
345 (host_port[..colon].to_string(), p)
346 } else {
347 (host_port.to_string(), 80)
348 }
349 } else {
350 (host_port.to_string(), 80)
351 };
352 if host.is_empty() {
353 return None;
354 }
355 Some(Backend { host, port })
356 }
357}
358
359#[cfg(feature = "http2")]
383pub struct H2ReverseProxy {
384 inner: ReverseProxy,
385}
386
387#[cfg(feature = "http2")]
388impl H2ReverseProxy {
389 pub fn new<I, S>(backends: I) -> Self
392 where
393 I: IntoIterator<Item = S>,
394 S: AsRef<str>,
395 {
396 H2ReverseProxy {
397 inner: ReverseProxy::new(backends),
398 }
399 }
400
401 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
403 self.inner = self.inner.path_prefix(prefix);
404 self
405 }
406
407 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
409 self.inner = self.inner.connect_timeout_ms(ms);
410 self
411 }
412
413 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
415 self.inner = self.inner.read_timeout_ms(ms);
416 self
417 }
418}
419
420#[cfg(feature = "http2")]
421impl crate::middleware::Middleware for H2ReverseProxy {
422 fn handle(
423 &self,
424 request: &crate::request::Request,
425 connection: &crate::server::ConnectionInfo,
426 next: &dyn crate::application::Application,
427 ) -> Result<crate::response::Response, String> {
428 if let Some(prefix) = &self.inner.path_prefix {
429 if !request.request_uri.starts_with(prefix.as_str()) {
430 return next.execute(request, connection);
431 }
432 }
433 if self.inner.backends.is_empty() {
434 return Ok(bad_gateway());
435 }
436 let n = self.inner.backends.len();
437 let start = self.inner.counter.fetch_add(1, Ordering::Relaxed);
438 for attempt in 0..n {
439 let idx = (start + attempt) % n;
440 match try_backend_h2(request, &connection.client.ip, &self.inner.backends[idx],
441 self.inner.connect_timeout, self.inner.read_timeout) {
442 Ok(resp) => return Ok(resp),
443 Err(_) if attempt + 1 < n => continue,
444 Err(_) => break,
445 }
446 }
447 Ok(bad_gateway())
448 }
449}
450
451#[cfg(feature = "http2")]
452fn try_backend_h2(
453 request: &Request,
454 client_ip: &str,
455 backend: &Backend,
456 connect_timeout: Duration,
457 _read_timeout: Duration,
458) -> Result<Response, String> {
459 use tokio::runtime::Handle;
460 match Handle::try_current() {
461 Ok(_) => tokio::task::block_in_place(|| {
462 Handle::current().block_on(forward_h2_async(request, client_ip, backend, connect_timeout))
463 }),
464 Err(_) => {
465 Err("no async runtime for H2 proxy; falling back to 502".to_string())
467 }
468 }
469}
470
471#[cfg(feature = "http2")]
472async fn forward_h2_async(
473 request: &Request,
474 client_ip: &str,
475 backend: &Backend,
476 connect_timeout: Duration,
477) -> Result<Response, String> {
478 use bytes::Bytes;
479 use http as hc;
480
481 let addr = format!("{}:{}", backend.host, backend.port);
482
483 let tcp = tokio::time::timeout(
484 connect_timeout,
485 tokio::net::TcpStream::connect(&addr),
486 )
487 .await
488 .map_err(|_| format!("h2 proxy: connect to {} timed out", addr))?
489 .map_err(|e| format!("h2 proxy: connect to {} failed: {}", addr, e))?;
490
491 let (send_req, conn) = h2::client::handshake(tcp)
492 .await
493 .map_err(|e| format!("h2 proxy: handshake with {} failed: {}", addr, e))?;
494
495 tokio::spawn(async move {
496 let _ = conn.await;
497 });
498
499 let uri_str = format!("http://{}{}", addr, request.request_uri);
500 let uri: hc::Uri = uri_str.parse().map_err(|e: hc::uri::InvalidUri| e.to_string())?;
501 let method = hc::Method::from_bytes(request.method.as_bytes()).map_err(|e| e.to_string())?;
502
503 let mut builder = hc::Request::builder().method(method).uri(uri);
504 builder = builder.header("host", &backend.host);
505 for h in &request.headers {
506 let lower = h.name.to_lowercase();
507 if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
508 continue;
509 }
510 builder = builder.header(&h.name, &h.value);
511 }
512 builder = builder.header("x-forwarded-for", client_ip);
513 builder = builder.header("via", "2 rws");
514
515 let body_bytes = Bytes::from(request.body.clone());
516 let end_of_stream = body_bytes.is_empty();
517 let http_req = builder.body(()).map_err(|e| e.to_string())?;
518
519 let mut send_req = send_req.ready().await.map_err(|e| e.to_string())?;
520 let (resp_future, mut req_body) = send_req
521 .send_request(http_req, end_of_stream)
522 .map_err(|e| e.to_string())?;
523 if !end_of_stream {
524 req_body.send_data(body_bytes, true).map_err(|e| e.to_string())?;
525 }
526
527 let resp = resp_future.await.map_err(|e| e.to_string())?;
528 let (parts, mut body) = resp.into_parts();
529
530 let content_type = parts
531 .headers
532 .get("content-type")
533 .and_then(|v| v.to_str().ok())
534 .unwrap_or("application/octet-stream")
535 .to_string();
536
537 let mut body_bytes: Vec<u8> = Vec::new();
538 while let Some(chunk) = body.data().await {
539 body_bytes.extend_from_slice(&chunk.map_err(|e| e.to_string())?);
540 }
541
542 let mut response = Response::new();
543 response.status_code = parts.status.as_u16() as i16;
544 response.reason_phrase = parts.status.canonical_reason().unwrap_or("").to_string();
545
546 const H2_HOP: &[&str] = &["connection", "keep-alive", "transfer-encoding",
547 "upgrade", "proxy-connection", "te"];
548 for (name, value) in &parts.headers {
549 let lower = name.as_str().to_lowercase();
550 if H2_HOP.contains(&lower.as_str()) { continue; }
551 if let Ok(v) = value.to_str() {
552 response.headers.push(crate::header::Header {
553 name: name.as_str().to_string(),
554 value: v.to_string(),
555 });
556 }
557 }
558
559 if !body_bytes.is_empty() {
560 response.content_range_list = vec![Range::get_content_range(body_bytes, content_type)];
561 }
562
563 Ok(response)
564}
565
566#[cfg(feature = "http2")]
590pub struct GrpcProxy {
591 inner: H2ReverseProxy,
592}
593
594#[cfg(feature = "http2")]
595impl GrpcProxy {
596 pub fn new<I, S>(backends: I) -> Self
598 where
599 I: IntoIterator<Item = S>,
600 S: AsRef<str>,
601 {
602 GrpcProxy { inner: H2ReverseProxy::new(backends) }
603 }
604
605 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
607 self.inner = self.inner.path_prefix(prefix);
608 self
609 }
610}
611
612#[cfg(feature = "http2")]
613impl crate::middleware::Middleware for GrpcProxy {
614 fn handle(
615 &self,
616 request: &crate::request::Request,
617 connection: &crate::server::ConnectionInfo,
618 next: &dyn crate::application::Application,
619 ) -> Result<crate::response::Response, String> {
620 let ct = request
621 .get_header("content-type".to_string())
622 .map(|h| h.value.as_str())
623 .unwrap_or("");
624 if ct.starts_with("application/grpc") {
625 self.inner.handle(request, connection, next)
626 } else {
627 next.execute(request, connection)
628 }
629 }
630}