1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43use std::fmt::{self, Debug, Formatter};
44#[cfg(test)]
45use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
46
47use hyper::upgrade::OnUpgrade;
48#[cfg(not(test))]
49use local_ip_address::{local_ip, local_ipv6};
50use percent_encoding::{AsciiSet, CONTROLS, utf8_percent_encode};
51use salvo_core::conn::SocketAddr;
52use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
53use salvo_core::http::uri::Uri;
54use salvo_core::http::{ReqBody, ResBody, StatusCode};
55use salvo_core::routing::normalize_url_path;
56use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
57
58#[macro_use]
59mod cfg;
60
61cfg_feature! {
62 #![feature = "hyper-client"]
63 mod hyper_client;
64 pub use hyper_client::*;
65}
66cfg_feature! {
67 #![feature = "reqwest-client"]
68 mod reqwest_client;
69 pub use reqwest_client::*;
70}
71
72cfg_feature! {
73 #![feature = "unix-sock-client"]
74 #[cfg(unix)]
75 mod unix_sock_client;
76 #[cfg(unix)]
77 pub use unix_sock_client::*;
78}
79
80type HyperRequest = hyper::Request<ReqBody>;
81type HyperResponse = hyper::Response<ResBody>;
82
83const X_FORWARDER_FOR_HEADER_NAME: &str = "x-forwarded-for";
84
85const QUERY_ENCODE_SET: &AsciiSet = &CONTROLS
86 .add(b' ')
87 .add(b'"')
88 .add(b'#')
89 .add(b'<')
90 .add(b'>')
91 .add(b'`');
92const PATH_ENCODE_SET: &AsciiSet = &QUERY_ENCODE_SET
93 .add(b'?')
94 .add(b'^')
95 .add(b'`')
96 .add(b'{')
97 .add(b'}');
98
99#[inline]
101pub(crate) fn encode_url_path(path: &str) -> String {
102 path.split('/')
103 .map(|s| utf8_percent_encode(s, PATH_ENCODE_SET).to_string())
104 .collect::<Vec<_>>()
105 .join("/")
106}
107
108pub trait Client: Send + Sync + 'static {
113 type Error: StdError + Send + Sync + 'static;
115
116 fn execute(
118 &self,
119 req: HyperRequest,
120 upgraded: Option<OnUpgrade>,
121 ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
122}
123
124pub trait Upstreams: Send + Sync + 'static {
130 type Error: StdError + Send + Sync + 'static;
132
133 fn elect(
135 &self,
136 req: &Request,
137 depot: &Depot,
138 ) -> impl Future<Output = Result<&str, Self::Error>> + Send;
139}
140impl Upstreams for &'static str {
141 type Error = Infallible;
142
143 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
144 Ok(*self)
145 }
146}
147impl Upstreams for String {
148 type Error = Infallible;
149 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
150 Ok(self.as_str())
151 }
152}
153
154impl<const N: usize> Upstreams for [&'static str; N] {
155 type Error = Error;
156 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
157 if self.is_empty() {
158 return Err(Error::other("upstreams is empty"));
159 }
160 let index = fastrand::usize(..self.len());
161 Ok(self[index])
162 }
163}
164
165impl<T> Upstreams for Vec<T>
166where
167 T: AsRef<str> + Send + Sync + 'static,
168{
169 type Error = Error;
170 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
171 if self.is_empty() {
172 return Err(Error::other("upstreams is empty"));
173 }
174 let index = fastrand::usize(..self.len());
175 Ok(self[index].as_ref())
176 }
177}
178
179pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
181
182pub type HostHeaderGetter =
184 Box<dyn Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static>;
185
186pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
191 req.params().tail().map(str::to_owned)
192}
193pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
195 req.uri().query().map(Into::into)
196}
197
198pub fn default_host_header_getter(
200 forward_uri: &Uri,
201 _req: &Request,
202 _depot: &Depot,
203) -> Option<String> {
204 if let Some(host) = forward_uri.host() {
205 return Some(String::from(host));
206 }
207
208 None
209}
210
211pub fn rfc2616_host_header_getter(
214 forward_uri: &Uri,
215 req: &Request,
216 _depot: &Depot,
217) -> Option<String> {
218 let mut parts: Vec<String> = Vec::with_capacity(2);
219
220 if let Some(host) = forward_uri.host() {
221 parts.push(host.to_owned());
222
223 if let Some(scheme) = forward_uri.scheme_str()
224 && let Some(port) = forward_uri.port_u16()
225 && (scheme == "http" && port != 80 || scheme == "https" && port != 443)
226 {
227 parts.push(port.to_string());
228 }
229 }
230
231 if parts.is_empty() {
232 default_host_header_getter(forward_uri, req, _depot)
233 } else {
234 Some(parts.join(":"))
235 }
236}
237
238pub fn preserve_original_host_header_getter(
241 forward_uri: &Uri,
242 req: &Request,
243 _depot: &Depot,
244) -> Option<String> {
245 if let Some(host_header) = req.headers().get(HOST)
246 && let Ok(host) = String::from_utf8(host_header.as_bytes().to_vec())
247 {
248 return Some(host);
249 }
250
251 default_host_header_getter(forward_uri, req, _depot)
252}
253
254#[non_exhaustive]
256pub struct Proxy<U, C>
257where
258 U: Upstreams,
259 C: Client,
260{
261 pub upstreams: U,
263 pub client: C,
265 pub url_path_getter: UrlPartGetter,
267 pub url_query_getter: UrlPartGetter,
269 pub host_header_getter: HostHeaderGetter,
271 pub client_ip_forwarding_enabled: bool,
273}
274
275impl<U, C> Debug for Proxy<U, C>
276where
277 U: Upstreams,
278 C: Client,
279{
280 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
281 f.debug_struct("Proxy").finish()
282 }
283}
284
285impl<U, C> Proxy<U, C>
286where
287 U: Upstreams,
288 U::Error: Into<BoxedError>,
289 C: Client,
290{
291 #[must_use]
293 pub fn new(upstreams: U, client: C) -> Self {
294 Self {
295 upstreams,
296 client,
297 url_path_getter: Box::new(default_url_path_getter),
298 url_query_getter: Box::new(default_url_query_getter),
299 host_header_getter: Box::new(default_host_header_getter),
300 client_ip_forwarding_enabled: false,
301 }
302 }
303
304 pub fn with_client_ip_forwarding(upstreams: U, client: C) -> Self {
306 Self {
307 upstreams,
308 client,
309 url_path_getter: Box::new(default_url_path_getter),
310 url_query_getter: Box::new(default_url_query_getter),
311 host_header_getter: Box::new(default_host_header_getter),
312 client_ip_forwarding_enabled: true,
313 }
314 }
315
316 #[inline]
318 #[must_use]
319 pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
320 where
321 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
322 {
323 self.url_path_getter = Box::new(url_path_getter);
324 self
325 }
326
327 #[inline]
329 #[must_use]
330 pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
331 where
332 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
333 {
334 self.url_query_getter = Box::new(url_query_getter);
335 self
336 }
337
338 #[inline]
340 #[must_use]
341 pub fn host_header_getter<G>(mut self, host_header_getter: G) -> Self
342 where
343 G: Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static,
344 {
345 self.host_header_getter = Box::new(host_header_getter);
346 self
347 }
348
349 #[inline]
351 pub fn upstreams(&self) -> &U {
352 &self.upstreams
353 }
354 #[inline]
356 pub fn upstreams_mut(&mut self) -> &mut U {
357 &mut self.upstreams
358 }
359
360 #[inline]
362 pub fn client(&self) -> &C {
363 &self.client
364 }
365 #[inline]
367 pub fn client_mut(&mut self) -> &mut C {
368 &mut self.client
369 }
370
371 #[inline]
373 #[must_use]
374 pub fn client_ip_forwarding(mut self, enable: bool) -> Self {
375 self.client_ip_forwarding_enabled = enable;
376 self
377 }
378
379 async fn build_proxied_request(
380 &self,
381 req: &mut Request,
382 depot: &Depot,
383 ) -> Result<HyperRequest, Error> {
384 let upstream = self
385 .upstreams
386 .elect(req, depot)
387 .await
388 .map_err(Error::other)?;
389
390 if upstream.is_empty() {
391 tracing::error!("upstreams is empty");
392 return Err(Error::other("upstreams is empty"));
393 }
394
395 let path = (self.url_path_getter)(req, depot).unwrap_or_default();
396 let path = encode_url_path(&normalize_url_path(&path));
397 let query = (self.url_query_getter)(req, depot);
398 let rest = if let Some(query) = query {
399 if let Some(stripped) = query.strip_prefix('?') {
400 format!("{path}?{}", utf8_percent_encode(stripped, QUERY_ENCODE_SET))
401 } else {
402 format!("{path}?{}", utf8_percent_encode(&query, QUERY_ENCODE_SET))
403 }
404 } else {
405 path
406 };
407 let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
408 format!("{}{}", upstream.trim_end_matches('/'), rest)
409 } else if upstream.ends_with('/') || rest.starts_with('/') {
410 format!("{upstream}{rest}")
411 } else if rest.is_empty() {
412 upstream.to_owned()
413 } else {
414 format!("{upstream}/{rest}")
415 };
416 let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
417 let mut build = hyper::Request::builder()
418 .method(req.method())
419 .uri(&forward_url);
420 for (key, value) in req.headers() {
421 if key != HOST {
422 build = build.header(key, value);
423 }
424 }
425 if let Some(host_value) = (self.host_header_getter)(&forward_url, req, depot) {
426 match HeaderValue::from_str(&host_value) {
427 Ok(host_value) => {
428 build = build.header(HOST, host_value);
429 }
430 Err(e) => {
431 tracing::error!(error = ?e, "invalid host header value");
432 }
433 }
434 }
435
436 if self.client_ip_forwarding_enabled {
437 let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
438 let current_xff = req.headers().get(&xff_header_name);
439
440 #[cfg(test)]
441 let system_ip_addr = match req.remote_addr() {
442 SocketAddr::IPv6(_) => Some(IpAddr::from(Ipv6Addr::new(
443 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8,
444 ))),
445 _ => Some(IpAddr::from(Ipv4Addr::new(101, 102, 103, 104))),
446 };
447
448 #[cfg(not(test))]
449 let system_ip_addr = match req.remote_addr() {
450 SocketAddr::IPv6(_) => local_ipv6().ok(),
451 _ => local_ip().ok(),
452 };
453
454 if let Some(system_ip_addr) = system_ip_addr {
455 let forwarded_addr = system_ip_addr.to_string();
456
457 let xff_value = match current_xff {
458 Some(current_xff) => match current_xff.to_str() {
459 Ok(current_xff) => format!("{forwarded_addr}, {current_xff}"),
460 _ => forwarded_addr.clone(),
461 },
462 None => forwarded_addr.clone(),
463 };
464
465 let xff_header_halue = match HeaderValue::from_str(xff_value.as_str()) {
466 Ok(xff_header_halue) => Some(xff_header_halue),
467 Err(_) => match HeaderValue::from_str(forwarded_addr.as_str()) {
468 Ok(xff_header_halue) => Some(xff_header_halue),
469 Err(e) => {
470 tracing::error!(error = ?e, "invalid x-forwarded-for header value");
471 None
472 }
473 },
474 };
475
476 if let Some(xff) = xff_header_halue
477 && let Some(headers) = build.headers_mut()
478 {
479 headers.insert(&xff_header_name, xff);
480 }
481 }
482 }
483
484 build.body(req.take_body()).map_err(Error::other)
485 }
486}
487
488#[async_trait]
489impl<U, C> Handler for Proxy<U, C>
490where
491 U: Upstreams,
492 U::Error: Into<BoxedError>,
493 C: Client,
494{
495 async fn handle(
496 &self,
497 req: &mut Request,
498 depot: &mut Depot,
499 res: &mut Response,
500 _ctrl: &mut FlowCtrl,
501 ) {
502 match self.build_proxied_request(req, depot).await {
503 Ok(proxied_request) => {
504 match self
505 .client
506 .execute(proxied_request, req.extensions_mut().remove())
507 .await
508 {
509 Ok(response) => {
510 let (
511 salvo_core::http::response::Parts {
512 status,
513 headers,
515 ..
517 },
518 body,
519 ) = response.into_parts();
520 res.status_code(status);
521 for name in headers.keys() {
522 for value in headers.get_all(name) {
523 res.headers.append(name, value.to_owned());
524 }
525 }
526 res.body(body);
527 }
528 Err(e) => {
529 tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
530 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
531 }
532 }
533 }
534 Err(e) => {
535 tracing::error!(error = ?e, "build proxied request failed");
536 res.status_code(StatusCode::BAD_REQUEST);
537 }
538 }
539 }
540}
541#[inline]
542#[allow(dead_code)]
543fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
544 if headers
545 .get(&CONNECTION)
546 .map(|value| {
547 value
548 .to_str()
549 .unwrap_or_default()
550 .split(',')
551 .any(|e| e.trim() == UPGRADE)
552 })
553 .unwrap_or(false)
554 && let Some(upgrade_value) = headers.get(&UPGRADE)
555 {
556 tracing::debug!(
557 "found upgrade header with value: {:?}",
558 upgrade_value.to_str()
559 );
560 return upgrade_value.to_str().ok();
561 }
562
563 None
564}
565
566#[cfg(test)]
568mod tests {
569 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
570 use std::str::FromStr;
571
572 use super::*;
573
574 #[test]
575 fn test_encode_url_path() {
576 let path = "/test/path";
577 let encoded_path = encode_url_path(path);
578 assert_eq!(encoded_path, "/test/path");
579 }
580
581 #[test]
582 fn test_default_url_path_getter_uses_raw_tail() {
583 let mut request = Request::new();
584 request
585 .params_mut()
586 .insert("**rest", "guide/../index.html".to_owned());
587 let depot = Depot::new();
588
589 assert_eq!(
590 default_url_path_getter(&request, &depot).as_deref(),
591 Some("guide/../index.html")
592 );
593 }
594
595 #[test]
596 fn test_get_upgrade_type() {
597 let mut headers = HeaderMap::new();
598 headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
599 headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
600 let upgrade_type = get_upgrade_type(&headers);
601 assert_eq!(upgrade_type, Some("websocket"));
602 }
603
604 #[test]
605 fn test_host_header_handling() {
606 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
607 let uri = Uri::from_str("http://host.tld/test").unwrap();
608 let mut req = Request::new();
609 let depot = Depot::new();
610
611 assert_eq!(
612 default_host_header_getter(&uri, &req, &depot),
613 Some("host.tld".to_string())
614 );
615
616 let uri_with_port = Uri::from_str("http://host.tld:8080/test").unwrap();
617 assert_eq!(
618 rfc2616_host_header_getter(&uri_with_port, &req, &depot),
619 Some("host.tld:8080".to_string())
620 );
621
622 let uri_with_http_port = Uri::from_str("http://host.tld:80/test").unwrap();
623 assert_eq!(
624 rfc2616_host_header_getter(&uri_with_http_port, &req, &depot),
625 Some("host.tld".to_string())
626 );
627
628 let uri_with_https_port = Uri::from_str("https://host.tld:443/test").unwrap();
629 assert_eq!(
630 rfc2616_host_header_getter(&uri_with_https_port, &req, &depot),
631 Some("host.tld".to_string())
632 );
633
634 let uri_with_non_https_scheme_and_https_port =
635 Uri::from_str("http://host.tld:443/test").unwrap();
636 assert_eq!(
637 rfc2616_host_header_getter(&uri_with_non_https_scheme_and_https_port, &req, &depot),
638 Some("host.tld:443".to_string())
639 );
640
641 req.headers_mut()
642 .insert(HOST, HeaderValue::from_static("test.host.tld"));
643 assert_eq!(
644 preserve_original_host_header_getter(&uri, &req, &depot),
645 Some("test.host.tld".to_string())
646 );
647 }
648
649 #[tokio::test]
650 async fn test_client_ip_forwarding() {
651 let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
652
653 let mut request = Request::new();
654 let mut depot = Depot::new();
655
656 let proxy_without_forwarding =
658 Proxy::new(vec!["http://example.com"], HyperClient::default());
659
660 assert_eq!(proxy_without_forwarding.client_ip_forwarding_enabled, false);
661
662 let proxy_with_forwarding = proxy_without_forwarding.client_ip_forwarding(true);
663
664 assert_eq!(proxy_with_forwarding.client_ip_forwarding_enabled, true);
665
666 let proxy =
667 Proxy::with_client_ip_forwarding(vec!["http://example.com"], HyperClient::default());
668 assert_eq!(proxy.client_ip_forwarding_enabled, true);
669
670 match proxy.build_proxied_request(&mut request, &mut depot).await {
671 Ok(req) => assert_eq!(
672 req.headers().get(&xff_header_name),
673 Some(&HeaderValue::from_static("101.102.103.104"))
674 ),
675 _ => assert!(false),
676 }
677
678 *request.remote_addr_mut() =
680 SocketAddr::from(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 12345, 0, 0));
681
682 match proxy.build_proxied_request(&mut request, &mut depot).await {
683 Ok(req) => assert_eq!(
684 req.headers().get(&xff_header_name),
685 Some(&HeaderValue::from_static("1:2:3:4:5:6:7:8"))
686 ),
687 _ => assert!(false),
688 }
689
690 *request.remote_addr_mut() = SocketAddr::Unknown;
691
692 match proxy.build_proxied_request(&mut request, &mut depot).await {
693 Ok(req) => assert_eq!(
694 req.headers().get(&xff_header_name),
695 Some(&HeaderValue::from_static("101.102.103.104"))
696 ),
697 _ => assert!(false),
698 }
699
700 request.headers_mut().insert(
702 &xff_header_name,
703 HeaderValue::from_static("10.72.0.1, 127.0.0.1"),
704 );
705 *request.remote_addr_mut() =
706 SocketAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12345));
707
708 match proxy.build_proxied_request(&mut request, &mut depot).await {
709 Ok(req) => assert_eq!(
710 req.headers().get(&xff_header_name),
711 Some(&HeaderValue::from_static(
712 "101.102.103.104, 10.72.0.1, 127.0.0.1"
713 ))
714 ),
715 _ => assert!(false),
716 }
717 }
718
719 #[tokio::test]
720 async fn test_build_proxied_request_unsafe_tail() {
721 let mut request = Request::new();
722 request.params_mut().insert("**rest", "../admin".to_owned());
723 let depot = Depot::new();
724 let proxy = Proxy::new(vec!["http://example.com/api"], HyperClient::default());
725
726 let req = proxy
727 .build_proxied_request(&mut request, &depot)
728 .await
729 .unwrap();
730 assert_eq!(req.uri().to_string(), "http://example.com/api/admin");
731 }
732
733 #[tokio::test]
734 async fn test_build_proxied_request_normalizes_safe_tail() {
735 let mut request = Request::new();
736 request
737 .params_mut()
738 .insert("**rest", "guide\\index.html".to_owned());
739 let depot = Depot::new();
740 let proxy = Proxy::new(vec!["http://example.com/api"], HyperClient::default());
741
742 let proxied_request = proxy
743 .build_proxied_request(&mut request, &depot)
744 .await
745 .unwrap();
746 assert_eq!(
747 proxied_request.uri().to_string(),
748 "http://example.com/api/guide/index.html"
749 );
750 }
751}