1#[cfg(test)]
27mod tests;
28
29use std::io::{Read, Write};
30use std::net::{TcpStream, ToSocketAddrs};
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::time::Duration;
33
34
35use crate::application::Application;
36use crate::core::New;
37use crate::middleware::Middleware;
38use crate::mime_type::MimeType;
39use crate::range::Range;
40use crate::request::Request;
41use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
42use crate::server::ConnectionInfo;
43
44const HOP_BY_HOP: &[&str] = &[
46 "connection",
47 "keep-alive",
48 "proxy-authenticate",
49 "proxy-authorization",
50 "te",
51 "trailers",
52 "transfer-encoding",
53 "upgrade",
54];
55
56pub enum LoadBalancing {
58 RoundRobin,
60}
61
62pub struct ReverseProxy {
77 backends: Vec<Backend>,
78 path_prefix: Option<String>,
79 connect_timeout: Duration,
80 read_timeout: Duration,
81 counter: AtomicUsize,
82}
83
84impl ReverseProxy {
85 pub fn new<I, S>(backends: I) -> Self
89 where
90 I: IntoIterator<Item = S>,
91 S: AsRef<str>,
92 {
93 Self {
94 backends: backends
95 .into_iter()
96 .filter_map(|u| Backend::parse(u.as_ref()))
97 .collect(),
98 path_prefix: None,
99 connect_timeout: Duration::from_secs(5),
100 read_timeout: Duration::from_secs(30),
101 counter: AtomicUsize::new(0),
102 }
103 }
104
105 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
110 self.path_prefix = Some(prefix.into());
111 self
112 }
113
114 pub fn strategy(self, _strategy: LoadBalancing) -> Self {
116 self
117 }
118
119 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
121 self.connect_timeout = Duration::from_millis(ms);
122 self
123 }
124
125 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
127 self.read_timeout = Duration::from_millis(ms);
128 self
129 }
130
131 fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
132 if self.backends.is_empty() {
133 return Err("no backends configured".to_string());
134 }
135 let n = self.backends.len();
136 let start = self.counter.fetch_add(1, Ordering::Relaxed);
137 for attempt in 0..n {
138 let idx = (start + attempt) % n;
139 match self.try_backend(request, connection, &self.backends[idx]) {
140 Ok(resp) => return Ok(resp),
141 Err(_) if attempt + 1 < n => continue,
142 Err(e) => return Err(e),
143 }
144 }
145 Err("all backends failed".to_string())
146 }
147
148 fn try_backend(
149 &self,
150 request: &Request,
151 connection: &ConnectionInfo,
152 backend: &Backend,
153 ) -> Result<Response, String> {
154 let addr_str = format!("{}:{}", backend.host, backend.port);
155 let sock_addr = addr_str
156 .to_socket_addrs()
157 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
158 .next()
159 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
160
161 let stream = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
162 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
163 stream
164 .set_read_timeout(Some(self.read_timeout))
165 .map_err(|e| e.to_string())?;
166 stream
167 .set_write_timeout(Some(Duration::from_secs(10)))
168 .map_err(|e| e.to_string())?;
169
170 let req_bytes = build_request(request, &backend.host, &connection.client.ip);
171 let mut stream = stream;
172 stream
173 .write_all(&req_bytes)
174 .map_err(|e| format!("write to backend failed: {}", e))?;
175
176 let resp_bytes = read_response(&mut stream)?;
177 Response::parse(&resp_bytes)
178 }
179}
180
181impl Middleware for ReverseProxy {
182 fn handle(
183 &self,
184 request: &Request,
185 connection: &ConnectionInfo,
186 next: &dyn Application,
187 ) -> Result<Response, String> {
188 if let Some(prefix) = &self.path_prefix {
189 if !request.request_uri.starts_with(prefix.as_str()) {
190 return next.execute(request, connection);
191 }
192 }
193 match self.proxy(request, connection) {
194 Ok(resp) => Ok(resp),
195 Err(_) => Ok(bad_gateway()),
196 }
197 }
198}
199
200fn build_request(request: &Request, backend_host: &str, client_ip: &str) -> Vec<u8> {
203 let mut out: Vec<u8> = Vec::new();
204 let _ = write!(
205 out,
206 "{} {} HTTP/1.1\r\nHost: {}\r\n",
207 request.method, request.request_uri, backend_host
208 );
209 for h in &request.headers {
210 let lower = h.name.to_lowercase();
211 if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
212 continue;
213 }
214 let _ = write!(out, "{}: {}\r\n", h.name, h.value);
215 }
216 let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
217 let _ = write!(out, "Via: 1.1 rws\r\n");
218 let _ = write!(out, "Connection: close\r\n");
219 if !request.body.is_empty() {
220 let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
221 }
222 let _ = write!(out, "\r\n");
223 out.extend_from_slice(&request.body);
224 out
225}
226
227fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
228 let mut buf: Vec<u8> = Vec::with_capacity(8192);
229 let mut tmp = [0u8; 4096];
230
231 let header_end = loop {
233 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
234 if n == 0 {
235 return if buf.is_empty() {
236 Err("backend closed connection without sending a response".to_string())
237 } else {
238 Ok(buf)
239 };
240 }
241 buf.extend_from_slice(&tmp[..n]);
242 if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
243 break pos + 4;
244 }
245 };
246
247 let content_length = std::str::from_utf8(&buf[..header_end])
249 .unwrap_or("")
250 .lines()
251 .find_map(|line| {
252 line.to_lowercase()
253 .starts_with("content-length:")
254 .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
255 .flatten()
256 });
257
258 match content_length {
259 Some(len) => {
260 while buf.len() < header_end + len {
261 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
262 if n == 0 {
263 break;
264 }
265 buf.extend_from_slice(&tmp[..n]);
266 }
267 }
268 None => loop {
269 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
270 if n == 0 {
271 break;
272 }
273 buf.extend_from_slice(&tmp[..n]);
274 },
275 }
276
277 Ok(buf)
278}
279
280fn bad_gateway() -> Response {
281 let cr = Range::get_content_range(
282 b"502 Bad Gateway".to_vec(),
283 MimeType::TEXT_PLAIN.to_string(),
284 );
285 let mut r = Response::new();
286 r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
287 r.reason_phrase = STATUS_CODE_REASON_PHRASE
288 .n502_bad_gateway
289 .reason_phrase
290 .to_string();
291 r.content_range_list = vec![cr];
292 r
293}
294
295struct Backend {
298 host: String,
299 port: u16,
300}
301
302impl Backend {
303 fn parse(url: &str) -> Option<Self> {
304 let rest = url
305 .strip_prefix("https://")
306 .or_else(|| url.strip_prefix("http://"))
307 .or_else(|| url.strip_prefix("h2://"))
308 .unwrap_or(url);
309 let host_port = rest.split('/').next().unwrap_or(rest);
311 let (host, port) = if let Some(colon) = host_port.rfind(':') {
312 let port_str = &host_port[colon + 1..];
313 if let Ok(p) = port_str.parse::<u16>() {
314 (host_port[..colon].to_string(), p)
315 } else {
316 (host_port.to_string(), 80)
317 }
318 } else {
319 (host_port.to_string(), 80)
320 };
321 if host.is_empty() {
322 return None;
323 }
324 Some(Backend { host, port })
325 }
326}
327
328#[cfg(feature = "http2")]
352pub struct H2ReverseProxy {
353 inner: ReverseProxy,
354}
355
356#[cfg(feature = "http2")]
357impl H2ReverseProxy {
358 pub fn new<I, S>(backends: I) -> Self
361 where
362 I: IntoIterator<Item = S>,
363 S: AsRef<str>,
364 {
365 H2ReverseProxy {
366 inner: ReverseProxy::new(backends),
367 }
368 }
369
370 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
372 self.inner = self.inner.path_prefix(prefix);
373 self
374 }
375
376 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
378 self.inner = self.inner.connect_timeout_ms(ms);
379 self
380 }
381
382 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
384 self.inner = self.inner.read_timeout_ms(ms);
385 self
386 }
387}
388
389#[cfg(feature = "http2")]
390impl crate::middleware::Middleware for H2ReverseProxy {
391 fn handle(
392 &self,
393 request: &crate::request::Request,
394 connection: &crate::server::ConnectionInfo,
395 next: &dyn crate::application::Application,
396 ) -> Result<crate::response::Response, String> {
397 if let Some(prefix) = &self.inner.path_prefix {
398 if !request.request_uri.starts_with(prefix.as_str()) {
399 return next.execute(request, connection);
400 }
401 }
402 if self.inner.backends.is_empty() {
403 return Ok(bad_gateway());
404 }
405 let n = self.inner.backends.len();
406 let start = self.inner.counter.fetch_add(1, Ordering::Relaxed);
407 for attempt in 0..n {
408 let idx = (start + attempt) % n;
409 match try_backend_h2(request, &connection.client.ip, &self.inner.backends[idx],
410 self.inner.connect_timeout, self.inner.read_timeout) {
411 Ok(resp) => return Ok(resp),
412 Err(_) if attempt + 1 < n => continue,
413 Err(_) => break,
414 }
415 }
416 Ok(bad_gateway())
417 }
418}
419
420#[cfg(feature = "http2")]
421fn try_backend_h2(
422 request: &Request,
423 client_ip: &str,
424 backend: &Backend,
425 connect_timeout: Duration,
426 _read_timeout: Duration,
427) -> Result<Response, String> {
428 use tokio::runtime::Handle;
429 match Handle::try_current() {
430 Ok(_) => tokio::task::block_in_place(|| {
431 Handle::current().block_on(forward_h2_async(request, client_ip, backend, connect_timeout))
432 }),
433 Err(_) => {
434 Err("no async runtime for H2 proxy; falling back to 502".to_string())
436 }
437 }
438}
439
440#[cfg(feature = "http2")]
441async fn forward_h2_async(
442 request: &Request,
443 client_ip: &str,
444 backend: &Backend,
445 connect_timeout: Duration,
446) -> Result<Response, String> {
447 use bytes::Bytes;
448 use http as hc;
449
450 let addr = format!("{}:{}", backend.host, backend.port);
451
452 let tcp = tokio::time::timeout(
453 connect_timeout,
454 tokio::net::TcpStream::connect(&addr),
455 )
456 .await
457 .map_err(|_| format!("h2 proxy: connect to {} timed out", addr))?
458 .map_err(|e| format!("h2 proxy: connect to {} failed: {}", addr, e))?;
459
460 let (send_req, conn) = h2::client::handshake(tcp)
461 .await
462 .map_err(|e| format!("h2 proxy: handshake with {} failed: {}", addr, e))?;
463
464 tokio::spawn(async move {
465 let _ = conn.await;
466 });
467
468 let uri_str = format!("http://{}{}", addr, request.request_uri);
469 let uri: hc::Uri = uri_str.parse().map_err(|e: hc::uri::InvalidUri| e.to_string())?;
470 let method = hc::Method::from_bytes(request.method.as_bytes()).map_err(|e| e.to_string())?;
471
472 let mut builder = hc::Request::builder().method(method).uri(uri);
473 builder = builder.header("host", &backend.host);
474 for h in &request.headers {
475 let lower = h.name.to_lowercase();
476 if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
477 continue;
478 }
479 builder = builder.header(&h.name, &h.value);
480 }
481 builder = builder.header("x-forwarded-for", client_ip);
482 builder = builder.header("via", "2 rws");
483
484 let body_bytes = Bytes::from(request.body.clone());
485 let end_of_stream = body_bytes.is_empty();
486 let http_req = builder.body(()).map_err(|e| e.to_string())?;
487
488 let mut send_req = send_req.ready().await.map_err(|e| e.to_string())?;
489 let (resp_future, mut req_body) = send_req
490 .send_request(http_req, end_of_stream)
491 .map_err(|e| e.to_string())?;
492 if !end_of_stream {
493 req_body.send_data(body_bytes, true).map_err(|e| e.to_string())?;
494 }
495
496 let resp = resp_future.await.map_err(|e| e.to_string())?;
497 let (parts, mut body) = resp.into_parts();
498
499 let content_type = parts
500 .headers
501 .get("content-type")
502 .and_then(|v| v.to_str().ok())
503 .unwrap_or("application/octet-stream")
504 .to_string();
505
506 let mut body_bytes: Vec<u8> = Vec::new();
507 while let Some(chunk) = body.data().await {
508 body_bytes.extend_from_slice(&chunk.map_err(|e| e.to_string())?);
509 }
510
511 let mut response = Response::new();
512 response.status_code = parts.status.as_u16() as i16;
513 response.reason_phrase = parts.status.canonical_reason().unwrap_or("").to_string();
514
515 const H2_HOP: &[&str] = &["connection", "keep-alive", "transfer-encoding",
516 "upgrade", "proxy-connection", "te"];
517 for (name, value) in &parts.headers {
518 let lower = name.as_str().to_lowercase();
519 if H2_HOP.contains(&lower.as_str()) { continue; }
520 if let Ok(v) = value.to_str() {
521 response.headers.push(crate::header::Header {
522 name: name.as_str().to_string(),
523 value: v.to_string(),
524 });
525 }
526 }
527
528 if !body_bytes.is_empty() {
529 response.content_range_list = vec![Range::get_content_range(body_bytes, content_type)];
530 }
531
532 Ok(response)
533}
534
535#[cfg(feature = "http2")]
559pub struct GrpcProxy {
560 inner: H2ReverseProxy,
561}
562
563#[cfg(feature = "http2")]
564impl GrpcProxy {
565 pub fn new<I, S>(backends: I) -> Self
567 where
568 I: IntoIterator<Item = S>,
569 S: AsRef<str>,
570 {
571 GrpcProxy { inner: H2ReverseProxy::new(backends) }
572 }
573
574 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
576 self.inner = self.inner.path_prefix(prefix);
577 self
578 }
579}
580
581#[cfg(feature = "http2")]
582impl crate::middleware::Middleware for GrpcProxy {
583 fn handle(
584 &self,
585 request: &crate::request::Request,
586 connection: &crate::server::ConnectionInfo,
587 next: &dyn crate::application::Application,
588 ) -> Result<crate::response::Response, String> {
589 let ct = request
590 .get_header("content-type".to_string())
591 .map(|h| h.value.as_str())
592 .unwrap_or("");
593 if ct.starts_with("application/grpc") {
594 self.inner.handle(request, connection, next)
595 } else {
596 next.execute(request, connection)
597 }
598 }
599}