Skip to main content

salvo_proxy/
lib.rs

1//! Provide HTTP proxy capabilities for the Salvo web framework.
2//!
3//! This crate allows you to easily forward requests to upstream servers,
4//! supporting both HTTP and HTTPS protocols. It's useful for creating API gateways,
5//! load balancers, and reverse proxies.
6//!
7//! # Example
8//!
9//! In this example, requests to different hosts are proxied to different upstream servers:
10//! - Requests to <http://127.0.0.1:8698/> are proxied to <https://www.rust-lang.org>
11//! - Requests to <http://localhost:8698/> are proxied to <https://crates.io>
12//!
13//! ```no_run
14//! use salvo_core::prelude::*;
15//! use salvo_proxy::Proxy;
16//!
17//! #[tokio::main]
18//! async fn main() {
19//!     let router = Router::new()
20//!         .push(
21//!             Router::new()
22//!                 .host("127.0.0.1")
23//!                 .path("{**rest}")
24//!                 .goal(Proxy::use_hyper_client("https://www.rust-lang.org")),
25//!         )
26//!         .push(
27//!             Router::new()
28//!                 .host("localhost")
29//!                 .path("{**rest}")
30//!                 .goal(Proxy::use_hyper_client("https://crates.io")),
31//!         );
32//!
33//!     let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
34//!     Server::new(acceptor).serve(router).await;
35//! }
36//! ```
37#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43use std::fmt::{self, Debug, Formatter};
44#[cfg(test)]
45use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
46
47use hyper::upgrade::OnUpgrade;
48#[cfg(not(test))]
49use local_ip_address::{local_ip, local_ipv6};
50use percent_encoding::{AsciiSet, CONTROLS, utf8_percent_encode};
51use salvo_core::conn::SocketAddr;
52use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
53use salvo_core::http::uri::Uri;
54use salvo_core::http::{ReqBody, ResBody, StatusCode};
55use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
56
57#[macro_use]
58mod cfg;
59
60cfg_feature! {
61    #![feature = "hyper-client"]
62    mod hyper_client;
63    pub use hyper_client::*;
64}
65cfg_feature! {
66    #![feature = "reqwest-client"]
67    mod reqwest_client;
68    pub use reqwest_client::*;
69}
70
71cfg_feature! {
72    #![feature = "unix-sock-client"]
73    #[cfg(unix)]
74    mod unix_sock_client;
75    #[cfg(unix)]
76    pub use unix_sock_client::*;
77}
78
79type HyperRequest = hyper::Request<ReqBody>;
80type HyperResponse = hyper::Response<ResBody>;
81
82const X_FORWARDER_FOR_HEADER_NAME: &str = "x-forwarded-for";
83
84const QUERY_ENCODE_SET: &AsciiSet = &CONTROLS
85    .add(b' ')
86    .add(b'"')
87    .add(b'#')
88    .add(b'<')
89    .add(b'>')
90    .add(b'`');
91const PATH_ENCODE_SET: &AsciiSet = &QUERY_ENCODE_SET
92    .add(b'?')
93    .add(b'^')
94    .add(b'`')
95    .add(b'{')
96    .add(b'}');
97
98/// Encode url path. This can be used when build your custom url path getter.
99#[inline]
100pub(crate) fn encode_url_path(path: &str) -> String {
101    path.split('/')
102        .map(|s| utf8_percent_encode(s, PATH_ENCODE_SET).to_string())
103        .collect::<Vec<_>>()
104        .join("/")
105}
106
107/// Client trait for implementing different HTTP clients for proxying.
108///
109/// Implement this trait to create custom proxy clients with different
110/// backends or configurations.
111pub trait Client: Send + Sync + 'static {
112    /// Error type returned by the client.
113    type Error: StdError + Send + Sync + 'static;
114
115    /// Execute a request through the proxy client.
116    fn execute(
117        &self,
118        req: HyperRequest,
119        upgraded: Option<OnUpgrade>,
120    ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
121}
122
123/// Upstreams trait for selecting target servers.
124///
125/// Implement this trait to customize how target servers are selected
126/// for proxying requests. This can be used to implement load balancing,
127/// failover, or other server selection strategies.
128pub trait Upstreams: Send + Sync + 'static {
129    /// Error type returned when selecting a server fails.
130    type Error: StdError + Send + Sync + 'static;
131
132    /// Elect a server to handle the current request.
133    fn elect(
134        &self,
135        req: &Request,
136        depot: &Depot,
137    ) -> impl Future<Output = Result<&str, Self::Error>> + Send;
138}
139impl Upstreams for &'static str {
140    type Error = Infallible;
141
142    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
143        Ok(*self)
144    }
145}
146impl Upstreams for String {
147    type Error = Infallible;
148    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
149        Ok(self.as_str())
150    }
151}
152
153impl<const N: usize> Upstreams for [&'static str; N] {
154    type Error = Error;
155    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
156        if self.is_empty() {
157            return Err(Error::other("upstreams is empty"));
158        }
159        let index = fastrand::usize(..self.len());
160        Ok(self[index])
161    }
162}
163
164impl<T> Upstreams for Vec<T>
165where
166    T: AsRef<str> + Send + Sync + 'static,
167{
168    type Error = Error;
169    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
170        if self.is_empty() {
171            return Err(Error::other("upstreams is empty"));
172        }
173        let index = fastrand::usize(..self.len());
174        Ok(self[index].as_ref())
175    }
176}
177
178/// Url part getter. You can use this to get the proxied url path or query.
179pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
180
181/// Host header getter. You can use this to get the host header for the proxied request.
182pub type HostHeaderGetter =
183    Box<dyn Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static>;
184
185/// Default url path getter.
186///
187/// This getter will get the last param as the rest url path from request.
188/// In most case you should use wildcard param, like `{**rest}`, `{*+rest}`.
189pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
190    req.params().tail().map(encode_url_path)
191}
192/// Default url query getter. This getter just return the query string from request uri.
193pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
194    req.uri().query().map(Into::into)
195}
196
197/// Default host header getter. This getter will get the host header from request uri
198pub fn default_host_header_getter(
199    forward_uri: &Uri,
200    _req: &Request,
201    _depot: &Depot,
202) -> Option<String> {
203    if let Some(host) = forward_uri.host() {
204        return Some(String::from(host));
205    }
206
207    None
208}
209
210/// RFC2616 complieant host header getter. This getter will get the host header from request uri,
211/// and add port if it's not default port. Falls back to default upon any forward URI parse error.
212pub fn rfc2616_host_header_getter(
213    forward_uri: &Uri,
214    req: &Request,
215    _depot: &Depot,
216) -> Option<String> {
217    let mut parts: Vec<String> = Vec::with_capacity(2);
218
219    if let Some(host) = forward_uri.host() {
220        parts.push(host.to_owned());
221
222        if let Some(scheme) = forward_uri.scheme_str()
223            && let Some(port) = forward_uri.port_u16()
224            && (scheme == "http" && port != 80 || scheme == "https" && port != 443)
225        {
226            parts.push(port.to_string());
227        }
228    }
229
230    if parts.is_empty() {
231        default_host_header_getter(forward_uri, req, _depot)
232    } else {
233        Some(parts.join(":"))
234    }
235}
236
237/// Preserve original host header getter. Propagates the original request host header to the proxied
238/// request.
239pub fn preserve_original_host_header_getter(
240    forward_uri: &Uri,
241    req: &Request,
242    _depot: &Depot,
243) -> Option<String> {
244    if let Some(host_header) = req.headers().get(HOST)
245        && let Ok(host) = String::from_utf8(host_header.as_bytes().to_vec())
246    {
247        return Some(host);
248    }
249
250    default_host_header_getter(forward_uri, req, _depot)
251}
252
253/// Handler that can proxy request to other server.
254#[non_exhaustive]
255pub struct Proxy<U, C>
256where
257    U: Upstreams,
258    C: Client,
259{
260    /// Upstreams list.
261    pub upstreams: U,
262    /// [`Client`] for proxy.
263    pub client: C,
264    /// Url path getter.
265    pub url_path_getter: UrlPartGetter,
266    /// Url query getter.
267    pub url_query_getter: UrlPartGetter,
268    /// Host header getter
269    pub host_header_getter: HostHeaderGetter,
270    /// Flag to enable x-forwarded-for header.
271    pub client_ip_forwarding_enabled: bool,
272}
273
274impl<U, C> Debug for Proxy<U, C>
275where
276    U: Upstreams,
277    C: Client,
278{
279    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
280        f.debug_struct("Proxy").finish()
281    }
282}
283
284impl<U, C> Proxy<U, C>
285where
286    U: Upstreams,
287    U::Error: Into<BoxedError>,
288    C: Client,
289{
290    /// Create new `Proxy` with upstreams list.
291    #[must_use]
292    pub fn new(upstreams: U, client: C) -> Self {
293        Self {
294            upstreams,
295            client,
296            url_path_getter: Box::new(default_url_path_getter),
297            url_query_getter: Box::new(default_url_query_getter),
298            host_header_getter: Box::new(default_host_header_getter),
299            client_ip_forwarding_enabled: false,
300        }
301    }
302
303    /// Create new `Proxy` with upstreams list and enable x-forwarded-for header.
304    pub fn with_client_ip_forwarding(upstreams: U, client: C) -> Self {
305        Self {
306            upstreams,
307            client,
308            url_path_getter: Box::new(default_url_path_getter),
309            url_query_getter: Box::new(default_url_query_getter),
310            host_header_getter: Box::new(default_host_header_getter),
311            client_ip_forwarding_enabled: true,
312        }
313    }
314
315    /// Set url path getter.
316    #[inline]
317    #[must_use]
318    pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
319    where
320        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
321    {
322        self.url_path_getter = Box::new(url_path_getter);
323        self
324    }
325
326    /// Set url query getter.
327    #[inline]
328    #[must_use]
329    pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
330    where
331        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
332    {
333        self.url_query_getter = Box::new(url_query_getter);
334        self
335    }
336
337    /// Set host header query getter.
338    #[inline]
339    #[must_use]
340    pub fn host_header_getter<G>(mut self, host_header_getter: G) -> Self
341    where
342        G: Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static,
343    {
344        self.host_header_getter = Box::new(host_header_getter);
345        self
346    }
347
348    /// Get upstreams list.
349    #[inline]
350    pub fn upstreams(&self) -> &U {
351        &self.upstreams
352    }
353    /// Get upstreams mutable list.
354    #[inline]
355    pub fn upstreams_mut(&mut self) -> &mut U {
356        &mut self.upstreams
357    }
358
359    /// Get client reference.
360    #[inline]
361    pub fn client(&self) -> &C {
362        &self.client
363    }
364    /// Get client mutable reference.
365    #[inline]
366    pub fn client_mut(&mut self) -> &mut C {
367        &mut self.client
368    }
369
370    /// Enable x-forwarded-for header prepending.
371    #[inline]
372    #[must_use]
373    pub fn client_ip_forwarding(mut self, enable: bool) -> Self {
374        self.client_ip_forwarding_enabled = enable;
375        self
376    }
377
378    async fn build_proxied_request(
379        &self,
380        req: &mut Request,
381        depot: &Depot,
382    ) -> Result<HyperRequest, Error> {
383        let upstream = self
384            .upstreams
385            .elect(req, depot)
386            .await
387            .map_err(Error::other)?;
388
389        if upstream.is_empty() {
390            tracing::error!("upstreams is empty");
391            return Err(Error::other("upstreams is empty"));
392        }
393
394        let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
395        let query = (self.url_query_getter)(req, depot);
396        let rest = if let Some(query) = query {
397            if let Some(stripped) = query.strip_prefix('?') {
398                format!("{path}?{}", utf8_percent_encode(stripped, QUERY_ENCODE_SET))
399            } else {
400                format!("{path}?{}", utf8_percent_encode(&query, QUERY_ENCODE_SET))
401            }
402        } else {
403            path
404        };
405        let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
406            format!("{}{}", upstream.trim_end_matches('/'), rest)
407        } else if upstream.ends_with('/') || rest.starts_with('/') {
408            format!("{upstream}{rest}")
409        } else if rest.is_empty() {
410            upstream.to_owned()
411        } else {
412            format!("{upstream}/{rest}")
413        };
414        let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
415        let mut build = hyper::Request::builder()
416            .method(req.method())
417            .uri(&forward_url);
418        for (key, value) in req.headers() {
419            if key != HOST {
420                build = build.header(key, value);
421            }
422        }
423        if let Some(host_value) = (self.host_header_getter)(&forward_url, req, depot) {
424            match HeaderValue::from_str(&host_value) {
425                Ok(host_value) => {
426                    build = build.header(HOST, host_value);
427                }
428                Err(e) => {
429                    tracing::error!(error = ?e, "invalid host header value");
430                }
431            }
432        }
433
434        if self.client_ip_forwarding_enabled {
435            let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
436            let current_xff = req.headers().get(&xff_header_name);
437
438            #[cfg(test)]
439            let system_ip_addr = match req.remote_addr() {
440                SocketAddr::IPv6(_) => Some(IpAddr::from(Ipv6Addr::new(
441                    0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8,
442                ))),
443                _ => Some(IpAddr::from(Ipv4Addr::new(101, 102, 103, 104))),
444            };
445
446            #[cfg(not(test))]
447            let system_ip_addr = match req.remote_addr() {
448                SocketAddr::IPv6(_) => local_ipv6().ok(),
449                _ => local_ip().ok(),
450            };
451
452            if let Some(system_ip_addr) = system_ip_addr {
453                let forwarded_addr = system_ip_addr.to_string();
454
455                let xff_value = match current_xff {
456                    Some(current_xff) => match current_xff.to_str() {
457                        Ok(current_xff) => format!("{forwarded_addr}, {current_xff}"),
458                        _ => forwarded_addr.clone(),
459                    },
460                    None => forwarded_addr.clone(),
461                };
462
463                let xff_header_halue = match HeaderValue::from_str(xff_value.as_str()) {
464                    Ok(xff_header_halue) => Some(xff_header_halue),
465                    Err(_) => match HeaderValue::from_str(forwarded_addr.as_str()) {
466                        Ok(xff_header_halue) => Some(xff_header_halue),
467                        Err(e) => {
468                            tracing::error!(error = ?e, "invalid x-forwarded-for header value");
469                            None
470                        }
471                    },
472                };
473
474                if let Some(xff) = xff_header_halue
475                    && let Some(headers) = build.headers_mut()
476                {
477                    headers.insert(&xff_header_name, xff);
478                }
479            }
480        }
481
482        build.body(req.take_body()).map_err(Error::other)
483    }
484}
485
486#[async_trait]
487impl<U, C> Handler for Proxy<U, C>
488where
489    U: Upstreams,
490    U::Error: Into<BoxedError>,
491    C: Client,
492{
493    async fn handle(
494        &self,
495        req: &mut Request,
496        depot: &mut Depot,
497        res: &mut Response,
498        _ctrl: &mut FlowCtrl,
499    ) {
500        match self.build_proxied_request(req, depot).await {
501            Ok(proxied_request) => {
502                match self
503                    .client
504                    .execute(proxied_request, req.extensions_mut().remove())
505                    .await
506                {
507                    Ok(response) => {
508                        let (
509                            salvo_core::http::response::Parts {
510                                status,
511                                // version,
512                                headers,
513                                // extensions,
514                                ..
515                            },
516                            body,
517                        ) = response.into_parts();
518                        res.status_code(status);
519                        for name in headers.keys() {
520                            for value in headers.get_all(name) {
521                                res.headers.append(name, value.to_owned());
522                            }
523                        }
524                        res.body(body);
525                    }
526                    Err(e) => {
527                        tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
528                        res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
529                    }
530                }
531            }
532            Err(e) => {
533                tracing::error!(error = ?e, "build proxied request failed");
534            }
535        }
536    }
537}
538#[inline]
539#[allow(dead_code)]
540fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
541    if headers
542        .get(&CONNECTION)
543        .map(|value| {
544            value
545                .to_str()
546                .unwrap_or_default()
547                .split(',')
548                .any(|e| e.trim() == UPGRADE)
549        })
550        .unwrap_or(false)
551        && let Some(upgrade_value) = headers.get(&UPGRADE)
552    {
553        tracing::debug!(
554            "found upgrade header with value: {:?}",
555            upgrade_value.to_str()
556        );
557        return upgrade_value.to_str().ok();
558    }
559
560    None
561}
562
563// Unit tests for Proxy
564#[cfg(test)]
565mod tests {
566    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
567    use std::str::FromStr;
568
569    use super::*;
570
571    #[test]
572    fn test_encode_url_path() {
573        let path = "/test/path";
574        let encoded_path = encode_url_path(path);
575        assert_eq!(encoded_path, "/test/path");
576    }
577
578    #[test]
579    fn test_get_upgrade_type() {
580        let mut headers = HeaderMap::new();
581        headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
582        headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
583        let upgrade_type = get_upgrade_type(&headers);
584        assert_eq!(upgrade_type, Some("websocket"));
585    }
586
587    #[test]
588    fn test_host_header_handling() {
589        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
590        let uri = Uri::from_str("http://host.tld/test").unwrap();
591        let mut req = Request::new();
592        let depot = Depot::new();
593
594        assert_eq!(
595            default_host_header_getter(&uri, &req, &depot),
596            Some("host.tld".to_string())
597        );
598
599        let uri_with_port = Uri::from_str("http://host.tld:8080/test").unwrap();
600        assert_eq!(
601            rfc2616_host_header_getter(&uri_with_port, &req, &depot),
602            Some("host.tld:8080".to_string())
603        );
604
605        let uri_with_http_port = Uri::from_str("http://host.tld:80/test").unwrap();
606        assert_eq!(
607            rfc2616_host_header_getter(&uri_with_http_port, &req, &depot),
608            Some("host.tld".to_string())
609        );
610
611        let uri_with_https_port = Uri::from_str("https://host.tld:443/test").unwrap();
612        assert_eq!(
613            rfc2616_host_header_getter(&uri_with_https_port, &req, &depot),
614            Some("host.tld".to_string())
615        );
616
617        let uri_with_non_https_scheme_and_https_port =
618            Uri::from_str("http://host.tld:443/test").unwrap();
619        assert_eq!(
620            rfc2616_host_header_getter(&uri_with_non_https_scheme_and_https_port, &req, &depot),
621            Some("host.tld:443".to_string())
622        );
623
624        req.headers_mut()
625            .insert(HOST, HeaderValue::from_static("test.host.tld"));
626        assert_eq!(
627            preserve_original_host_header_getter(&uri, &req, &depot),
628            Some("test.host.tld".to_string())
629        );
630    }
631
632    #[tokio::test]
633    async fn test_client_ip_forwarding() {
634        let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
635
636        let mut request = Request::new();
637        let mut depot = Depot::new();
638
639        // Test functionality not broken
640        let proxy_without_forwarding =
641            Proxy::new(vec!["http://example.com"], HyperClient::default());
642
643        assert_eq!(proxy_without_forwarding.client_ip_forwarding_enabled, false);
644
645        let proxy_with_forwarding = proxy_without_forwarding.client_ip_forwarding(true);
646
647        assert_eq!(proxy_with_forwarding.client_ip_forwarding_enabled, true);
648
649        let proxy =
650            Proxy::with_client_ip_forwarding(vec!["http://example.com"], HyperClient::default());
651        assert_eq!(proxy.client_ip_forwarding_enabled, true);
652
653        match proxy.build_proxied_request(&mut request, &mut depot).await {
654            Ok(req) => assert_eq!(
655                req.headers().get(&xff_header_name),
656                Some(&HeaderValue::from_static("101.102.103.104"))
657            ),
658            _ => assert!(false),
659        }
660
661        // Test choosing correct IP version depending on remote address
662        *request.remote_addr_mut() =
663            SocketAddr::from(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 12345, 0, 0));
664
665        match proxy.build_proxied_request(&mut request, &mut depot).await {
666            Ok(req) => assert_eq!(
667                req.headers().get(&xff_header_name),
668                Some(&HeaderValue::from_static("1:2:3:4:5:6:7:8"))
669            ),
670            _ => assert!(false),
671        }
672
673        *request.remote_addr_mut() = SocketAddr::Unknown;
674
675        match proxy.build_proxied_request(&mut request, &mut depot).await {
676            Ok(req) => assert_eq!(
677                req.headers().get(&xff_header_name),
678                Some(&HeaderValue::from_static("101.102.103.104"))
679            ),
680            _ => assert!(false),
681        }
682
683        // Test IP prepending when XFF header already exists in initial request.
684        request.headers_mut().insert(
685            &xff_header_name,
686            HeaderValue::from_static("10.72.0.1, 127.0.0.1"),
687        );
688        *request.remote_addr_mut() =
689            SocketAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12345));
690
691        match proxy.build_proxied_request(&mut request, &mut depot).await {
692            Ok(req) => assert_eq!(
693                req.headers().get(&xff_header_name),
694                Some(&HeaderValue::from_static(
695                    "101.102.103.104, 10.72.0.1, 127.0.0.1"
696                ))
697            ),
698            _ => assert!(false),
699        }
700    }
701}