1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43use std::fmt::{self, Debug, Formatter};
44#[cfg(test)]
45use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
46
47use hyper::upgrade::OnUpgrade;
48#[cfg(not(test))]
49use local_ip_address::{local_ip, local_ipv6};
50use percent_encoding::{AsciiSet, CONTROLS, utf8_percent_encode};
51use salvo_core::conn::SocketAddr;
52use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
53use salvo_core::http::uri::Uri;
54use salvo_core::http::{ReqBody, ResBody, StatusCode};
55use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
56
57#[macro_use]
58mod cfg;
59
60cfg_feature! {
61 #![feature = "hyper-client"]
62 mod hyper_client;
63 pub use hyper_client::*;
64}
65cfg_feature! {
66 #![feature = "reqwest-client"]
67 mod reqwest_client;
68 pub use reqwest_client::*;
69}
70
71cfg_feature! {
72 #![feature = "unix-sock-client"]
73 #[cfg(unix)]
74 mod unix_sock_client;
75 #[cfg(unix)]
76 pub use unix_sock_client::*;
77}
78
79type HyperRequest = hyper::Request<ReqBody>;
80type HyperResponse = hyper::Response<ResBody>;
81
82const X_FORWARDER_FOR_HEADER_NAME: &str = "x-forwarded-for";
83
84const QUERY_ENCODE_SET: &AsciiSet = &CONTROLS
85 .add(b' ')
86 .add(b'"')
87 .add(b'#')
88 .add(b'<')
89 .add(b'>')
90 .add(b'`');
91const PATH_ENCODE_SET: &AsciiSet = &QUERY_ENCODE_SET
92 .add(b'?')
93 .add(b'^')
94 .add(b'`')
95 .add(b'{')
96 .add(b'}');
97
98#[inline]
100pub(crate) fn encode_url_path(path: &str) -> String {
101 path.split('/')
102 .map(|s| utf8_percent_encode(s, PATH_ENCODE_SET).to_string())
103 .collect::<Vec<_>>()
104 .join("/")
105}
106
107pub trait Client: Send + Sync + 'static {
112 type Error: StdError + Send + Sync + 'static;
114
115 fn execute(
117 &self,
118 req: HyperRequest,
119 upgraded: Option<OnUpgrade>,
120 ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
121}
122
123pub trait Upstreams: Send + Sync + 'static {
129 type Error: StdError + Send + Sync + 'static;
131
132 fn elect(
134 &self,
135 req: &Request,
136 depot: &Depot,
137 ) -> impl Future<Output = Result<&str, Self::Error>> + Send;
138}
139impl Upstreams for &'static str {
140 type Error = Infallible;
141
142 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
143 Ok(*self)
144 }
145}
146impl Upstreams for String {
147 type Error = Infallible;
148 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
149 Ok(self.as_str())
150 }
151}
152
153impl<const N: usize> Upstreams for [&'static str; N] {
154 type Error = Error;
155 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
156 if self.is_empty() {
157 return Err(Error::other("upstreams is empty"));
158 }
159 let index = fastrand::usize(..self.len());
160 Ok(self[index])
161 }
162}
163
164impl<T> Upstreams for Vec<T>
165where
166 T: AsRef<str> + Send + Sync + 'static,
167{
168 type Error = Error;
169 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
170 if self.is_empty() {
171 return Err(Error::other("upstreams is empty"));
172 }
173 let index = fastrand::usize(..self.len());
174 Ok(self[index].as_ref())
175 }
176}
177
178pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
180
181pub type HostHeaderGetter =
183 Box<dyn Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static>;
184
185pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
190 req.params().tail().map(encode_url_path)
191}
192pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
194 req.uri().query().map(Into::into)
195}
196
197pub fn default_host_header_getter(
199 forward_uri: &Uri,
200 _req: &Request,
201 _depot: &Depot,
202) -> Option<String> {
203 if let Some(host) = forward_uri.host() {
204 return Some(String::from(host));
205 }
206
207 None
208}
209
210pub fn rfc2616_host_header_getter(
213 forward_uri: &Uri,
214 req: &Request,
215 _depot: &Depot,
216) -> Option<String> {
217 let mut parts: Vec<String> = Vec::with_capacity(2);
218
219 if let Some(host) = forward_uri.host() {
220 parts.push(host.to_owned());
221
222 if let Some(scheme) = forward_uri.scheme_str()
223 && let Some(port) = forward_uri.port_u16()
224 && (scheme == "http" && port != 80 || scheme == "https" && port != 443)
225 {
226 parts.push(port.to_string());
227 }
228 }
229
230 if parts.is_empty() {
231 default_host_header_getter(forward_uri, req, _depot)
232 } else {
233 Some(parts.join(":"))
234 }
235}
236
237pub fn preserve_original_host_header_getter(
240 forward_uri: &Uri,
241 req: &Request,
242 _depot: &Depot,
243) -> Option<String> {
244 if let Some(host_header) = req.headers().get(HOST)
245 && let Ok(host) = String::from_utf8(host_header.as_bytes().to_vec())
246 {
247 return Some(host);
248 }
249
250 default_host_header_getter(forward_uri, req, _depot)
251}
252
253#[non_exhaustive]
255pub struct Proxy<U, C>
256where
257 U: Upstreams,
258 C: Client,
259{
260 pub upstreams: U,
262 pub client: C,
264 pub url_path_getter: UrlPartGetter,
266 pub url_query_getter: UrlPartGetter,
268 pub host_header_getter: HostHeaderGetter,
270 pub client_ip_forwarding_enabled: bool,
272}
273
274impl<U, C> Debug for Proxy<U, C>
275where
276 U: Upstreams,
277 C: Client,
278{
279 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
280 f.debug_struct("Proxy").finish()
281 }
282}
283
284impl<U, C> Proxy<U, C>
285where
286 U: Upstreams,
287 U::Error: Into<BoxedError>,
288 C: Client,
289{
290 #[must_use]
292 pub fn new(upstreams: U, client: C) -> Self {
293 Self {
294 upstreams,
295 client,
296 url_path_getter: Box::new(default_url_path_getter),
297 url_query_getter: Box::new(default_url_query_getter),
298 host_header_getter: Box::new(default_host_header_getter),
299 client_ip_forwarding_enabled: false,
300 }
301 }
302
303 pub fn with_client_ip_forwarding(upstreams: U, client: C) -> Self {
305 Self {
306 upstreams,
307 client,
308 url_path_getter: Box::new(default_url_path_getter),
309 url_query_getter: Box::new(default_url_query_getter),
310 host_header_getter: Box::new(default_host_header_getter),
311 client_ip_forwarding_enabled: true,
312 }
313 }
314
315 #[inline]
317 #[must_use]
318 pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
319 where
320 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
321 {
322 self.url_path_getter = Box::new(url_path_getter);
323 self
324 }
325
326 #[inline]
328 #[must_use]
329 pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
330 where
331 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
332 {
333 self.url_query_getter = Box::new(url_query_getter);
334 self
335 }
336
337 #[inline]
339 #[must_use]
340 pub fn host_header_getter<G>(mut self, host_header_getter: G) -> Self
341 where
342 G: Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static,
343 {
344 self.host_header_getter = Box::new(host_header_getter);
345 self
346 }
347
348 #[inline]
350 pub fn upstreams(&self) -> &U {
351 &self.upstreams
352 }
353 #[inline]
355 pub fn upstreams_mut(&mut self) -> &mut U {
356 &mut self.upstreams
357 }
358
359 #[inline]
361 pub fn client(&self) -> &C {
362 &self.client
363 }
364 #[inline]
366 pub fn client_mut(&mut self) -> &mut C {
367 &mut self.client
368 }
369
370 #[inline]
372 #[must_use]
373 pub fn client_ip_forwarding(mut self, enable: bool) -> Self {
374 self.client_ip_forwarding_enabled = enable;
375 self
376 }
377
378 async fn build_proxied_request(
379 &self,
380 req: &mut Request,
381 depot: &Depot,
382 ) -> Result<HyperRequest, Error> {
383 let upstream = self
384 .upstreams
385 .elect(req, depot)
386 .await
387 .map_err(Error::other)?;
388
389 if upstream.is_empty() {
390 tracing::error!("upstreams is empty");
391 return Err(Error::other("upstreams is empty"));
392 }
393
394 let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
395 let query = (self.url_query_getter)(req, depot);
396 let rest = if let Some(query) = query {
397 if let Some(stripped) = query.strip_prefix('?') {
398 format!("{path}?{}", utf8_percent_encode(stripped, QUERY_ENCODE_SET))
399 } else {
400 format!("{path}?{}", utf8_percent_encode(&query, QUERY_ENCODE_SET))
401 }
402 } else {
403 path
404 };
405 let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
406 format!("{}{}", upstream.trim_end_matches('/'), rest)
407 } else if upstream.ends_with('/') || rest.starts_with('/') {
408 format!("{upstream}{rest}")
409 } else if rest.is_empty() {
410 upstream.to_owned()
411 } else {
412 format!("{upstream}/{rest}")
413 };
414 let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
415 let mut build = hyper::Request::builder()
416 .method(req.method())
417 .uri(&forward_url);
418 for (key, value) in req.headers() {
419 if key != HOST {
420 build = build.header(key, value);
421 }
422 }
423 if let Some(host_value) = (self.host_header_getter)(&forward_url, req, depot) {
424 match HeaderValue::from_str(&host_value) {
425 Ok(host_value) => {
426 build = build.header(HOST, host_value);
427 }
428 Err(e) => {
429 tracing::error!(error = ?e, "invalid host header value");
430 }
431 }
432 }
433
434 if self.client_ip_forwarding_enabled {
435 let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
436 let current_xff = req.headers().get(&xff_header_name);
437
438 #[cfg(test)]
439 let system_ip_addr = match req.remote_addr() {
440 SocketAddr::IPv6(_) => Some(IpAddr::from(Ipv6Addr::new(
441 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8,
442 ))),
443 _ => Some(IpAddr::from(Ipv4Addr::new(101, 102, 103, 104))),
444 };
445
446 #[cfg(not(test))]
447 let system_ip_addr = match req.remote_addr() {
448 SocketAddr::IPv6(_) => local_ipv6().ok(),
449 _ => local_ip().ok(),
450 };
451
452 if let Some(system_ip_addr) = system_ip_addr {
453 let forwarded_addr = system_ip_addr.to_string();
454
455 let xff_value = match current_xff {
456 Some(current_xff) => match current_xff.to_str() {
457 Ok(current_xff) => format!("{forwarded_addr}, {current_xff}"),
458 _ => forwarded_addr.clone(),
459 },
460 None => forwarded_addr.clone(),
461 };
462
463 let xff_header_halue = match HeaderValue::from_str(xff_value.as_str()) {
464 Ok(xff_header_halue) => Some(xff_header_halue),
465 Err(_) => match HeaderValue::from_str(forwarded_addr.as_str()) {
466 Ok(xff_header_halue) => Some(xff_header_halue),
467 Err(e) => {
468 tracing::error!(error = ?e, "invalid x-forwarded-for header value");
469 None
470 }
471 },
472 };
473
474 if let Some(xff) = xff_header_halue
475 && let Some(headers) = build.headers_mut()
476 {
477 headers.insert(&xff_header_name, xff);
478 }
479 }
480 }
481
482 build.body(req.take_body()).map_err(Error::other)
483 }
484}
485
486#[async_trait]
487impl<U, C> Handler for Proxy<U, C>
488where
489 U: Upstreams,
490 U::Error: Into<BoxedError>,
491 C: Client,
492{
493 async fn handle(
494 &self,
495 req: &mut Request,
496 depot: &mut Depot,
497 res: &mut Response,
498 _ctrl: &mut FlowCtrl,
499 ) {
500 match self.build_proxied_request(req, depot).await {
501 Ok(proxied_request) => {
502 match self
503 .client
504 .execute(proxied_request, req.extensions_mut().remove())
505 .await
506 {
507 Ok(response) => {
508 let (
509 salvo_core::http::response::Parts {
510 status,
511 headers,
513 ..
515 },
516 body,
517 ) = response.into_parts();
518 res.status_code(status);
519 for name in headers.keys() {
520 for value in headers.get_all(name) {
521 res.headers.append(name, value.to_owned());
522 }
523 }
524 res.body(body);
525 }
526 Err(e) => {
527 tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
528 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
529 }
530 }
531 }
532 Err(e) => {
533 tracing::error!(error = ?e, "build proxied request failed");
534 }
535 }
536 }
537}
538#[inline]
539#[allow(dead_code)]
540fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
541 if headers
542 .get(&CONNECTION)
543 .map(|value| {
544 value
545 .to_str()
546 .unwrap_or_default()
547 .split(',')
548 .any(|e| e.trim() == UPGRADE)
549 })
550 .unwrap_or(false)
551 && let Some(upgrade_value) = headers.get(&UPGRADE)
552 {
553 tracing::debug!(
554 "found upgrade header with value: {:?}",
555 upgrade_value.to_str()
556 );
557 return upgrade_value.to_str().ok();
558 }
559
560 None
561}
562
563#[cfg(test)]
565mod tests {
566 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
567 use std::str::FromStr;
568
569 use super::*;
570
571 #[test]
572 fn test_encode_url_path() {
573 let path = "/test/path";
574 let encoded_path = encode_url_path(path);
575 assert_eq!(encoded_path, "/test/path");
576 }
577
578 #[test]
579 fn test_get_upgrade_type() {
580 let mut headers = HeaderMap::new();
581 headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
582 headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
583 let upgrade_type = get_upgrade_type(&headers);
584 assert_eq!(upgrade_type, Some("websocket"));
585 }
586
587 #[test]
588 fn test_host_header_handling() {
589 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
590 let uri = Uri::from_str("http://host.tld/test").unwrap();
591 let mut req = Request::new();
592 let depot = Depot::new();
593
594 assert_eq!(
595 default_host_header_getter(&uri, &req, &depot),
596 Some("host.tld".to_string())
597 );
598
599 let uri_with_port = Uri::from_str("http://host.tld:8080/test").unwrap();
600 assert_eq!(
601 rfc2616_host_header_getter(&uri_with_port, &req, &depot),
602 Some("host.tld:8080".to_string())
603 );
604
605 let uri_with_http_port = Uri::from_str("http://host.tld:80/test").unwrap();
606 assert_eq!(
607 rfc2616_host_header_getter(&uri_with_http_port, &req, &depot),
608 Some("host.tld".to_string())
609 );
610
611 let uri_with_https_port = Uri::from_str("https://host.tld:443/test").unwrap();
612 assert_eq!(
613 rfc2616_host_header_getter(&uri_with_https_port, &req, &depot),
614 Some("host.tld".to_string())
615 );
616
617 let uri_with_non_https_scheme_and_https_port =
618 Uri::from_str("http://host.tld:443/test").unwrap();
619 assert_eq!(
620 rfc2616_host_header_getter(&uri_with_non_https_scheme_and_https_port, &req, &depot),
621 Some("host.tld:443".to_string())
622 );
623
624 req.headers_mut()
625 .insert(HOST, HeaderValue::from_static("test.host.tld"));
626 assert_eq!(
627 preserve_original_host_header_getter(&uri, &req, &depot),
628 Some("test.host.tld".to_string())
629 );
630 }
631
632 #[tokio::test]
633 async fn test_client_ip_forwarding() {
634 let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
635
636 let mut request = Request::new();
637 let mut depot = Depot::new();
638
639 let proxy_without_forwarding =
641 Proxy::new(vec!["http://example.com"], HyperClient::default());
642
643 assert_eq!(proxy_without_forwarding.client_ip_forwarding_enabled, false);
644
645 let proxy_with_forwarding = proxy_without_forwarding.client_ip_forwarding(true);
646
647 assert_eq!(proxy_with_forwarding.client_ip_forwarding_enabled, true);
648
649 let proxy =
650 Proxy::with_client_ip_forwarding(vec!["http://example.com"], HyperClient::default());
651 assert_eq!(proxy.client_ip_forwarding_enabled, true);
652
653 match proxy.build_proxied_request(&mut request, &mut depot).await {
654 Ok(req) => assert_eq!(
655 req.headers().get(&xff_header_name),
656 Some(&HeaderValue::from_static("101.102.103.104"))
657 ),
658 _ => assert!(false),
659 }
660
661 *request.remote_addr_mut() =
663 SocketAddr::from(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 12345, 0, 0));
664
665 match proxy.build_proxied_request(&mut request, &mut depot).await {
666 Ok(req) => assert_eq!(
667 req.headers().get(&xff_header_name),
668 Some(&HeaderValue::from_static("1:2:3:4:5:6:7:8"))
669 ),
670 _ => assert!(false),
671 }
672
673 *request.remote_addr_mut() = SocketAddr::Unknown;
674
675 match proxy.build_proxied_request(&mut request, &mut depot).await {
676 Ok(req) => assert_eq!(
677 req.headers().get(&xff_header_name),
678 Some(&HeaderValue::from_static("101.102.103.104"))
679 ),
680 _ => assert!(false),
681 }
682
683 request.headers_mut().insert(
685 &xff_header_name,
686 HeaderValue::from_static("10.72.0.1, 127.0.0.1"),
687 );
688 *request.remote_addr_mut() =
689 SocketAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12345));
690
691 match proxy.build_proxied_request(&mut request, &mut depot).await {
692 Ok(req) => assert_eq!(
693 req.headers().get(&xff_header_name),
694 Some(&HeaderValue::from_static(
695 "101.102.103.104, 10.72.0.1, 127.0.0.1"
696 ))
697 ),
698 _ => assert!(false),
699 }
700 }
701}