salvo_proxy/
lib.rs

1//! Provide HTTP proxy capabilities for the Salvo web framework.
2//!
3//! This crate allows you to easily forward requests to upstream servers,
4//! supporting both HTTP and HTTPS protocols. It's useful for creating API gateways,
5//! load balancers, and reverse proxies.
6//!
7//! # Example
8//!
9//! In this example, requests to different hosts are proxied to different upstream servers:
10//! - Requests to <http://127.0.0.1:8698/> are proxied to <https://www.rust-lang.org>
11//! - Requests to <http://localhost:8698/> are proxied to <https://crates.io>
12//!
13//! ```no_run
14//! use salvo_core::prelude::*;
15//! use salvo_proxy::Proxy;
16//!
17//! #[tokio::main]
18//! async fn main() {
19//!     let router = Router::new()
20//!         .push(
21//!             Router::new()
22//!                 .host("127.0.0.1")
23//!                 .path("{**rest}")
24//!                 .goal(Proxy::use_hyper_client("https://www.rust-lang.org")),
25//!         )
26//!         .push(
27//!             Router::new()
28//!                 .host("localhost")
29//!                 .path("{**rest}")
30//!                 .goal(Proxy::use_hyper_client("https://crates.io")),
31//!         );
32//!
33//!     let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
34//!     Server::new(acceptor).serve(router).await;
35//! }
36//! ```
37#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43use std::fmt::{self, Debug, Formatter};
44
45use hyper::upgrade::OnUpgrade;
46use percent_encoding::{CONTROLS, utf8_percent_encode};
47use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
48use salvo_core::http::uri::Uri;
49use salvo_core::http::{ReqBody, ResBody, StatusCode};
50use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
51
52#[macro_use]
53mod cfg;
54
55cfg_feature! {
56    #![feature = "hyper-client"]
57    mod hyper_client;
58    pub use hyper_client::*;
59}
60cfg_feature! {
61    #![feature = "reqwest-client"]
62    mod reqwest_client;
63    pub use reqwest_client::*;
64}
65
66cfg_feature! {
67    #![feature = "unix-sock-client"]
68    #[cfg(unix)]
69    mod unix_sock_client;
70    #[cfg(unix)]
71    pub use unix_sock_client::*;
72}
73
74type HyperRequest = hyper::Request<ReqBody>;
75type HyperResponse = hyper::Response<ResBody>;
76
77/// Encode url path. This can be used when build your custom url path getter.
78#[inline]
79pub(crate) fn encode_url_path(path: &str) -> String {
80    path.split('/')
81        .map(|s| utf8_percent_encode(s, CONTROLS).to_string())
82        .collect::<Vec<_>>()
83        .join("/")
84}
85
86/// Client trait for implementing different HTTP clients for proxying.
87///
88/// Implement this trait to create custom proxy clients with different
89/// backends or configurations.
90pub trait Client: Send + Sync + 'static {
91    /// Error type returned by the client.
92    type Error: StdError + Send + Sync + 'static;
93
94    /// Execute a request through the proxy client.
95    fn execute(
96        &self,
97        req: HyperRequest,
98        upgraded: Option<OnUpgrade>,
99    ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
100}
101
102/// Upstreams trait for selecting target servers.
103///
104/// Implement this trait to customize how target servers are selected
105/// for proxying requests. This can be used to implement load balancing,
106/// failover, or other server selection strategies.
107pub trait Upstreams: Send + Sync + 'static {
108    /// Error type returned when selecting a server fails.
109    type Error: StdError + Send + Sync + 'static;
110
111    /// Elect a server to handle the current request.
112    fn elect(&self) -> impl Future<Output = Result<&str, Self::Error>> + Send;
113}
114impl Upstreams for &'static str {
115    type Error = Infallible;
116
117    async fn elect(&self) -> Result<&str, Self::Error> {
118        Ok(*self)
119    }
120}
121impl Upstreams for String {
122    type Error = Infallible;
123    async fn elect(&self) -> Result<&str, Self::Error> {
124        Ok(self.as_str())
125    }
126}
127
128impl<const N: usize> Upstreams for [&'static str; N] {
129    type Error = Error;
130    async fn elect(&self) -> Result<&str, Self::Error> {
131        if self.is_empty() {
132            return Err(Error::other("upstreams is empty"));
133        }
134        let index = fastrand::usize(..self.len());
135        Ok(self[index])
136    }
137}
138
139impl<T> Upstreams for Vec<T>
140where
141    T: AsRef<str> + Send + Sync + 'static,
142{
143    type Error = Error;
144    async fn elect(&self) -> Result<&str, Self::Error> {
145        if self.is_empty() {
146            return Err(Error::other("upstreams is empty"));
147        }
148        let index = fastrand::usize(..self.len());
149        Ok(self[index].as_ref())
150    }
151}
152
153/// Url part getter. You can use this to get the proxied url path or query.
154pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
155
156/// Default url path getter.
157///
158/// This getter will get the last param as the rest url path from request.
159/// In most case you should use wildcard param, like `{**rest}`, `{*+rest}`.
160pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
161    req.params().tail().map(encode_url_path)
162}
163/// Default url query getter. This getter just return the query string from request uri.
164pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
165    req.uri().query().map(Into::into)
166}
167
168/// Handler that can proxy request to other server.
169#[non_exhaustive]
170pub struct Proxy<U, C>
171where
172    U: Upstreams,
173    C: Client,
174{
175    /// Upstreams list.
176    pub upstreams: U,
177    /// [`Client`] for proxy.
178    pub client: C,
179    /// Url path getter.
180    pub url_path_getter: UrlPartGetter,
181    /// Url query getter.
182    pub url_query_getter: UrlPartGetter,
183}
184
185impl<U, C> Debug for Proxy<U, C>
186where
187    U: Upstreams,
188    C: Client,
189{
190    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
191        f.debug_struct("Proxy").finish()
192    }
193}
194
195impl<U, C> Proxy<U, C>
196where
197    U: Upstreams,
198    U::Error: Into<BoxedError>,
199    C: Client,
200{
201    /// Create new `Proxy` with upstreams list.
202    #[must_use]
203    pub fn new(upstreams: U, client: C) -> Self {
204        Self {
205            upstreams,
206            client,
207            url_path_getter: Box::new(default_url_path_getter),
208            url_query_getter: Box::new(default_url_query_getter),
209        }
210    }
211
212    /// Set url path getter.
213    #[inline]
214    #[must_use]
215    pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
216    where
217        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
218    {
219        self.url_path_getter = Box::new(url_path_getter);
220        self
221    }
222
223    /// Set url query getter.
224    #[inline]
225    #[must_use]
226    pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
227    where
228        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
229    {
230        self.url_query_getter = Box::new(url_query_getter);
231        self
232    }
233
234    /// Get upstreams list.
235    #[inline]
236    pub fn upstreams(&self) -> &U {
237        &self.upstreams
238    }
239    /// Get upstreams mutable list.
240    #[inline]
241    pub fn upstreams_mut(&mut self) -> &mut U {
242        &mut self.upstreams
243    }
244
245    /// Get client reference.
246    #[inline]
247    pub fn client(&self) -> &C {
248        &self.client
249    }
250    /// Get client mutable reference.
251    #[inline]
252    pub fn client_mut(&mut self) -> &mut C {
253        &mut self.client
254    }
255
256    async fn build_proxied_request(
257        &self,
258        req: &mut Request,
259        depot: &Depot,
260    ) -> Result<HyperRequest, Error> {
261        let upstream = self.upstreams.elect().await.map_err(Error::other)?;
262        if upstream.is_empty() {
263            tracing::error!("upstreams is empty");
264            return Err(Error::other("upstreams is empty"));
265        }
266
267        let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
268        let query = (self.url_query_getter)(req, depot);
269        let rest = if let Some(query) = query {
270            if query.starts_with('?') {
271                format!("{path}{query}")
272            } else {
273                format!("{path}?{query}")
274            }
275        } else {
276            path
277        };
278        let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
279            format!("{}{}", upstream.trim_end_matches('/'), rest)
280        } else if upstream.ends_with('/') || rest.starts_with('/') {
281            format!("{upstream}{rest}")
282        } else if rest.is_empty() {
283            upstream.to_owned()
284        } else {
285            format!("{upstream}/{rest}")
286        };
287        let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
288        let mut build = hyper::Request::builder()
289            .method(req.method())
290            .uri(&forward_url);
291        for (key, value) in req.headers() {
292            if key != HOST {
293                build = build.header(key, value);
294            }
295        }
296        if let Some(host) = forward_url
297            .host()
298            .and_then(|host| HeaderValue::from_str(host).ok())
299        {
300            build = build.header(HeaderName::from_static("host"), host);
301        }
302        // let x_forwarded_for_header_name = "x-forwarded-for";
303        // // Add forwarding information in the headers
304        // match request.headers_mut().entry(x_forwarded_for_header_name) {
305        //     Ok(header_entry) => {
306        //         match header_entry {
307        //             hyper::header::Entry::Vacant(entry) => {
308        //                 let addr = format!("{}", client_ip);
309        //                 entry.insert(addr.parse().unwrap());
310        //             },
311        //             hyper::header::Entry::Occupied(mut entry) => {
312        //                 let addr = format!("{}, {}", entry.get().to_str().unwrap(), client_ip);
313        //                 entry.insert(addr.parse().unwrap());
314        //             }
315        //         }
316        //     }
317        //     // shouldn't happen...
318        //     Err(_) => panic!("Invalid header name: {}", x_forwarded_for_header_name),
319        // }
320        build.body(req.take_body()).map_err(Error::other)
321    }
322}
323
324#[async_trait]
325impl<U, C> Handler for Proxy<U, C>
326where
327    U: Upstreams,
328    U::Error: Into<BoxedError>,
329    C: Client,
330{
331    async fn handle(
332        &self,
333        req: &mut Request,
334        depot: &mut Depot,
335        res: &mut Response,
336        _ctrl: &mut FlowCtrl,
337    ) {
338        match self.build_proxied_request(req, depot).await {
339            Ok(proxied_request) => {
340                match self
341                    .client
342                    .execute(proxied_request, req.extensions_mut().remove())
343                    .await
344                {
345                    Ok(response) => {
346                        let (
347                            salvo_core::http::response::Parts {
348                                status,
349                                // version,
350                                headers,
351                                // extensions,
352                                ..
353                            },
354                            body,
355                        ) = response.into_parts();
356                        res.status_code(status);
357                        for name in headers.keys() {
358                            for value in headers.get_all(name) {
359                                res.headers.append(name, value.to_owned());
360                            }
361                        }
362                        res.body(body);
363                    }
364                    Err(e) => {
365                        tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
366                        res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
367                    }
368                }
369            }
370            Err(e) => {
371                tracing::error!(error = ?e, "build proxied request failed");
372            }
373        }
374    }
375}
376#[inline]
377#[allow(dead_code)]
378fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
379    if headers
380        .get(&CONNECTION)
381        .map(|value| {
382            value
383                .to_str()
384                .unwrap_or_default()
385                .split(',')
386                .any(|e| e.trim() == UPGRADE)
387        })
388        .unwrap_or(false)
389    {
390        if let Some(upgrade_value) = headers.get(&UPGRADE) {
391            tracing::debug!(
392                "Found upgrade header with value: {:?}",
393                upgrade_value.to_str()
394            );
395            return upgrade_value.to_str().ok();
396        }
397    }
398
399    None
400}
401
402// Unit tests for Proxy
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_encode_url_path() {
409        let path = "/test/path";
410        let encoded_path = encode_url_path(path);
411        assert_eq!(encoded_path, "/test/path");
412    }
413
414    #[test]
415    fn test_get_upgrade_type() {
416        let mut headers = HeaderMap::new();
417        headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
418        headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
419        let upgrade_type = get_upgrade_type(&headers);
420        assert_eq!(upgrade_type, Some("websocket"));
421    }
422}