1pub mod pool;
27
28#[cfg(test)]
29mod tests;
30
31use std::io::{Read, Write};
32use std::net::{TcpStream, ToSocketAddrs};
33use std::sync::atomic::{AtomicUsize, Ordering};
34use std::sync::Arc;
35use std::time::Duration;
36
37pub use pool::ConnPool;
38
39use crate::application::Application;
40use crate::core::New;
41use crate::middleware::Middleware;
42use crate::mime_type::MimeType;
43use crate::range::Range;
44use crate::request::Request;
45use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
46use crate::server::ConnectionInfo;
47
48const HOP_BY_HOP: &[&str] = &[
50 "connection",
51 "keep-alive",
52 "proxy-authenticate",
53 "proxy-authorization",
54 "te",
55 "trailers",
56 "transfer-encoding",
57 "upgrade",
58];
59
60pub enum LoadBalancing {
62 RoundRobin,
64}
65
66pub struct ReverseProxy {
81 backends: Vec<Backend>,
82 path_prefix: Option<String>,
83 connect_timeout: Duration,
84 read_timeout: Duration,
85 counter: AtomicUsize,
86 pool: Arc<ConnPool>,
87}
88
89impl ReverseProxy {
90 pub fn new<I, S>(backends: I) -> Self
94 where
95 I: IntoIterator<Item = S>,
96 S: AsRef<str>,
97 {
98 Self {
99 backends: backends
100 .into_iter()
101 .filter_map(|u| Backend::parse(u.as_ref()))
102 .collect(),
103 path_prefix: None,
104 connect_timeout: Duration::from_secs(5),
105 read_timeout: Duration::from_secs(30),
106 counter: AtomicUsize::new(0),
107 pool: Arc::new(ConnPool::new_default()),
108 }
109 }
110
111 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
116 self.path_prefix = Some(prefix.into());
117 self
118 }
119
120 pub fn strategy(self, _strategy: LoadBalancing) -> Self {
122 self
123 }
124
125 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
127 self.connect_timeout = Duration::from_millis(ms);
128 self
129 }
130
131 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
133 self.read_timeout = Duration::from_millis(ms);
134 self
135 }
136
137 pub fn with_pool(mut self, pool: Arc<ConnPool>) -> Self {
142 self.pool = pool;
143 self
144 }
145
146 pub fn max_idle_conns(mut self, n: usize) -> Self {
148 self.pool = Arc::new(ConnPool::new(n, Duration::from_secs(60)));
149 self
150 }
151
152 fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
153 if self.backends.is_empty() {
154 return Err("no backends configured".to_string());
155 }
156 let n = self.backends.len();
157 let start = self.counter.fetch_add(1, Ordering::Relaxed);
158 for attempt in 0..n {
159 let idx = (start + attempt) % n;
160 match self.try_backend(request, connection, &self.backends[idx]) {
161 Ok(resp) => return Ok(resp),
162 Err(_) if attempt + 1 < n => continue,
163 Err(e) => return Err(e),
164 }
165 }
166 Err("all backends failed".to_string())
167 }
168
169 fn try_backend(
170 &self,
171 request: &Request,
172 connection: &ConnectionInfo,
173 backend: &Backend,
174 ) -> Result<Response, String> {
175 let key = format!("{}:{}", backend.host, backend.port);
176
177 let stream = if let Some(pooled) = self.pool.acquire(&key) {
179 pooled
180 } else {
181 let addr_str = key.as_str();
182 let sock_addr = addr_str
183 .to_socket_addrs()
184 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
185 .next()
186 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
187 TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
188 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?
189 };
190
191 stream.set_read_timeout(Some(self.read_timeout)).map_err(|e| e.to_string())?;
192 stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
193
194 let req_bytes = build_request(request, &backend.host, &connection.client.ip, true);
197 let mut stream = stream;
198 stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
199
200 let mut tmp = [0u8; 4096];
201 let (header_bytes, body_prefix) = read_headers_only(&mut stream, &mut tmp)?;
202 let header_lower =
203 std::str::from_utf8(&header_bytes).unwrap_or("").to_ascii_lowercase();
204
205 if should_stream_response(&header_lower) {
206 let mut resp = parse_status_and_headers(&header_bytes)?;
209 resp.stream_pipe =
210 Some(Box::new(ConcatReader::new(body_prefix, stream)));
211 Ok(resp)
212 } else {
213 let (resp_bytes, reusable) =
216 read_response_from_partial(&mut stream, header_bytes, body_prefix, &mut tmp)?;
217 if reusable {
218 self.pool.release(&key, stream);
219 }
220 Response::parse(&resp_bytes)
221 }
222 }
223}
224
225impl Middleware for ReverseProxy {
226 fn handle(
227 &self,
228 request: &Request,
229 connection: &ConnectionInfo,
230 next: &dyn Application,
231 ) -> Result<Response, String> {
232 if let Some(prefix) = &self.path_prefix {
233 if !request.request_uri.starts_with(prefix.as_str()) {
234 return next.execute(request, connection);
235 }
236 }
237 match self.proxy(request, connection) {
238 Ok(resp) => Ok(resp),
239 Err(_) => Ok(bad_gateway()),
240 }
241 }
242}
243
244pub(crate) fn build_request(
247 request: &Request,
248 backend_host: &str,
249 client_ip: &str,
250 keep_alive: bool,
251) -> Vec<u8> {
252 let mut out: Vec<u8> = Vec::new();
253 let _ = write!(
254 out,
255 "{} {} HTTP/1.1\r\nHost: {}\r\n",
256 request.method, request.request_uri, backend_host
257 );
258 for h in &request.headers {
259 let lower = h.name.to_lowercase();
260 if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
261 continue;
262 }
263 let _ = write!(out, "{}: {}\r\n", h.name, h.value);
264 }
265 let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
266 let _ = write!(out, "Via: 1.1 rws\r\n");
267 if keep_alive {
268 let _ = write!(out, "Connection: keep-alive\r\n");
269 } else {
270 let _ = write!(out, "Connection: close\r\n");
271 }
272 if !request.body.is_empty() {
273 let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
274 }
275 let _ = write!(out, "\r\n");
276 out.extend_from_slice(&request.body);
277 out
278}
279
280fn decode_chunked(
285 stream: &mut TcpStream,
286 buf: &[u8],
287 header_end: usize,
288 tmp: &mut [u8],
289) -> Result<Vec<u8>, String> {
290 let mut raw: Vec<u8> = buf[header_end..].to_vec();
292 let mut decoded: Vec<u8> = Vec::new();
293
294 loop {
295 let crlf = loop {
297 if let Some(p) = raw.windows(2).position(|w| w == b"\r\n") {
298 break p;
299 }
300 let n = stream.read(tmp).map_err(|e| e.to_string())?;
301 if n == 0 {
302 return Err("chunked: premature EOF reading chunk size".to_string());
303 }
304 raw.extend_from_slice(&tmp[..n]);
305 };
306
307 let size_line = std::str::from_utf8(&raw[..crlf])
309 .map_err(|_| "chunked: non-UTF-8 chunk size line".to_string())?;
310 let size_str = size_line.split(';').next().unwrap_or("").trim();
311 let chunk_size = usize::from_str_radix(size_str, 16)
312 .map_err(|_| format!("chunked: invalid chunk size '{}'", size_str))?;
313 raw.drain(..crlf + 2); if chunk_size == 0 {
316 while raw.len() < 2 {
318 let n = stream.read(tmp).map_err(|e| e.to_string())?;
319 if n == 0 {
320 break;
321 }
322 raw.extend_from_slice(&tmp[..n]);
323 }
324 break;
325 }
326
327 while raw.len() < chunk_size + 2 {
329 let n = stream.read(tmp).map_err(|e| e.to_string())?;
330 if n == 0 {
331 return Err("chunked: premature EOF reading chunk body".to_string());
332 }
333 raw.extend_from_slice(&tmp[..n]);
334 }
335 decoded.extend_from_slice(&raw[..chunk_size]);
336 raw.drain(..chunk_size + 2); }
338
339 Ok(decoded)
340}
341
342fn rewrite_as_content_length(buf: &mut Vec<u8>, header_end: usize, decoded: &[u8]) {
345 let header_str = std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_string();
346 buf.clear();
347 for line in header_str.lines() {
348 if line.to_ascii_lowercase().starts_with("transfer-encoding:") || line.is_empty() {
349 continue;
350 }
351 buf.extend_from_slice(line.as_bytes());
352 buf.extend_from_slice(b"\r\n");
353 }
354 let _ = write!(buf, "Content-Length: {}\r\n\r\n", decoded.len());
355 buf.extend_from_slice(decoded);
356}
357
358pub(crate) fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
361 read_response_from(stream)
362}
363
364pub(crate) fn read_response_from<R: Read>(stream: &mut R) -> Result<Vec<u8>, String> {
365 let mut buf: Vec<u8> = Vec::with_capacity(8192);
366 let mut tmp = [0u8; 4096];
367
368 let header_end = loop {
369 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
370 if n == 0 {
371 return if buf.is_empty() {
372 Err("backend closed connection without sending a response".to_string())
373 } else {
374 Ok(buf)
375 };
376 }
377 buf.extend_from_slice(&tmp[..n]);
378 if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
379 break pos + 4;
380 }
381 };
382
383 let content_length = std::str::from_utf8(&buf[..header_end])
384 .unwrap_or("")
385 .lines()
386 .find_map(|line| {
387 line.to_lowercase()
388 .starts_with("content-length:")
389 .then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
390 .flatten()
391 });
392
393 match content_length {
394 Some(len) => {
395 while buf.len() < header_end + len {
396 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
397 if n == 0 {
398 break;
399 }
400 buf.extend_from_slice(&tmp[..n]);
401 }
402 }
403 None => loop {
404 let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
405 if n == 0 {
406 break;
407 }
408 buf.extend_from_slice(&tmp[..n]);
409 },
410 }
411
412 Ok(buf)
413}
414
415pub(crate) fn proxy_http1(
421 request: &Request,
422 client_ip: &str,
423 host: &str,
424 port: u16,
425 connect_timeout: Duration,
426 read_timeout: Duration,
427) -> Result<Response, String> {
428 use std::net::ToSocketAddrs;
429 let addr_str = format!("{}:{}", host, port);
430 let sock_addr = addr_str
431 .to_socket_addrs()
432 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
433 .next()
434 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
435 let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
436 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
437 stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
438 stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
439 let req_bytes = build_request(request, host, client_ip, false);
440 let mut stream = stream;
441 stream.write_all(&req_bytes).map_err(|e| format!("write to backend failed: {}", e))?;
442 let resp_bytes = read_response(&mut stream)?;
443 Response::parse(&resp_bytes)
444}
445
446#[cfg(any(feature = "http-client", feature = "http2"))]
450pub(crate) fn proxy_https1(
451 request: &Request,
452 client_ip: &str,
453 host: &str,
454 port: u16,
455 connect_timeout: Duration,
456 read_timeout: Duration,
457) -> Result<Response, String> {
458 use rustls::pki_types::ServerName;
459 use rustls::ClientConfig;
460 use std::net::ToSocketAddrs;
461 use std::sync::Arc;
462
463 let addr_str = format!("{}:{}", host, port);
464 let sock_addr = addr_str
465 .to_socket_addrs()
466 .map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
467 .next()
468 .ok_or_else(|| format!("no address resolved for {}", addr_str))?;
469
470 let stream = TcpStream::connect_timeout(&sock_addr, connect_timeout)
471 .map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
472 stream.set_read_timeout(Some(read_timeout)).map_err(|e| e.to_string())?;
473 stream.set_write_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
474
475 let root_store =
476 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
477 let config = Arc::new(
478 ClientConfig::builder()
479 .with_root_certificates(root_store)
480 .with_no_client_auth(),
481 );
482 let server_name = ServerName::try_from(host.to_string())
483 .map_err(|e| format!("invalid upstream hostname '{}': {}", host, e))?;
484 let conn = rustls::ClientConnection::new(config, server_name).map_err(|e| e.to_string())?;
485 let mut tls = rustls::StreamOwned::new(conn, stream);
486
487 let req_bytes = build_request(request, host, client_ip, false);
488 tls.write_all(&req_bytes)
489 .map_err(|e| format!("write to upstream failed: {}", e))?;
490
491 let resp_bytes = read_response_from(&mut tls)?;
492 Response::parse(&resp_bytes)
493}
494
495fn bad_gateway() -> Response {
496 let cr = Range::get_content_range(
497 b"502 Bad Gateway".to_vec(),
498 MimeType::TEXT_PLAIN.to_string(),
499 );
500 let mut r = Response::new();
501 r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
502 r.reason_phrase = STATUS_CODE_REASON_PHRASE
503 .n502_bad_gateway
504 .reason_phrase
505 .to_string();
506 r.content_range_list = vec![cr];
507 r
508}
509
510const STREAM_THRESHOLD: usize = 1024 * 1024; pub(crate) struct ConcatReader<R: Read + Send> {
518 prefix: Vec<u8>,
519 prefix_pos: usize,
520 inner: R,
521}
522
523impl<R: Read + Send> ConcatReader<R> {
524 fn new(prefix: Vec<u8>, inner: R) -> Self {
525 ConcatReader { prefix, prefix_pos: 0, inner }
526 }
527}
528
529impl<R: Read + Send> Read for ConcatReader<R> {
530 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
531 if self.prefix_pos < self.prefix.len() {
532 let avail = &self.prefix[self.prefix_pos..];
533 let n = buf.len().min(avail.len());
534 buf[..n].copy_from_slice(&avail[..n]);
535 self.prefix_pos += n;
536 return Ok(n);
537 }
538 self.inner.read(buf)
539 }
540}
541
542fn read_headers_only(stream: &mut TcpStream, tmp: &mut [u8]) -> Result<(Vec<u8>, Vec<u8>), String> {
547 let mut buf: Vec<u8> = Vec::with_capacity(4096);
548 loop {
549 let n = stream.read(tmp).map_err(|e| e.to_string())?;
550 if n == 0 {
551 return Err("backend closed connection before headers were complete".to_string());
552 }
553 buf.extend_from_slice(&tmp[..n]);
554 if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
555 let body_prefix = buf[pos + 4..].to_vec();
556 buf.truncate(pos + 4);
557 return Ok((buf, body_prefix));
558 }
559 }
560}
561
562pub(crate) fn should_stream_response(header_lower: &str) -> bool {
569 let is_sse = header_lower.lines().any(|l| {
570 l.starts_with("content-type:") && l.contains("text/event-stream")
571 });
572 let is_chunked = header_lower.lines().any(|l| {
573 l.starts_with("transfer-encoding:") && l.contains("chunked")
574 });
575 let content_length: Option<usize> = header_lower.lines().find_map(|l| {
576 l.strip_prefix("content-length:")?.trim().parse().ok()
577 });
578 let is_large = content_length.map_or(false, |n| n > STREAM_THRESHOLD);
579 is_sse || is_chunked || is_large
580}
581
582fn parse_status_and_headers(header_bytes: &[u8]) -> Result<Response, String> {
584 let s = std::str::from_utf8(header_bytes)
585 .map_err(|e| format!("non-UTF-8 response headers: {}", e))?;
586 let mut lines = s.lines();
587 let status_line = lines.next().ok_or("empty backend response")?;
588 let mut parts = status_line.splitn(3, ' ');
589 let http_version = parts.next().unwrap_or("HTTP/1.1").to_string();
590 let status_code: i16 = parts
591 .next()
592 .unwrap_or("502")
593 .parse()
594 .map_err(|_| format!("invalid status code in '{}'", status_line))?;
595 let reason_phrase = parts.next().unwrap_or("").trim_end_matches('\r').to_string();
596 let mut headers = Vec::new();
597 for line in lines {
598 let line = line.trim_end_matches('\r');
599 if line.is_empty() { break; }
600 if let Some(colon) = line.find(':') {
601 headers.push(crate::header::Header {
602 name: line[..colon].trim().to_string(),
603 value: line[colon + 1..].trim().to_string(),
604 });
605 }
606 }
607 Ok(Response {
608 http_version,
609 status_code,
610 reason_phrase,
611 headers,
612 content_range_list: vec![],
613 stream_file: None,
614 stream_pipe: None,
615 })
616}
617
618fn read_response_from_partial(
624 stream: &mut TcpStream,
625 header_bytes: Vec<u8>,
626 body_prefix: Vec<u8>,
627 tmp: &mut [u8],
628) -> Result<(Vec<u8>, bool), String> {
629 let header_end = header_bytes.len();
630 let mut buf = header_bytes;
631 buf.extend_from_slice(&body_prefix);
632
633 let header_lower =
634 std::str::from_utf8(&buf[..header_end]).unwrap_or("").to_ascii_lowercase();
635 let connection_close =
636 header_lower.lines().any(|l| l.starts_with("connection:") && l.contains("close"));
637 let is_chunked = header_lower
638 .lines()
639 .any(|l| l.starts_with("transfer-encoding:") && l.contains("chunked"));
640 let content_length: Option<usize> = header_lower.lines().find_map(|l| {
641 l.strip_prefix("content-length:")?.trim().parse().ok()
642 });
643
644 if is_chunked {
645 let decoded = decode_chunked(stream, &buf, header_end, tmp)?;
646 rewrite_as_content_length(&mut buf, header_end, &decoded);
647 Ok((buf, !connection_close))
648 } else if let Some(len) = content_length {
649 while buf.len() < header_end + len {
650 let n = stream.read(tmp).map_err(|e| e.to_string())?;
651 if n == 0 { break; }
652 buf.extend_from_slice(&tmp[..n]);
653 }
654 Ok((buf, !connection_close))
655 } else {
656 loop {
657 let n = stream.read(tmp).map_err(|e| e.to_string())?;
658 if n == 0 { break; }
659 buf.extend_from_slice(&tmp[..n]);
660 }
661 Ok((buf, false))
662 }
663}
664
665struct Backend {
668 host: String,
669 port: u16,
670 #[cfg_attr(not(feature = "http2"), allow(dead_code))]
673 tls: bool,
674}
675
676impl Backend {
677 fn parse(url: &str) -> Option<Self> {
678 let (rest, tls, default_port) = if let Some(r) = url.strip_prefix("https://") {
679 (r, true, 443u16)
680 } else if let Some(r) = url.strip_prefix("h2s://") {
681 (r, true, 443u16)
682 } else if let Some(r) = url.strip_prefix("grpcs://") {
683 (r, true, 443u16)
684 } else if let Some(r) = url.strip_prefix("http://") {
685 (r, false, 80u16)
686 } else if let Some(r) = url.strip_prefix("h2://") {
687 (r, false, 80u16)
688 } else if let Some(r) = url.strip_prefix("grpc://") {
689 (r, false, 80u16)
690 } else {
691 (url, false, 80u16)
692 };
693 let host_port = rest.split('/').next().unwrap_or(rest);
695 let (host, port) = if let Some(colon) = host_port.rfind(':') {
696 let port_str = &host_port[colon + 1..];
697 if let Ok(p) = port_str.parse::<u16>() {
698 (host_port[..colon].to_string(), p)
699 } else {
700 (host_port.to_string(), default_port)
701 }
702 } else {
703 (host_port.to_string(), default_port)
704 };
705 if host.is_empty() {
706 return None;
707 }
708 Some(Backend { host, port, tls })
709 }
710}
711
712#[cfg(feature = "http2")]
723pub struct H2ReverseProxy {
724 inner: ReverseProxy,
725}
726
727#[cfg(feature = "http2")]
728impl H2ReverseProxy {
729 pub fn new<I, S>(backends: I) -> Self
740 where
741 I: IntoIterator<Item = S>,
742 S: AsRef<str>,
743 {
744 H2ReverseProxy {
745 inner: ReverseProxy::new(backends),
746 }
747 }
748
749 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
751 self.inner = self.inner.path_prefix(prefix);
752 self
753 }
754
755 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
757 self.inner = self.inner.connect_timeout_ms(ms);
758 self
759 }
760
761 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
763 self.inner = self.inner.read_timeout_ms(ms);
764 self
765 }
766}
767
768#[cfg(feature = "http2")]
769impl crate::middleware::Middleware for H2ReverseProxy {
770 fn handle(
771 &self,
772 request: &crate::request::Request,
773 connection: &crate::server::ConnectionInfo,
774 next: &dyn crate::application::Application,
775 ) -> Result<crate::response::Response, String> {
776 if let Some(prefix) = &self.inner.path_prefix {
777 if !request.request_uri.starts_with(prefix.as_str()) {
778 return next.execute(request, connection);
779 }
780 }
781 if self.inner.backends.is_empty() {
782 return Ok(bad_gateway());
783 }
784 let n = self.inner.backends.len();
785 let start = self.inner.counter.fetch_add(1, Ordering::Relaxed);
786 for attempt in 0..n {
787 let idx = (start + attempt) % n;
788 match try_backend_h2(request, &connection.client.ip, &self.inner.backends[idx],
789 self.inner.connect_timeout, self.inner.read_timeout) {
790 Ok(resp) => return Ok(resp),
791 Err(_) if attempt + 1 < n => continue,
792 Err(_) => break,
793 }
794 }
795 Ok(bad_gateway())
796 }
797}
798
799#[cfg(feature = "http2")]
800fn try_backend_h2(
801 request: &Request,
802 client_ip: &str,
803 backend: &Backend,
804 connect_timeout: Duration,
805 _read_timeout: Duration,
806) -> Result<Response, String> {
807 use tokio::runtime::Handle;
808 match Handle::try_current() {
809 Ok(_) => tokio::task::block_in_place(|| {
810 Handle::current().block_on(forward_h2_async(request, client_ip, backend, connect_timeout))
811 }),
812 Err(_) => {
813 Err("no async runtime for H2 proxy; falling back to 502".to_string())
814 }
815 }
816}
817
818#[cfg(feature = "http2")]
819async fn forward_h2_async(
820 request: &Request,
821 client_ip: &str,
822 backend: &Backend,
823 connect_timeout: Duration,
824) -> Result<Response, String> {
825 let addr = format!("{}:{}", backend.host, backend.port);
826 let tcp = tokio::time::timeout(
827 connect_timeout,
828 tokio::net::TcpStream::connect(&addr),
829 )
830 .await
831 .map_err(|_| format!("h2 proxy: connect to {} timed out", addr))?
832 .map_err(|e| format!("h2 proxy: connect to {} failed: {}", addr, e))?;
833
834 if backend.tls {
835 use rustls::pki_types::ServerName;
836 use rustls::ClientConfig;
837 use std::sync::Arc;
838 use tokio_rustls::TlsConnector;
839
840 let root_store = rustls::RootCertStore::from_iter(
841 webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
842 );
843 let mut config = ClientConfig::builder()
844 .with_root_certificates(root_store)
845 .with_no_client_auth();
846 config.alpn_protocols = vec![b"h2".to_vec()];
848 let connector = TlsConnector::from(Arc::new(config));
849 let server_name = ServerName::try_from(backend.host.as_str())
850 .map_err(|e| format!("invalid upstream hostname '{}': {}", backend.host, e))?
851 .to_owned();
852 let tls_stream = connector
853 .connect(server_name, tcp)
854 .await
855 .map_err(|e| format!("h2 proxy: TLS handshake with {} failed: {}", addr, e))?;
856 send_h2_request(request, client_ip, backend, tls_stream).await
857 } else {
858 send_h2_request(request, client_ip, backend, tcp).await
859 }
860}
861
862#[cfg(feature = "http2")]
867async fn send_h2_request<T>(
868 request: &Request,
869 client_ip: &str,
870 backend: &Backend,
871 stream: T,
872) -> Result<Response, String>
873where
874 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
875{
876 use bytes::Bytes;
877 use http as hc;
878
879 let addr = format!("{}:{}", backend.host, backend.port);
880
881 let (send_req, conn) = h2::client::handshake(stream)
882 .await
883 .map_err(|e| format!("h2 proxy: handshake with {} failed: {}", addr, e))?;
884
885 tokio::spawn(async move {
886 let _ = conn.await;
887 });
888
889 let scheme = if backend.tls { "https" } else { "http" };
890 let uri_str = format!("{}://{}{}", scheme, addr, request.request_uri);
891 let uri: hc::Uri = uri_str.parse().map_err(|e: hc::uri::InvalidUri| e.to_string())?;
892 let method = hc::Method::from_bytes(request.method.as_bytes()).map_err(|e| e.to_string())?;
893
894 let mut builder = hc::Request::builder().method(method).uri(uri);
895 builder = builder.header("host", &backend.host);
896 for h in &request.headers {
897 let lower = h.name.to_lowercase();
898 if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
899 continue;
900 }
901 builder = builder.header(&h.name, &h.value);
902 }
903 builder = builder.header("x-forwarded-for", client_ip);
904 builder = builder.header("via", "2 rws");
905
906 let body_bytes = Bytes::from(request.body.clone());
907 let end_of_stream = body_bytes.is_empty();
908 let http_req = builder.body(()).map_err(|e| e.to_string())?;
909
910 let mut send_req = send_req.ready().await.map_err(|e| e.to_string())?;
911 let (resp_future, mut req_body) = send_req
912 .send_request(http_req, end_of_stream)
913 .map_err(|e| e.to_string())?;
914 if !end_of_stream {
915 req_body.send_data(body_bytes, true).map_err(|e| e.to_string())?;
916 }
917
918 let resp = resp_future.await.map_err(|e| e.to_string())?;
919 let (parts, mut body) = resp.into_parts();
920
921 let content_type = parts
922 .headers
923 .get("content-type")
924 .and_then(|v| v.to_str().ok())
925 .unwrap_or("application/octet-stream")
926 .to_string();
927
928 let mut body_bytes: Vec<u8> = Vec::new();
929 while let Some(chunk) = body.data().await {
930 body_bytes.extend_from_slice(&chunk.map_err(|e| e.to_string())?);
931 }
932
933 let mut response = Response::new();
934 response.status_code = parts.status.as_u16() as i16;
935 response.reason_phrase = parts.status.canonical_reason().unwrap_or("").to_string();
936
937 const H2_HOP: &[&str] = &[
938 "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-connection", "te",
939 ];
940 for (name, value) in &parts.headers {
941 let lower = name.as_str().to_lowercase();
942 if H2_HOP.contains(&lower.as_str()) {
943 continue;
944 }
945 if let Ok(v) = value.to_str() {
946 response.headers.push(crate::header::Header {
947 name: name.as_str().to_string(),
948 value: v.to_string(),
949 });
950 }
951 }
952
953 if !body_bytes.is_empty() {
954 response.content_range_list = vec![Range::get_content_range(body_bytes, content_type)];
955 }
956
957 Ok(response)
958}
959
960#[cfg(feature = "http2")]
980pub struct GrpcProxy {
981 inner: H2ReverseProxy,
982}
983
984#[cfg(feature = "http2")]
985impl GrpcProxy {
986 pub fn new<I, S>(backends: I) -> Self
996 where
997 I: IntoIterator<Item = S>,
998 S: AsRef<str>,
999 {
1000 GrpcProxy { inner: H2ReverseProxy::new(backends) }
1001 }
1002
1003 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
1005 self.inner = self.inner.path_prefix(prefix);
1006 self
1007 }
1008}
1009
1010#[cfg(feature = "http2")]
1011impl crate::middleware::Middleware for GrpcProxy {
1012 fn handle(
1013 &self,
1014 request: &crate::request::Request,
1015 connection: &crate::server::ConnectionInfo,
1016 next: &dyn crate::application::Application,
1017 ) -> Result<crate::response::Response, String> {
1018 let ct = request
1019 .get_header("content-type".to_string())
1020 .map(|h| h.value.as_str())
1021 .unwrap_or("");
1022 if ct.starts_with("application/grpc") {
1023 self.inner.handle(request, connection, next)
1024 } else {
1025 next.execute(request, connection)
1026 }
1027 }
1028}
1029
1030#[cfg(test)]
1033mod backend_parse_tests {
1034 use super::Backend;
1035
1036 fn parse(url: &str) -> Option<(String, u16, bool)> {
1037 Backend::parse(url).map(|b| (b.host, b.port, b.tls))
1038 }
1039
1040 #[test]
1041 fn bare_host_port() {
1042 assert_eq!(Some(("api.example.com".into(), 8080, false)), parse("api.example.com:8080"));
1043 }
1044
1045 #[test]
1046 fn http_scheme() {
1047 assert_eq!(Some(("backend".into(), 3000, false)), parse("http://backend:3000"));
1048 }
1049
1050 #[test]
1051 fn h2_scheme_plain() {
1052 assert_eq!(Some(("svc".into(), 50051, false)), parse("h2://svc:50051"));
1053 }
1054
1055 #[test]
1056 fn grpc_scheme_plain() {
1057 assert_eq!(Some(("svc".into(), 50051, false)), parse("grpc://svc:50051"));
1058 }
1059
1060 #[test]
1061 fn https_scheme_sets_tls_and_default_port() {
1062 assert_eq!(Some(("api.example.com".into(), 443, true)), parse("https://api.example.com"));
1063 }
1064
1065 #[test]
1066 fn https_scheme_explicit_port() {
1067 assert_eq!(Some(("api.example.com".into(), 8443, true)), parse("https://api.example.com:8443"));
1068 }
1069
1070 #[test]
1071 fn h2s_scheme_sets_tls() {
1072 assert_eq!(Some(("svc".into(), 443, true)), parse("h2s://svc"));
1073 }
1074
1075 #[test]
1076 fn h2s_scheme_explicit_port() {
1077 assert_eq!(Some(("svc".into(), 8443, true)), parse("h2s://svc:8443"));
1078 }
1079
1080 #[test]
1081 fn grpcs_scheme_sets_tls() {
1082 assert_eq!(Some(("grpc-svc".into(), 443, true)), parse("grpcs://grpc-svc"));
1083 }
1084
1085 #[test]
1086 fn grpcs_scheme_explicit_port() {
1087 assert_eq!(Some(("grpc-svc".into(), 50052, true)), parse("grpcs://grpc-svc:50052"));
1088 }
1089
1090 #[test]
1091 fn empty_host_returns_none() {
1092 assert_eq!(None, parse("https://"));
1093 }
1094
1095 #[test]
1096 fn bare_host_no_port_defaults_to_80() {
1097 assert_eq!(Some(("myhost".into(), 80, false)), parse("myhost"));
1098 }
1099}