1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43use std::fmt::{self, Debug, Formatter};
44
45use hyper::upgrade::OnUpgrade;
46use percent_encoding::{CONTROLS, utf8_percent_encode};
47use salvo_core::conn::SocketAddr;
48use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
49use salvo_core::http::uri::Uri;
50use salvo_core::http::{ReqBody, ResBody, StatusCode};
51use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
52
53#[cfg(test)]
54use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
55
56#[cfg(not(test))]
57use local_ip_address::{local_ip, local_ipv6};
58
59#[macro_use]
60mod cfg;
61
62cfg_feature! {
63 #![feature = "hyper-client"]
64 mod hyper_client;
65 pub use hyper_client::*;
66}
67cfg_feature! {
68 #![feature = "reqwest-client"]
69 mod reqwest_client;
70 pub use reqwest_client::*;
71}
72
73cfg_feature! {
74 #![feature = "unix-sock-client"]
75 #[cfg(unix)]
76 mod unix_sock_client;
77 #[cfg(unix)]
78 pub use unix_sock_client::*;
79}
80
81type HyperRequest = hyper::Request<ReqBody>;
82type HyperResponse = hyper::Response<ResBody>;
83
84const X_FORWARDER_FOR_HEADER_NAME: &str = "x-forwarded-for";
85
86#[inline]
88pub(crate) fn encode_url_path(path: &str) -> String {
89 path.split('/')
90 .map(|s| utf8_percent_encode(s, CONTROLS).to_string())
91 .collect::<Vec<_>>()
92 .join("/")
93}
94
95pub trait Client: Send + Sync + 'static {
100 type Error: StdError + Send + Sync + 'static;
102
103 fn execute(
105 &self,
106 req: HyperRequest,
107 upgraded: Option<OnUpgrade>,
108 ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
109}
110
111pub trait Upstreams: Send + Sync + 'static {
117 type Error: StdError + Send + Sync + 'static;
119
120 fn elect(
122 &self,
123 req: &Request,
124 depot: &Depot,
125 ) -> impl Future<Output = Result<&str, Self::Error>> + Send;
126}
127impl Upstreams for &'static str {
128 type Error = Infallible;
129
130 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
131 Ok(*self)
132 }
133}
134impl Upstreams for String {
135 type Error = Infallible;
136 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
137 Ok(self.as_str())
138 }
139}
140
141impl<const N: usize> Upstreams for [&'static str; N] {
142 type Error = Error;
143 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
144 if self.is_empty() {
145 return Err(Error::other("upstreams is empty"));
146 }
147 let index = fastrand::usize(..self.len());
148 Ok(self[index])
149 }
150}
151
152impl<T> Upstreams for Vec<T>
153where
154 T: AsRef<str> + Send + Sync + 'static,
155{
156 type Error = Error;
157 async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
158 if self.is_empty() {
159 return Err(Error::other("upstreams is empty"));
160 }
161 let index = fastrand::usize(..self.len());
162 Ok(self[index].as_ref())
163 }
164}
165
166pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
168
169pub type HostHeaderGetter =
171 Box<dyn Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static>;
172
173pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
178 req.params().tail().map(encode_url_path)
179}
180pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
182 req.uri().query().map(Into::into)
183}
184
185pub fn default_host_header_getter(
187 forward_uri: &Uri,
188 _req: &Request,
189 _depot: &Depot,
190) -> Option<String> {
191 if let Some(host) = forward_uri.host() {
192 return Some(String::from(host));
193 }
194
195 None
196}
197
198pub fn rfc2616_host_header_getter(
201 forward_uri: &Uri,
202 req: &Request,
203 _depot: &Depot,
204) -> Option<String> {
205 let mut parts: Vec<String> = Vec::with_capacity(2);
206
207 if let Some(host) = forward_uri.host() {
208 parts.push(host.to_owned());
209
210 if let Some(scheme) = forward_uri.scheme_str()
211 && let Some(port) = forward_uri.port_u16()
212 && (scheme == "http" && port != 80 || scheme == "https" && port != 443)
213 {
214 parts.push(port.to_string());
215 }
216 }
217
218 if parts.is_empty() {
219 default_host_header_getter(forward_uri, req, _depot)
220 } else {
221 Some(parts.join(":"))
222 }
223}
224
225pub fn preserve_original_host_header_getter(
227 forward_uri: &Uri,
228 req: &Request,
229 _depot: &Depot,
230) -> Option<String> {
231 if let Some(host_header) = req.headers().get(HOST)
232 && let Ok(host) = String::from_utf8(host_header.as_bytes().to_vec())
233 {
234 return Some(host);
235 }
236
237 default_host_header_getter(forward_uri, req, _depot)
238}
239
240#[non_exhaustive]
242pub struct Proxy<U, C>
243where
244 U: Upstreams,
245 C: Client,
246{
247 pub upstreams: U,
249 pub client: C,
251 pub url_path_getter: UrlPartGetter,
253 pub url_query_getter: UrlPartGetter,
255 pub host_header_getter: HostHeaderGetter,
257 pub client_ip_forwarding_enabled: bool,
259}
260
261impl<U, C> Debug for Proxy<U, C>
262where
263 U: Upstreams,
264 C: Client,
265{
266 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
267 f.debug_struct("Proxy").finish()
268 }
269}
270
271impl<U, C> Proxy<U, C>
272where
273 U: Upstreams,
274 U::Error: Into<BoxedError>,
275 C: Client,
276{
277 #[must_use]
279 pub fn new(upstreams: U, client: C) -> Self {
280 Self {
281 upstreams,
282 client,
283 url_path_getter: Box::new(default_url_path_getter),
284 url_query_getter: Box::new(default_url_query_getter),
285 host_header_getter: Box::new(default_host_header_getter),
286 client_ip_forwarding_enabled: false,
287 }
288 }
289
290 pub fn with_client_ip_forwarding(upstreams: U, client: C) -> Self {
292 Self {
293 upstreams,
294 client,
295 url_path_getter: Box::new(default_url_path_getter),
296 url_query_getter: Box::new(default_url_query_getter),
297 host_header_getter: Box::new(default_host_header_getter),
298 client_ip_forwarding_enabled: true,
299 }
300 }
301
302 #[inline]
304 #[must_use]
305 pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
306 where
307 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
308 {
309 self.url_path_getter = Box::new(url_path_getter);
310 self
311 }
312
313 #[inline]
315 #[must_use]
316 pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
317 where
318 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
319 {
320 self.url_query_getter = Box::new(url_query_getter);
321 self
322 }
323
324 #[inline]
326 #[must_use]
327 pub fn host_header_getter<G>(mut self, host_header_getter: G) -> Self
328 where
329 G: Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static,
330 {
331 self.host_header_getter = Box::new(host_header_getter);
332 self
333 }
334
335 #[inline]
337 pub fn upstreams(&self) -> &U {
338 &self.upstreams
339 }
340 #[inline]
342 pub fn upstreams_mut(&mut self) -> &mut U {
343 &mut self.upstreams
344 }
345
346 #[inline]
348 pub fn client(&self) -> &C {
349 &self.client
350 }
351 #[inline]
353 pub fn client_mut(&mut self) -> &mut C {
354 &mut self.client
355 }
356
357 #[inline]
359 #[must_use]
360 pub fn client_ip_forwarding(mut self, enable: bool) -> Self {
361 self.client_ip_forwarding_enabled = enable;
362 self
363 }
364
365 async fn build_proxied_request(
366 &self,
367 req: &mut Request,
368 depot: &Depot,
369 ) -> Result<HyperRequest, Error> {
370 let upstream = self
371 .upstreams
372 .elect(req, depot)
373 .await
374 .map_err(Error::other)?;
375
376 if upstream.is_empty() {
377 tracing::error!("upstreams is empty");
378 return Err(Error::other("upstreams is empty"));
379 }
380
381 let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
382 let query = (self.url_query_getter)(req, depot);
383 let rest = if let Some(query) = query {
384 if query.starts_with('?') {
385 format!("{path}{query}")
386 } else {
387 format!("{path}?{query}")
388 }
389 } else {
390 path
391 };
392 let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
393 format!("{}{}", upstream.trim_end_matches('/'), rest)
394 } else if upstream.ends_with('/') || rest.starts_with('/') {
395 format!("{upstream}{rest}")
396 } else if rest.is_empty() {
397 upstream.to_owned()
398 } else {
399 format!("{upstream}/{rest}")
400 };
401 let forward_url = url::Url::parse(&forward_url).map_err(|e| {
402 Error::other(format!("url::Url::parse failed for '{forward_url}': {e}"))
403 })?;
404 let forward_url: Uri = forward_url
405 .as_str()
406 .parse()
407 .map_err(|e| Error::other(format!("Uri::parse failed for '{forward_url}': {e}")))?;
408 let mut build = hyper::Request::builder()
409 .method(req.method())
410 .uri(&forward_url);
411 for (key, value) in req.headers() {
412 if key != HOST {
413 build = build.header(key, value);
414 }
415 }
416 if let Some(host_value) = (self.host_header_getter)(&forward_url, req, depot) {
417 match HeaderValue::from_str(&host_value) {
418 Ok(host_value) => {
419 build = build.header(HOST, host_value);
420 }
421 Err(e) => {
422 tracing::error!(error = ?e, "invalid host header value");
423 }
424 }
425 }
426
427 if self.client_ip_forwarding_enabled {
428 let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
429 let current_xff = req.headers().get(&xff_header_name);
430
431 #[cfg(test)]
432 let system_ip_addr = match req.remote_addr() {
433 SocketAddr::IPv6(_) => Some(IpAddr::from(Ipv6Addr::new(
434 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8,
435 ))),
436 _ => Some(IpAddr::from(Ipv4Addr::new(101, 102, 103, 104))),
437 };
438
439 #[cfg(not(test))]
440 let system_ip_addr = match req.remote_addr() {
441 SocketAddr::IPv6(_) => local_ipv6().ok(),
442 _ => local_ip().ok(),
443 };
444
445 if let Some(system_ip_addr) = system_ip_addr {
446 let forwarded_addr = system_ip_addr.to_string();
447
448 let xff_value = match current_xff {
449 Some(current_xff) => match current_xff.to_str() {
450 Ok(current_xff) => format!("{forwarded_addr}, {current_xff}"),
451 _ => forwarded_addr.clone(),
452 },
453 None => forwarded_addr.clone(),
454 };
455
456 let xff_header_halue = match HeaderValue::from_str(xff_value.as_str()) {
457 Ok(xff_header_halue) => Some(xff_header_halue),
458 Err(_) => match HeaderValue::from_str(forwarded_addr.as_str()) {
459 Ok(xff_header_halue) => Some(xff_header_halue),
460 Err(e) => {
461 tracing::error!(error = ?e, "invalid x-forwarded-for header value");
462 None
463 }
464 },
465 };
466
467 if let Some(xff) = xff_header_halue
468 && let Some(headers) = build.headers_mut()
469 {
470 headers.insert(&xff_header_name, xff);
471 }
472 }
473 }
474
475 build.body(req.take_body()).map_err(Error::other)
476 }
477}
478
479#[async_trait]
480impl<U, C> Handler for Proxy<U, C>
481where
482 U: Upstreams,
483 U::Error: Into<BoxedError>,
484 C: Client,
485{
486 async fn handle(
487 &self,
488 req: &mut Request,
489 depot: &mut Depot,
490 res: &mut Response,
491 _ctrl: &mut FlowCtrl,
492 ) {
493 match self.build_proxied_request(req, depot).await {
494 Ok(proxied_request) => {
495 match self
496 .client
497 .execute(proxied_request, req.extensions_mut().remove())
498 .await
499 {
500 Ok(response) => {
501 let (
502 salvo_core::http::response::Parts {
503 status,
504 headers,
506 ..
508 },
509 body,
510 ) = response.into_parts();
511 res.status_code(status);
512 for name in headers.keys() {
513 for value in headers.get_all(name) {
514 res.headers.append(name, value.to_owned());
515 }
516 }
517 res.body(body);
518 }
519 Err(e) => {
520 tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
521 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
522 }
523 }
524 }
525 Err(e) => {
526 tracing::error!(error = ?e, "build proxied request failed");
527 }
528 }
529 }
530}
531#[inline]
532#[allow(dead_code)]
533fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
534 if headers
535 .get(&CONNECTION)
536 .map(|value| {
537 value
538 .to_str()
539 .unwrap_or_default()
540 .split(',')
541 .any(|e| e.trim() == UPGRADE)
542 })
543 .unwrap_or(false)
544 && let Some(upgrade_value) = headers.get(&UPGRADE)
545 {
546 tracing::debug!(
547 "found upgrade header with value: {:?}",
548 upgrade_value.to_str()
549 );
550 return upgrade_value.to_str().ok();
551 }
552
553 None
554}
555
556#[cfg(test)]
558mod tests {
559 use super::*;
560
561 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
562 use std::str::FromStr;
563
564 #[test]
565 fn test_encode_url_path() {
566 let path = "/test/path";
567 let encoded_path = encode_url_path(path);
568 assert_eq!(encoded_path, "/test/path");
569 }
570
571 #[test]
572 fn test_get_upgrade_type() {
573 let mut headers = HeaderMap::new();
574 headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
575 headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
576 let upgrade_type = get_upgrade_type(&headers);
577 assert_eq!(upgrade_type, Some("websocket"));
578 }
579
580 #[test]
581 fn test_host_header_handling() {
582 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
583 let uri = Uri::from_str("http://host.tld/test").unwrap();
584 let mut req = Request::new();
585 let depot = Depot::new();
586
587 assert_eq!(
588 default_host_header_getter(&uri, &req, &depot),
589 Some("host.tld".to_string())
590 );
591
592 let uri_with_port = Uri::from_str("http://host.tld:8080/test").unwrap();
593 assert_eq!(
594 rfc2616_host_header_getter(&uri_with_port, &req, &depot),
595 Some("host.tld:8080".to_string())
596 );
597
598 let uri_with_http_port = Uri::from_str("http://host.tld:80/test").unwrap();
599 assert_eq!(
600 rfc2616_host_header_getter(&uri_with_http_port, &req, &depot),
601 Some("host.tld".to_string())
602 );
603
604 let uri_with_https_port = Uri::from_str("https://host.tld:443/test").unwrap();
605 assert_eq!(
606 rfc2616_host_header_getter(&uri_with_https_port, &req, &depot),
607 Some("host.tld".to_string())
608 );
609
610 let uri_with_non_https_scheme_and_https_port =
611 Uri::from_str("http://host.tld:443/test").unwrap();
612 assert_eq!(
613 rfc2616_host_header_getter(&uri_with_non_https_scheme_and_https_port, &req, &depot),
614 Some("host.tld:443".to_string())
615 );
616
617 req.headers_mut()
618 .insert(HOST, HeaderValue::from_static("test.host.tld"));
619 assert_eq!(
620 preserve_original_host_header_getter(&uri, &req, &depot),
621 Some("test.host.tld".to_string())
622 );
623 }
624
625 #[tokio::test]
626 async fn test_client_ip_forwarding() {
627 let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
628
629 let mut request = Request::new();
630 let mut depot = Depot::new();
631
632 let proxy_without_forwarding =
634 Proxy::new(vec!["http://example.com"], HyperClient::default());
635
636 assert_eq!(proxy_without_forwarding.client_ip_forwarding_enabled, false);
637
638 let proxy_with_forwarding = proxy_without_forwarding.client_ip_forwarding(true);
639
640 assert_eq!(proxy_with_forwarding.client_ip_forwarding_enabled, true);
641
642 let proxy =
643 Proxy::with_client_ip_forwarding(vec!["http://example.com"], HyperClient::default());
644 assert_eq!(proxy.client_ip_forwarding_enabled, true);
645
646 match proxy.build_proxied_request(&mut request, &mut depot).await {
647 Ok(req) => assert_eq!(
648 req.headers().get(&xff_header_name),
649 Some(&HeaderValue::from_static("101.102.103.104"))
650 ),
651 _ => assert!(false),
652 }
653
654 *request.remote_addr_mut() =
656 SocketAddr::from(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 12345, 0, 0));
657
658 match proxy.build_proxied_request(&mut request, &mut depot).await {
659 Ok(req) => assert_eq!(
660 req.headers().get(&xff_header_name),
661 Some(&HeaderValue::from_static("1:2:3:4:5:6:7:8"))
662 ),
663 _ => assert!(false),
664 }
665
666 *request.remote_addr_mut() = SocketAddr::Unknown;
667
668 match proxy.build_proxied_request(&mut request, &mut depot).await {
669 Ok(req) => assert_eq!(
670 req.headers().get(&xff_header_name),
671 Some(&HeaderValue::from_static("101.102.103.104"))
672 ),
673 _ => assert!(false),
674 }
675
676 request.headers_mut().insert(
678 &xff_header_name,
679 HeaderValue::from_static("10.72.0.1, 127.0.0.1"),
680 );
681 *request.remote_addr_mut() =
682 SocketAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12345));
683
684 match proxy.build_proxied_request(&mut request, &mut depot).await {
685 Ok(req) => assert_eq!(
686 req.headers().get(&xff_header_name),
687 Some(&HeaderValue::from_static(
688 "101.102.103.104, 10.72.0.1, 127.0.0.1"
689 ))
690 ),
691 _ => assert!(false),
692 }
693 }
694}