1pub mod pool;
27
28#[cfg(test)]
29mod tests;
30
31use std::io::{Read, Write};
32use std::net::{TcpStream, ToSocketAddrs};
33use std::sync::atomic::{AtomicUsize, Ordering};
34use std::sync::Arc;
35use std::time::Duration;
36
37pub use pool::ConnPool;
38
39use crate::application::Application;
40use crate::core::New;
41use crate::middleware::Middleware;
42use crate::mime_type::MimeType;
43use crate::range::Range;
44use crate::request::Request;
45use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
46use crate::server::ConnectionInfo;
47
48const HOP_BY_HOP: &[&str] = &[
50 "connection",
51 "keep-alive",
52 "proxy-authenticate",
53 "proxy-authorization",
54 "te",
55 "trailers",
56 "transfer-encoding",
57 "upgrade",
58];
59
60pub enum LoadBalancing {
62 RoundRobin,
64}
65
66pub struct ReverseProxy {
81 backends: Vec<Backend>,
82 path_prefix: Option<String>,
83 connect_timeout: Duration,
84 read_timeout: Duration,
85 counter: AtomicUsize,
86 pool: Arc<ConnPool>,
87}
88
89impl ReverseProxy {
90 pub fn new<I, S>(backends: I) -> Self
94 where
95 I: IntoIterator<Item = S>,
96 S: AsRef<str>,
97 {
98 Self {
99 backends: backends
100 .into_iter()
101 .filter_map(|u| Backend::parse(u.as_ref()))
102 .collect(),
103 path_prefix: None,
104 connect_timeout: Duration::from_secs(5),
105 read_timeout: Duration::from_secs(30),
106 counter: AtomicUsize::new(0),
107 pool: Arc::new(ConnPool::new_default()),
108 }
109 }
110
111 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
116 self.path_prefix = Some(prefix.into());
117 self
118 }
119
120 pub fn strategy(self, _strategy: LoadBalancing) -> Self {
122 self
123 }
124
125 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
127 self.connect_timeout = Duration::from_millis(ms);
128 self
129 }
130
131 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
133 self.read_timeout = Duration::from_millis(ms);
134 self
135 }
136
137 pub fn with_pool(mut self, pool: Arc<ConnPool>) -> Self {
142 self.pool = pool;
143 self
144 }
145
146 pub fn max_idle_conns(mut self, n: usize) -> Self {
148 self.pool = Arc::new(ConnPool::new(n, Duration::from_secs(60)));
149 self
150 }
151
152 fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
153 if self.backends.is_empty() {
154 return Err("no backends configured".to_string());
155 }
156 let n = self.backends.len();
157 let start = self.counter.fetch_add(1, Ordering::Relaxed);
158 for attempt in 0..n {
159 let idx = (start + attempt) % n;
160 match self.try_backend(request, connection, &self.backends[idx]) {
161 Ok(resp) => return Ok(resp),
162 Err(_) if attempt + 1 < n => continue,
163 Err(e) => return Err(e),
164 }
165 }
166 Err("all backends failed".to_string())
167 }
168
169 fn try_backend(
170 &self,
171 request: &Request,
172 connection: &ConnectionInfo,
173 backend: &Backend,
174 ) -> Result<Response, String> {
175 let key = format!("{}:{}", backend.host, backend.port);
176
177 let stream = if let Some(pooled) = self.pool.acquire(&key) {
179 pooled
180 } else {
181 let addr_str = key.as_str();
182 let sock_addr = addr_str
183 .to_socket_addrs()
184 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
185 .next()
186 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
187 TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
188 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?
189 };
190
191 stream.set_read_timeout(Some(self.read_timeout)).map_err(|e| e.to_string())?;
192 stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
193
194 let req_bytes = build_request(request, &backend.host, &connection.client.ip, true);
197 let mut stream = stream;
198 stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
199
200 let mut tmp = [0u8; 4096];
201 let (header_bytes, body_prefix) = read_headers_only(&mut stream, &mut tmp)?;
202 let header_lower =
203 std::str::from_utf8(&header_bytes).unwrap_or("").to_ascii_lowercase();
204
205 if should_stream_response(&header_lower) {
206 let mut resp = parse_status_and_headers(&header_bytes)?;
209 resp.stream_pipe =
210 Some(Box::new(ConcatReader::new(body_prefix, stream)));
211 Ok(resp)
212 } else {
213 let (resp_bytes, reusable) =
216 read_response_from_partial(&mut stream, header_bytes, body_prefix, &mut tmp)?;
217 if reusable {
218 self.pool.release(&key, stream);
219 }
220 Response::parse(&resp_bytes)
221 }
222 }
223}
224
225impl Middleware for ReverseProxy {
226 fn handle(
227 &self,
228 request: &Request,
229 connection: &ConnectionInfo,
230 next: &dyn Application,
231 ) -> Result<Response, String> {
232 if let Some(prefix) = &self.path_prefix {
233 if !request.request_uri.starts_with(prefix.as_str()) {
234 return next.execute(request, connection);
235 }
236 }
237 match self.proxy(request, connection) {
238 Ok(resp) => Ok(resp),
239 Err(_) => Ok(bad_gateway()),
240 }
241 }
242}
243
244pub(crate) fn build_request(
247 request: &Request,
248 backend_host: &str,
249 client_ip: &str,
250 keep_alive: bool,
251) -> Vec<u8> {
252 let mut out: Vec<u8> = Vec::new();
253 let _ = write!(
254 out,
255 "{} {} HTTP/1.1\r\nHost: {}\r\n",
256 request.method, request.request_uri, backend_host
257 );
258 for h in &request.headers {
259 let lower = h.name.to_lowercase();
260 if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
261 continue;
262 }
263 let _ = write!(out, "{}: {}\r\n", h.name, h.value);
264 }
265 let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
266 let _ = write!(out, "Via: 1.1 rws\r\n");
267 if keep_alive {
268 let _ = write!(out, "Connection: keep-alive\r\n");
269 } else {
270 let _ = write!(out, "Connection: close\r\n");
271 }
272 if !request.body.is_empty() {
273 let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
274 }
275 let _ = write!(out, "\r\n");
276 out.extend_from_slice(&request.body);
277 out
278}
279
280pub(crate) fn read_response_poolable(stream: &mut TcpStream) -> Result<(Vec<u8>, bool), String> {
291 let mut buf: Vec<u8> = Vec::with_capacity(8192);
292 let mut tmp = [0u8; 4096];
293
294 let header_end = loop {
296 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
297 if n == 0 {
298 return if buf.is_empty() {
299 Err("backend closed connection without sending a response".to_string())
300 } else {
301 Ok((buf, false))
302 };
303 }
304 buf.extend_from_slice(&tmp[..n]);
305 if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
306 break pos + 4;
307 }
308 };
309
310 let header_str_lower =
311 std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_ascii_lowercase();
312
313 let connection_close =
314 header_str_lower.lines().any(|l| l.starts_with("connection:") && l.contains("close"));
315
316 let is_chunked = header_str_lower
317 .lines()
318 .any(|l| l.starts_with("transfer-encoding:") && l.contains("chunked"));
319
320 let content_length: Option<usize> = header_str_lower.lines().find_map(|l| {
321 l.strip_prefix("content-length:")?
322 .trim()
323 .parse()
324 .ok()
325 });
326
327 if is_chunked {
328 let decoded = decode_chunked(stream, &buf, header_end, &mut tmp)?;
329 rewrite_as_content_length(&mut buf, header_end, &decoded);
330 Ok((buf, !connection_close))
331 } else if let Some(len) = content_length {
332 while buf.len() < header_end + len {
333 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
334 if n == 0 {
335 break;
336 }
337 buf.extend_from_slice(&tmp[..n]);
338 }
339 Ok((buf, !connection_close))
340 } else {
341 loop {
343 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
344 if n == 0 {
345 break;
346 }
347 buf.extend_from_slice(&tmp[..n]);
348 }
349 Ok((buf, false))
350 }
351}
352
353fn decode_chunked(
358 stream: &mut TcpStream,
359 buf: &[u8],
360 header_end: usize,
361 tmp: &mut [u8],
362) -> Result<Vec<u8>, String> {
363 let mut raw: Vec<u8> = buf[header_end..].to_vec();
365 let mut decoded: Vec<u8> = Vec::new();
366
367 loop {
368 let crlf = loop {
370 if let Some(p) = raw.windows(2).position(|w| w == b"\r\n") {
371 break p;
372 }
373 let n = stream.read(tmp).map_err(|e| e.to_string())?;
374 if n == 0 {
375 return Err("chunked: premature EOF reading chunk size".to_string());
376 }
377 raw.extend_from_slice(&tmp[..n]);
378 };
379
380 let size_line = std::str::from_utf8(&raw[..crlf])
382 .map_err(|_| "chunked: non-UTF-8 chunk size line".to_string())?;
383 let size_str = size_line.split(';').next().unwrap_or("").trim();
384 let chunk_size = usize::from_str_radix(size_str, 16)
385 .map_err(|_| format!("chunked: invalid chunk size '{}'", size_str))?;
386 raw.drain(..crlf + 2); if chunk_size == 0 {
389 while raw.len() < 2 {
391 let n = stream.read(tmp).map_err(|e| e.to_string())?;
392 if n == 0 {
393 break;
394 }
395 raw.extend_from_slice(&tmp[..n]);
396 }
397 break;
398 }
399
400 while raw.len() < chunk_size + 2 {
402 let n = stream.read(tmp).map_err(|e| e.to_string())?;
403 if n == 0 {
404 return Err("chunked: premature EOF reading chunk body".to_string());
405 }
406 raw.extend_from_slice(&tmp[..n]);
407 }
408 decoded.extend_from_slice(&raw[..chunk_size]);
409 raw.drain(..chunk_size + 2); }
411
412 Ok(decoded)
413}
414
415fn rewrite_as_content_length(buf: &mut Vec<u8>, header_end: usize, decoded: &[u8]) {
418 let header_str = std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_string();
419 buf.clear();
420 for line in header_str.lines() {
421 if line.to_ascii_lowercase().starts_with("transfer-encoding:") || line.is_empty() {
422 continue;
423 }
424 buf.extend_from_slice(line.as_bytes());
425 buf.extend_from_slice(b"\r\n");
426 }
427 let _ = write!(buf, "Content-Length: {}\r\n\r\n", decoded.len());
428 buf.extend_from_slice(decoded);
429}
430
431pub(crate) fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
434 read_response_from(stream)
435}
436
437pub(crate) fn read_response_from<R: Read>(stream: &mut R) -> Result<Vec<u8>, String> {
438 let mut buf: Vec<u8> = Vec::with_capacity(8192);
439 let mut tmp = [0u8; 4096];
440
441 let header_end = loop {
442 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
443 if n == 0 {
444 return if buf.is_empty() {
445 Err("backend closed connection without sending a response".to_string())
446 } else {
447 Ok(buf)
448 };
449 }
450 buf.extend_from_slice(&tmp[..n]);
451 if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
452 break pos + 4;
453 }
454 };
455
456 let content_length = std::str::from_utf8(&buf[..header_end])
457 .unwrap_or("")
458 .lines()
459 .find_map(|line| {
460 line.to_lowercase()
461 .starts_with("content-length:")
462 .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
463 .flatten()
464 });
465
466 match content_length {
467 Some(len) => {
468 while buf.len() < header_end + len {
469 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
470 if n == 0 {
471 break;
472 }
473 buf.extend_from_slice(&tmp[..n]);
474 }
475 }
476 None => loop {
477 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
478 if n == 0 {
479 break;
480 }
481 buf.extend_from_slice(&tmp[..n]);
482 },
483 }
484
485 Ok(buf)
486}
487
488pub(crate) fn proxy_http1(
494 request: &Request,
495 client_ip: &str,
496 host: &str,
497 port: u16,
498 connect_timeout: Duration,
499 read_timeout: Duration,
500) -> Result<Response, String> {
501 use std::net::ToSocketAddrs;
502 let addr_str = format!("{}:{}", host, port);
503 let sock_addr = addr_str
504 .to_socket_addrs()
505 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
506 .next()
507 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
508 let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
509 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
510 stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
511 stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
512 let req_bytes = build_request(request, host, client_ip, false);
513 let mut stream = stream;
514 stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
515 let resp_bytes = read_response(&mut stream)?;
516 Response::parse(&resp_bytes)
517}
518
519#[cfg(any(feature = "http-client", feature = "http2"))]
523pub(crate) fn proxy_https1(
524 request: &Request,
525 client_ip: &str,
526 host: &str,
527 port: u16,
528 connect_timeout: Duration,
529 read_timeout: Duration,
530) -> Result<Response, String> {
531 use rustls::pki_types::ServerName;
532 use rustls::ClientConfig;
533 use std::net::ToSocketAddrs;
534 use std::sync::Arc;
535
536 let addr_str = format!("{}:{}", host, port);
537 let sock_addr = addr_str
538 .to_socket_addrs()
539 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
540 .next()
541 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
542
543 let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
544 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
545 stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
546 stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
547
548 let root_store =
549 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
550 let config = Arc::new(
551 ClientConfig::builder()
552 .with_root_certificates(root_store)
553 .with_no_client_auth(),
554 );
555 let server_name = ServerName::try_from(host.to_string())
556 .map_err(|e| format!("invalid upstream hostname '{}': {}", host, e))?;
557 let conn = rustls::ClientConnection::new(config, server_name).map_err(|e| e.to_string())?;
558 let mut tls = rustls::StreamOwned::new(conn, stream);
559
560 let req_bytes = build_request(request, host, client_ip, false);
561 tls.write_all(&req_bytes)
562 .map_err(|e| format!("write to upstream failed: {}", e))?;
563
564 let resp_bytes = read_response_from(&mut tls)?;
565 Response::parse(&resp_bytes)
566}
567
568fn bad_gateway() -> Response {
569 let cr = Range::get_content_range(
570 b"502 Bad Gateway".to_vec(),
571 MimeType::TEXT_PLAIN.to_string(),
572 );
573 let mut r = Response::new();
574 r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
575 r.reason_phrase = STATUS_CODE_REASON_PHRASE
576 .n502_bad_gateway
577 .reason_phrase
578 .to_string();
579 r.content_range_list = vec![cr];
580 r
581}
582
583const STREAM_THRESHOLD: usize = 1024 * 1024; pub(crate) struct ConcatReader<R: Read + Send> {
591 prefix: Vec<u8>,
592 prefix_pos: usize,
593 inner: R,
594}
595
596impl<R: Read + Send> ConcatReader<R> {
597 fn new(prefix: Vec<u8>, inner: R) -> Self {
598 ConcatReader { prefix, prefix_pos: 0, inner }
599 }
600}
601
602impl<R: Read + Send> Read for ConcatReader<R> {
603 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
604 if self.prefix_pos < self.prefix.len() {
605 let avail = &self.prefix[self.prefix_pos..];
606 let n = buf.len().min(avail.len());
607 buf[..n].copy_from_slice(&avail[..n]);
608 self.prefix_pos += n;
609 return Ok(n);
610 }
611 self.inner.read(buf)
612 }
613}
614
615fn read_headers_only(stream: &mut TcpStream, tmp: &mut [u8]) -> Result<(Vec<u8>, Vec<u8>), String> {
620 let mut buf: Vec<u8> = Vec::with_capacity(4096);
621 loop {
622 let n = stream.read(tmp).map_err(|e| e.to_string())?;
623 if n == 0 {
624 return Err("backend closed connection before headers were complete".to_string());
625 }
626 buf.extend_from_slice(&tmp[..n]);
627 if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
628 let body_prefix = buf[pos + 4..].to_vec();
629 buf.truncate(pos + 4);
630 return Ok((buf, body_prefix));
631 }
632 }
633}
634
635pub(crate) fn should_stream_response(header_lower: &str) -> bool {
642 let is_sse = header_lower.lines().any(|l| {
643 l.starts_with("content-type:") && l.contains("text/event-stream")
644 });
645 let is_chunked = header_lower.lines().any(|l| {
646 l.starts_with("transfer-encoding:") && l.contains("chunked")
647 });
648 let content_length: Option<usize> = header_lower.lines().find_map(|l| {
649 l.strip_prefix("content-length:")?.trim().parse().ok()
650 });
651 let is_large = content_length.map_or(false, |n| n > STREAM_THRESHOLD);
652 is_sse || is_chunked || is_large
653}
654
655fn parse_status_and_headers(header_bytes: &[u8]) -> Result<Response, String> {
657 let s = std::str::from_utf8(header_bytes)
658 .map_err(|e| format!("non-UTF-8 response headers: {}", e))?;
659 let mut lines = s.lines();
660 let status_line = lines.next().ok_or("empty backend response")?;
661 let mut parts = status_line.splitn(3, ' ');
662 let http_version = parts.next().unwrap_or("HTTP/1.1").to_string();
663 let status_code: i16 = parts
664 .next()
665 .unwrap_or("502")
666 .parse()
667 .map_err(|_| format!("invalid status code in '{}'", status_line))?;
668 let reason_phrase = parts.next().unwrap_or("").trim_end_matches('\r').to_string();
669 let mut headers = Vec::new();
670 for line in lines {
671 let line = line.trim_end_matches('\r');
672 if line.is_empty() { break; }
673 if let Some(colon) = line.find(':') {
674 headers.push(crate::header::Header {
675 name: line[..colon].trim().to_string(),
676 value: line[colon + 1..].trim().to_string(),
677 });
678 }
679 }
680 Ok(Response {
681 http_version,
682 status_code,
683 reason_phrase,
684 headers,
685 content_range_list: vec![],
686 stream_file: None,
687 stream_pipe: None,
688 })
689}
690
691fn read_response_from_partial(
697 stream: &mut TcpStream,
698 header_bytes: Vec<u8>,
699 body_prefix: Vec<u8>,
700 tmp: &mut [u8],
701) -> Result<(Vec<u8>, bool), String> {
702 let header_end = header_bytes.len();
703 let mut buf = header_bytes;
704 buf.extend_from_slice(&body_prefix);
705
706 let header_lower =
707 std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_ascii_lowercase();
708 let connection_close =
709 header_lower.lines().any(|l| l.starts_with("connection:") && l.contains("close"));
710 let is_chunked = header_lower
711 .lines()
712 .any(|l| l.starts_with("transfer-encoding:") && l.contains("chunked"));
713 let content_length: Option<usize> = header_lower.lines().find_map(|l| {
714 l.strip_prefix("content-length:")?.trim().parse().ok()
715 });
716
717 if is_chunked {
718 let decoded = decode_chunked(stream, &buf, header_end, tmp)?;
719 rewrite_as_content_length(&mut buf, header_end, &decoded);
720 Ok((buf, !connection_close))
721 } else if let Some(len) = content_length {
722 while buf.len() < header_end + len {
723 let n = stream.read(tmp).map_err(|e| e.to_string())?;
724 if n == 0 { break; }
725 buf.extend_from_slice(&tmp[..n]);
726 }
727 Ok((buf, !connection_close))
728 } else {
729 loop {
730 let n = stream.read(tmp).map_err(|e| e.to_string())?;
731 if n == 0 { break; }
732 buf.extend_from_slice(&tmp[..n]);
733 }
734 Ok((buf, false))
735 }
736}
737
738struct Backend {
741 host: String,
742 port: u16,
743 #[cfg_attr(not(feature = "http2"), allow(dead_code))]
746 tls: bool,
747}
748
749impl Backend {
750 fn parse(url: &str) -> Option<Self> {
751 let (rest, tls, default_port) = if let Some(r) = url.strip_prefix("https://") {
752 (r, true, 443u16)
753 } else if let Some(r) = url.strip_prefix("h2s://") {
754 (r, true, 443u16)
755 } else if let Some(r) = url.strip_prefix("grpcs://") {
756 (r, true, 443u16)
757 } else if let Some(r) = url.strip_prefix("http://") {
758 (r, false, 80u16)
759 } else if let Some(r) = url.strip_prefix("h2://") {
760 (r, false, 80u16)
761 } else if let Some(r) = url.strip_prefix("grpc://") {
762 (r, false, 80u16)
763 } else {
764 (url, false, 80u16)
765 };
766 let host_port = rest.split('/').next().unwrap_or(rest);
768 let (host, port) = if let Some(colon) = host_port.rfind(':') {
769 let port_str = &host_port[colon + 1..];
770 if let Ok(p) = port_str.parse::<u16>() {
771 (host_port[..colon].to_string(), p)
772 } else {
773 (host_port.to_string(), default_port)
774 }
775 } else {
776 (host_port.to_string(), default_port)
777 };
778 if host.is_empty() {
779 return None;
780 }
781 Some(Backend { host, port, tls })
782 }
783}
784
785#[cfg(feature = "http2")]
796pub struct H2ReverseProxy {
797 inner: ReverseProxy,
798}
799
800#[cfg(feature = "http2")]
801impl H2ReverseProxy {
802 pub fn new<I, S>(backends: I) -> Self
813 where
814 I: IntoIterator<Item = S>,
815 S: AsRef<str>,
816 {
817 H2ReverseProxy {
818 inner: ReverseProxy::new(backends),
819 }
820 }
821
822 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
824 self.inner = self.inner.path_prefix(prefix);
825 self
826 }
827
828 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
830 self.inner = self.inner.connect_timeout_ms(ms);
831 self
832 }
833
834 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
836 self.inner = self.inner.read_timeout_ms(ms);
837 self
838 }
839}
840
841#[cfg(feature = "http2")]
842impl crate::middleware::Middleware for H2ReverseProxy {
843 fn handle(
844 &self,
845 request: &crate::request::Request,
846 connection: &crate::server::ConnectionInfo,
847 next: &dyn crate::application::Application,
848 ) -> Result<crate::response::Response, String> {
849 if let Some(prefix) = &self.inner.path_prefix {
850 if !request.request_uri.starts_with(prefix.as_str()) {
851 return next.execute(request, connection);
852 }
853 }
854 if self.inner.backends.is_empty() {
855 return Ok(bad_gateway());
856 }
857 let n = self.inner.backends.len();
858 let start = self.inner.counter.fetch_add(1, Ordering::Relaxed);
859 for attempt in 0..n {
860 let idx = (start + attempt) % n;
861 match try_backend_h2(request, &connection.client.ip, &self.inner.backends[idx],
862 self.inner.connect_timeout, self.inner.read_timeout) {
863 Ok(resp) => return Ok(resp),
864 Err(_) if attempt + 1 < n => continue,
865 Err(_) => break,
866 }
867 }
868 Ok(bad_gateway())
869 }
870}
871
872#[cfg(feature = "http2")]
873fn try_backend_h2(
874 request: &Request,
875 client_ip: &str,
876 backend: &Backend,
877 connect_timeout: Duration,
878 _read_timeout: Duration,
879) -> Result<Response, String> {
880 use tokio::runtime::Handle;
881 match Handle::try_current() {
882 Ok(_) => tokio::task::block_in_place(|| {
883 Handle::current().block_on(forward_h2_async(request, client_ip, backend, connect_timeout))
884 }),
885 Err(_) => {
886 Err("no async runtime for H2 proxy; falling back to 502".to_string())
887 }
888 }
889}
890
891#[cfg(feature = "http2")]
892async fn forward_h2_async(
893 request: &Request,
894 client_ip: &str,
895 backend: &Backend,
896 connect_timeout: Duration,
897) -> Result<Response, String> {
898 let addr = format!("{}:{}", backend.host, backend.port);
899 let tcp = tokio::time::timeout(
900 connect_timeout,
901 tokio::net::TcpStream::connect(&addr),
902 )
903 .await
904 .map_err(|_| format!("h2 proxy: connect to {} timed out", addr))?
905 .map_err(|e| format!("h2 proxy: connect to {} failed: {}", addr, e))?;
906
907 if backend.tls {
908 use rustls::pki_types::ServerName;
909 use rustls::ClientConfig;
910 use std::sync::Arc;
911 use tokio_rustls::TlsConnector;
912
913 let root_store = rustls::RootCertStore::from_iter(
914 webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
915 );
916 let mut config = ClientConfig::builder()
917 .with_root_certificates(root_store)
918 .with_no_client_auth();
919 config.alpn_protocols = vec![b"h2".to_vec()];
921 let connector = TlsConnector::from(Arc::new(config));
922 let server_name = ServerName::try_from(backend.host.as_str())
923 .map_err(|e| format!("invalid upstream hostname '{}': {}", backend.host, e))?
924 .to_owned();
925 let tls_stream = connector
926 .connect(server_name, tcp)
927 .await
928 .map_err(|e| format!("h2 proxy: TLS handshake with {} failed: {}", addr, e))?;
929 send_h2_request(request, client_ip, backend, tls_stream).await
930 } else {
931 send_h2_request(request, client_ip, backend, tcp).await
932 }
933}
934
935#[cfg(feature = "http2")]
940async fn send_h2_request<T>(
941 request: &Request,
942 client_ip: &str,
943 backend: &Backend,
944 stream: T,
945) -> Result<Response, String>
946where
947 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
948{
949 use bytes::Bytes;
950 use http as hc;
951
952 let addr = format!("{}:{}", backend.host, backend.port);
953
954 let (send_req, conn) = h2::client::handshake(stream)
955 .await
956 .map_err(|e| format!("h2 proxy: handshake with {} failed: {}", addr, e))?;
957
958 tokio::spawn(async move {
959 let _ = conn.await;
960 });
961
962 let scheme = if backend.tls { "https" } else { "http" };
963 let uri_str = format!("{}://{}{}", scheme, addr, request.request_uri);
964 let uri: hc::Uri = uri_str.parse().map_err(|e: hc::uri::InvalidUri| e.to_string())?;
965 let method = hc::Method::from_bytes(request.method.as_bytes()).map_err(|e| e.to_string())?;
966
967 let mut builder = hc::Request::builder().method(method).uri(uri);
968 builder = builder.header("host", &backend.host);
969 for h in &request.headers {
970 let lower = h.name.to_lowercase();
971 if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
972 continue;
973 }
974 builder = builder.header(&h.name, &h.value);
975 }
976 builder = builder.header("x-forwarded-for", client_ip);
977 builder = builder.header("via", "2 rws");
978
979 let body_bytes = Bytes::from(request.body.clone());
980 let end_of_stream = body_bytes.is_empty();
981 let http_req = builder.body(()).map_err(|e| e.to_string())?;
982
983 let mut send_req = send_req.ready().await.map_err(|e| e.to_string())?;
984 let (resp_future, mut req_body) = send_req
985 .send_request(http_req, end_of_stream)
986 .map_err(|e| e.to_string())?;
987 if !end_of_stream {
988 req_body.send_data(body_bytes, true).map_err(|e| e.to_string())?;
989 }
990
991 let resp = resp_future.await.map_err(|e| e.to_string())?;
992 let (parts, mut body) = resp.into_parts();
993
994 let content_type = parts
995 .headers
996 .get("content-type")
997 .and_then(|v| v.to_str().ok())
998 .unwrap_or("application/octet-stream")
999 .to_string();
1000
1001 let mut body_bytes: Vec<u8> = Vec::new();
1002 while let Some(chunk) = body.data().await {
1003 body_bytes.extend_from_slice(&chunk.map_err(|e| e.to_string())?);
1004 }
1005
1006 let mut response = Response::new();
1007 response.status_code = parts.status.as_u16() as i16;
1008 response.reason_phrase = parts.status.canonical_reason().unwrap_or("").to_string();
1009
1010 const H2_HOP: &[&str] = &[
1011 "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-connection", "te",
1012 ];
1013 for (name, value) in &parts.headers {
1014 let lower = name.as_str().to_lowercase();
1015 if H2_HOP.contains(&lower.as_str()) {
1016 continue;
1017 }
1018 if let Ok(v) = value.to_str() {
1019 response.headers.push(crate::header::Header {
1020 name: name.as_str().to_string(),
1021 value: v.to_string(),
1022 });
1023 }
1024 }
1025
1026 if !body_bytes.is_empty() {
1027 response.content_range_list = vec![Range::get_content_range(body_bytes, content_type)];
1028 }
1029
1030 Ok(response)
1031}
1032
1033#[cfg(feature = "http2")]
1053pub struct GrpcProxy {
1054 inner: H2ReverseProxy,
1055}
1056
1057#[cfg(feature = "http2")]
1058impl GrpcProxy {
1059 pub fn new<I, S>(backends: I) -> Self
1069 where
1070 I: IntoIterator<Item = S>,
1071 S: AsRef<str>,
1072 {
1073 GrpcProxy { inner: H2ReverseProxy::new(backends) }
1074 }
1075
1076 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
1078 self.inner = self.inner.path_prefix(prefix);
1079 self
1080 }
1081}
1082
1083#[cfg(feature = "http2")]
1084impl crate::middleware::Middleware for GrpcProxy {
1085 fn handle(
1086 &self,
1087 request: &crate::request::Request,
1088 connection: &crate::server::ConnectionInfo,
1089 next: &dyn crate::application::Application,
1090 ) -> Result<crate::response::Response, String> {
1091 let ct = request
1092 .get_header("content-type".to_string())
1093 .map(|h| h.value.as_str())
1094 .unwrap_or("");
1095 if ct.starts_with("application/grpc") {
1096 self.inner.handle(request, connection, next)
1097 } else {
1098 next.execute(request, connection)
1099 }
1100 }
1101}
1102
1103#[cfg(test)]
1106mod backend_parse_tests {
1107 use super::Backend;
1108
1109 fn parse(url: &str) -> Option<(String, u16, bool)> {
1110 Backend::parse(url).map(|b| (b.host, b.port, b.tls))
1111 }
1112
1113 #[test]
1114 fn bare_host_port() {
1115 assert_eq!(Some(("api.example.com".into(), 8080, false)), parse("api.example.com:8080"));
1116 }
1117
1118 #[test]
1119 fn http_scheme() {
1120 assert_eq!(Some(("backend".into(), 3000, false)), parse("http://backend:3000"));
1121 }
1122
1123 #[test]
1124 fn h2_scheme_plain() {
1125 assert_eq!(Some(("svc".into(), 50051, false)), parse("h2://svc:50051"));
1126 }
1127
1128 #[test]
1129 fn grpc_scheme_plain() {
1130 assert_eq!(Some(("svc".into(), 50051, false)), parse("grpc://svc:50051"));
1131 }
1132
1133 #[test]
1134 fn https_scheme_sets_tls_and_default_port() {
1135 assert_eq!(Some(("api.example.com".into(), 443, true)), parse("https://api.example.com"));
1136 }
1137
1138 #[test]
1139 fn https_scheme_explicit_port() {
1140 assert_eq!(Some(("api.example.com".into(), 8443, true)), parse("https://api.example.com:8443"));
1141 }
1142
1143 #[test]
1144 fn h2s_scheme_sets_tls() {
1145 assert_eq!(Some(("svc".into(), 443, true)), parse("h2s://svc"));
1146 }
1147
1148 #[test]
1149 fn h2s_scheme_explicit_port() {
1150 assert_eq!(Some(("svc".into(), 8443, true)), parse("h2s://svc:8443"));
1151 }
1152
1153 #[test]
1154 fn grpcs_scheme_sets_tls() {
1155 assert_eq!(Some(("grpc-svc".into(), 443, true)), parse("grpcs://grpc-svc"));
1156 }
1157
1158 #[test]
1159 fn grpcs_scheme_explicit_port() {
1160 assert_eq!(Some(("grpc-svc".into(), 50052, true)), parse("grpcs://grpc-svc:50052"));
1161 }
1162
1163 #[test]
1164 fn empty_host_returns_none() {
1165 assert_eq!(None, parse("https://"));
1166 }
1167
1168 #[test]
1169 fn bare_host_no_port_defaults_to_80() {
1170 assert_eq!(Some(("myhost".into(), 80, false)), parse("myhost"));
1171 }
1172}