1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43use std::fmt::{self, Debug, Formatter};
44#[cfg(test)]
45use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
46
47use hyper::upgrade::OnUpgrade;
48#[cfg(not(test))]
49use local_ip_address::{local_ip, local_ipv6};
50use percent_encoding::{AsciiSet, CONTROLS, utf8_percent_encode};
51use salvo_core::conn::SocketAddr;
52use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
53use salvo_core::http::uri::Uri;
54use salvo_core::http::{ReqBody, ResBody, StatusCode};
55use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
56
57#[macro_use]
58mod cfg;
59
60cfg_feature! {
61 #![feature = "hyper-client"]
62 mod hyper_client;
63 pub use hyper_client::*;
64}
65cfg_feature! {
66 #![feature = "reqwest-client"]
67 mod reqwest_client;
68 pub use reqwest_client::*;
69}
70
71cfg_feature! {
72 #![feature = "unix-sock-client"]
73 #[cfg(unix)]
74 mod unix_sock_client;
75 #[cfg(unix)]
76 pub use unix_sock_client::*;
77}
78
79type HyperRequest = hyper::Request<ReqBody>;
80type HyperResponse = hyper::Response<ResBody>;
81
82const X_FORWARDER_FOR_HEADER_NAME: &str = "x-forwarded-for";
83
84const QUERY_ENCODE_SET: &AsciiSet = &CONTROLS
85 .add(b' ')
86 .add(b'"')
87 .add(b'#')
88 .add(b'<')
89 .add(b'>')
90 .add(b'`');
91const PATH_ENCODE_SET: &AsciiSet = &QUERY_ENCODE_SET
92 .add(b'?')
93 .add(b'^')
94 .add(b'`')
95 .add(b'{')
96 .add(b'}');
97
98#[inline]
100pub(crate) fn encode_url_path(path: &str) -> String {
101 path.split('/')
102 .map(|s| utf8_percent_encode(s, PATH_ENCODE_SET).to_string())
103 .collect::<Vec<_>>()
104 .join("/")
105}
106
107pub trait Client: Send + Sync + 'static {
112 type Error: StdError + Send + Sync + 'static;
114
115 fn execute(
117 &self,
118 req: HyperRequest,
119 upgraded: Option<OnUpgrade>,
120 ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
121}
122
123pub trait Upstreams: Send + Sync + 'static {
129 type Error: StdError + Send + Sync + 'static;
131
132 fn elect(
134 &self,
135 req: &Request,
136 depot: &Depot,
137 ) -> impl Future<Output = Result<&str, Self::Error>> + Send;
138}
139impl Upstreams for &'static str {
140 type Error = Infallible;
141
142 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
143 Ok(*self)
144 }
145}
146impl Upstreams for String {
147 type Error = Infallible;
148 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
149 Ok(self.as_str())
150 }
151}
152
153impl<const N: usize> Upstreams for [&'static str; N] {
154 type Error = Error;
155 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
156 if self.is_empty() {
157 return Err(Error::other("upstreams is empty"));
158 }
159 let index = fastrand::usize(..self.len());
160 Ok(self[index])
161 }
162}
163
164impl<T> Upstreams for Vec<T>
165where
166 T: AsRef<str> + Send + Sync + 'static,
167{
168 type Error = Error;
169 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
170 if self.is_empty() {
171 return Err(Error::other("upstreams is empty"));
172 }
173 let index = fastrand::usize(..self.len());
174 Ok(self[index].as_ref())
175 }
176}
177
178pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
180
181pub type HostHeaderGetter =
183 Box<dyn Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static>;
184
185pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
190 req.params().tail().map(encode_url_path)
191}
192pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
194 req.uri().query().map(Into::into)
195}
196
197pub fn default_host_header_getter(
199 forward_uri: &Uri,
200 _req: &Request,
201 _depot: &Depot,
202) -> Option<String> {
203 if let Some(host) = forward_uri.host() {
204 return Some(String::from(host));
205 }
206
207 None
208}
209
210pub fn rfc2616_host_header_getter(
213 forward_uri: &Uri,
214 req: &Request,
215 _depot: &Depot,
216) -> Option<String> {
217 let mut parts: Vec<String> = Vec::with_capacity(2);
218
219 if let Some(host) = forward_uri.host() {
220 parts.push(host.to_owned());
221
222 if let Some(scheme) = forward_uri.scheme_str()
223 && let Some(port) = forward_uri.port_u16()
224 && (scheme == "http" && port != 80 || scheme == "https" && port != 443)
225 {
226 parts.push(port.to_string());
227 }
228 }
229
230 if parts.is_empty() {
231 default_host_header_getter(forward_uri, req, _depot)
232 } else {
233 Some(parts.join(":"))
234 }
235}
236
237pub fn preserve_original_host_header_getter(
240 forward_uri: &Uri,
241 req: &Request,
242 _depot: &Depot,
243) -> Option<String> {
244 if let Some(host_header) = req.headers().get(HOST)
245 && let Ok(host) = String::from_utf8(host_header.as_bytes().to_vec())
246 {
247 return Some(host);
248 }
249
250 default_host_header_getter(forward_uri, req, _depot)
251}
252
253#[non_exhaustive]
255pub struct Proxy<U, C>
256where
257 U: Upstreams,
258 C: Client,
259{
260 pub upstreams: U,
262 pub client: C,
264 pub url_path_getter: UrlPartGetter,
266 pub url_query_getter: UrlPartGetter,
268 pub host_header_getter: HostHeaderGetter,
270 pub client_ip_forwarding_enabled: bool,
272}
273
274impl<U, C> Debug for Proxy<U, C>
275where
276 U: Upstreams,
277 C: Client,
278{
279 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
280 f.debug_struct("Proxy").finish()
281 }
282}
283
284impl<U, C> Proxy<U, C>
285where
286 U: Upstreams,
287 U::Error: Into<BoxedError>,
288 C: Client,
289{
290 #[must_use]
292 pub fn new(upstreams: U, client: C) -> Self {
293 Self {
294 upstreams,
295 client,
296 url_path_getter: Box::new(default_url_path_getter),
297 url_query_getter: Box::new(default_url_query_getter),
298 host_header_getter: Box::new(default_host_header_getter),
299 client_ip_forwarding_enabled: false,
300 }
301 }
302
303 pub fn with_client_ip_forwarding(upstreams: U, client: C) -> Self {
305 Self {
306 upstreams,
307 client,
308 url_path_getter: Box::new(default_url_path_getter),
309 url_query_getter: Box::new(default_url_query_getter),
310 host_header_getter: Box::new(default_host_header_getter),
311 client_ip_forwarding_enabled: true,
312 }
313 }
314
315 #[inline]
317 #[must_use]
318 pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
319 where
320 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
321 {
322 self.url_path_getter = Box::new(url_path_getter);
323 self
324 }
325
326 #[inline]
328 #[must_use]
329 pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
330 where
331 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
332 {
333 self.url_query_getter = Box::new(url_query_getter);
334 self
335 }
336
337 #[inline]
339 #[must_use]
340 pub fn host_header_getter<G>(mut self, host_header_getter: G) -> Self
341 where
342 G: Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static,
343 {
344 self.host_header_getter = Box::new(host_header_getter);
345 self
346 }
347
348 #[inline]
350 pub fn upstreams(&self) -> &U {
351 &self.upstreams
352 }
353 #[inline]
355 pub fn upstreams_mut(&mut self) -> &mut U {
356 &mut self.upstreams
357 }
358
359 #[inline]
361 pub fn client(&self) -> &C {
362 &self.client
363 }
364 #[inline]
366 pub fn client_mut(&mut self) -> &mut C {
367 &mut self.client
368 }
369
370 #[inline]
372 #[must_use]
373 pub fn client_ip_forwarding(mut self, enable: bool) -> Self {
374 self.client_ip_forwarding_enabled = enable;
375 self
376 }
377
378 async fn build_proxied_request(
379 &self,
380 req: &mut Request,
381 depot: &Depot,
382 ) -> Result<HyperRequest, Error> {
383 let upstream = self
384 .upstreams
385 .elect(req, depot)
386 .await
387 .map_err(Error::other)?;
388
389 if upstream.is_empty() {
390 tracing::error!("upstreams is empty");
391 return Err(Error::other("upstreams is empty"));
392 }
393
394 let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
395 let query = (self.url_query_getter)(req, depot);
396 let rest = if let Some(query) = query {
397 if query.starts_with('?') {
398 format!(
399 "{path}?{}",
400 utf8_percent_encode(&query[1..], QUERY_ENCODE_SET)
401 )
402 } else {
403 format!("{path}?{}", utf8_percent_encode(&query, QUERY_ENCODE_SET))
404 }
405 } else {
406 path
407 };
408 let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
409 format!("{}{}", upstream.trim_end_matches('/'), rest)
410 } else if upstream.ends_with('/') || rest.starts_with('/') {
411 format!("{upstream}{rest}")
412 } else if rest.is_empty() {
413 upstream.to_owned()
414 } else {
415 format!("{upstream}/{rest}")
416 };
417 let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
418 let mut build = hyper::Request::builder()
419 .method(req.method())
420 .uri(&forward_url);
421 for (key, value) in req.headers() {
422 if key != HOST {
423 build = build.header(key, value);
424 }
425 }
426 if let Some(host_value) = (self.host_header_getter)(&forward_url, req, depot) {
427 match HeaderValue::from_str(&host_value) {
428 Ok(host_value) => {
429 build = build.header(HOST, host_value);
430 }
431 Err(e) => {
432 tracing::error!(error = ?e, "invalid host header value");
433 }
434 }
435 }
436
437 if self.client_ip_forwarding_enabled {
438 let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
439 let current_xff = req.headers().get(&xff_header_name);
440
441 #[cfg(test)]
442 let system_ip_addr = match req.remote_addr() {
443 SocketAddr::IPv6(_) => Some(IpAddr::from(Ipv6Addr::new(
444 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8,
445 ))),
446 _ => Some(IpAddr::from(Ipv4Addr::new(101, 102, 103, 104))),
447 };
448
449 #[cfg(not(test))]
450 let system_ip_addr = match req.remote_addr() {
451 SocketAddr::IPv6(_) => local_ipv6().ok(),
452 _ => local_ip().ok(),
453 };
454
455 if let Some(system_ip_addr) = system_ip_addr {
456 let forwarded_addr = system_ip_addr.to_string();
457
458 let xff_value = match current_xff {
459 Some(current_xff) => match current_xff.to_str() {
460 Ok(current_xff) => format!("{forwarded_addr}, {current_xff}"),
461 _ => forwarded_addr.clone(),
462 },
463 None => forwarded_addr.clone(),
464 };
465
466 let xff_header_halue = match HeaderValue::from_str(xff_value.as_str()) {
467 Ok(xff_header_halue) => Some(xff_header_halue),
468 Err(_) => match HeaderValue::from_str(forwarded_addr.as_str()) {
469 Ok(xff_header_halue) => Some(xff_header_halue),
470 Err(e) => {
471 tracing::error!(error = ?e, "invalid x-forwarded-for header value");
472 None
473 }
474 },
475 };
476
477 if let Some(xff) = xff_header_halue
478 && let Some(headers) = build.headers_mut()
479 {
480 headers.insert(&xff_header_name, xff);
481 }
482 }
483 }
484
485 build.body(req.take_body()).map_err(Error::other)
486 }
487}
488
489#[async_trait]
490impl<U, C> Handler for Proxy<U, C>
491where
492 U: Upstreams,
493 U::Error: Into<BoxedError>,
494 C: Client,
495{
496 async fn handle(
497 &self,
498 req: &mut Request,
499 depot: &mut Depot,
500 res: &mut Response,
501 _ctrl: &mut FlowCtrl,
502 ) {
503 match self.build_proxied_request(req, depot).await {
504 Ok(proxied_request) => {
505 match self
506 .client
507 .execute(proxied_request, req.extensions_mut().remove())
508 .await
509 {
510 Ok(response) => {
511 let (
512 salvo_core::http::response::Parts {
513 status,
514 headers,
516 ..
518 },
519 body,
520 ) = response.into_parts();
521 res.status_code(status);
522 for name in headers.keys() {
523 for value in headers.get_all(name) {
524 res.headers.append(name, value.to_owned());
525 }
526 }
527 res.body(body);
528 }
529 Err(e) => {
530 tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
531 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
532 }
533 }
534 }
535 Err(e) => {
536 tracing::error!(error = ?e, "build proxied request failed");
537 }
538 }
539 }
540}
541#[inline]
542#[allow(dead_code)]
543fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
544 if headers
545 .get(&CONNECTION)
546 .map(|value| {
547 value
548 .to_str()
549 .unwrap_or_default()
550 .split(',')
551 .any(|e| e.trim() == UPGRADE)
552 })
553 .unwrap_or(false)
554 && let Some(upgrade_value) = headers.get(&UPGRADE)
555 {
556 tracing::debug!(
557 "found upgrade header with value: {:?}",
558 upgrade_value.to_str()
559 );
560 return upgrade_value.to_str().ok();
561 }
562
563 None
564}
565
566#[cfg(test)]
568mod tests {
569 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
570 use std::str::FromStr;
571
572 use super::*;
573
574 #[test]
575 fn test_encode_url_path() {
576 let path = "/test/path";
577 let encoded_path = encode_url_path(path);
578 assert_eq!(encoded_path, "/test/path");
579 }
580
581 #[test]
582 fn test_get_upgrade_type() {
583 let mut headers = HeaderMap::new();
584 headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
585 headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
586 let upgrade_type = get_upgrade_type(&headers);
587 assert_eq!(upgrade_type, Some("websocket"));
588 }
589
590 #[test]
591 fn test_host_header_handling() {
592 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
593 let uri = Uri::from_str("http://host.tld/test").unwrap();
594 let mut req = Request::new();
595 let depot = Depot::new();
596
597 assert_eq!(
598 default_host_header_getter(&uri, &req, &depot),
599 Some("host.tld".to_string())
600 );
601
602 let uri_with_port = Uri::from_str("http://host.tld:8080/test").unwrap();
603 assert_eq!(
604 rfc2616_host_header_getter(&uri_with_port, &req, &depot),
605 Some("host.tld:8080".to_string())
606 );
607
608 let uri_with_http_port = Uri::from_str("http://host.tld:80/test").unwrap();
609 assert_eq!(
610 rfc2616_host_header_getter(&uri_with_http_port, &req, &depot),
611 Some("host.tld".to_string())
612 );
613
614 let uri_with_https_port = Uri::from_str("https://host.tld:443/test").unwrap();
615 assert_eq!(
616 rfc2616_host_header_getter(&uri_with_https_port, &req, &depot),
617 Some("host.tld".to_string())
618 );
619
620 let uri_with_non_https_scheme_and_https_port =
621 Uri::from_str("http://host.tld:443/test").unwrap();
622 assert_eq!(
623 rfc2616_host_header_getter(&uri_with_non_https_scheme_and_https_port, &req, &depot),
624 Some("host.tld:443".to_string())
625 );
626
627 req.headers_mut()
628 .insert(HOST, HeaderValue::from_static("test.host.tld"));
629 assert_eq!(
630 preserve_original_host_header_getter(&uri, &req, &depot),
631 Some("test.host.tld".to_string())
632 );
633 }
634
635 #[tokio::test]
636 async fn test_client_ip_forwarding() {
637 let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
638
639 let mut request = Request::new();
640 let mut depot = Depot::new();
641
642 let proxy_without_forwarding =
644 Proxy::new(vec!["http://example.com"], HyperClient::default());
645
646 assert_eq!(proxy_without_forwarding.client_ip_forwarding_enabled, false);
647
648 let proxy_with_forwarding = proxy_without_forwarding.client_ip_forwarding(true);
649
650 assert_eq!(proxy_with_forwarding.client_ip_forwarding_enabled, true);
651
652 let proxy =
653 Proxy::with_client_ip_forwarding(vec!["http://example.com"], HyperClient::default());
654 assert_eq!(proxy.client_ip_forwarding_enabled, true);
655
656 match proxy.build_proxied_request(&mut request, &mut depot).await {
657 Ok(req) => assert_eq!(
658 req.headers().get(&xff_header_name),
659 Some(&HeaderValue::from_static("101.102.103.104"))
660 ),
661 _ => assert!(false),
662 }
663
664 *request.remote_addr_mut() =
666 SocketAddr::from(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 12345, 0, 0));
667
668 match proxy.build_proxied_request(&mut request, &mut depot).await {
669 Ok(req) => assert_eq!(
670 req.headers().get(&xff_header_name),
671 Some(&HeaderValue::from_static("1:2:3:4:5:6:7:8"))
672 ),
673 _ => assert!(false),
674 }
675
676 *request.remote_addr_mut() = SocketAddr::Unknown;
677
678 match proxy.build_proxied_request(&mut request, &mut depot).await {
679 Ok(req) => assert_eq!(
680 req.headers().get(&xff_header_name),
681 Some(&HeaderValue::from_static("101.102.103.104"))
682 ),
683 _ => assert!(false),
684 }
685
686 request.headers_mut().insert(
688 &xff_header_name,
689 HeaderValue::from_static("10.72.0.1, 127.0.0.1"),
690 );
691 *request.remote_addr_mut() =
692 SocketAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12345));
693
694 match proxy.build_proxied_request(&mut request, &mut depot).await {
695 Ok(req) => assert_eq!(
696 req.headers().get(&xff_header_name),
697 Some(&HeaderValue::from_static(
698 "101.102.103.104, 10.72.0.1, 127.0.0.1"
699 ))
700 ),
701 _ => assert!(false),
702 }
703 }
704}