salvo_proxy/
lib.rs

1//! Provide HTTP proxy capabilities for the Salvo web framework.
2//!
3//! This crate allows you to easily forward requests to upstream servers,
4//! supporting both HTTP and HTTPS protocols. It's useful for creating API gateways,
5//! load balancers, and reverse proxies.
6//!
7//! # Example
8//!
9//! In this example, requests to different hosts are proxied to different upstream servers:
10//! - Requests to http://127.0.0.1:5800/ are proxied to https://www.rust-lang.org
11//! - Requests to http://localhost:5800/ are proxied to https://crates.io
12//!
13//! ```no_run
14//! use salvo_core::prelude::*;
15//! use salvo_proxy::Proxy;
16//!
17//! #[tokio::main]
18//! async fn main() {
19//!     let router = Router::new()
20//!         .push(
21//!             Router::new()
22//!                 .host("127.0.0.1")
23//!                 .path("{**rest}")
24//!                 .goal(Proxy::use_hyper_client("https://www.rust-lang.org")),
25//!         )
26//!         .push(
27//!             Router::new()
28//!                 .host("localhost")
29//!                 .path("{**rest}")
30//!                 .goal(Proxy::use_hyper_client("https://crates.io")),
31//!         );
32//!
33//!     let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
34//!     Server::new(acceptor).serve(router).await;
35//! }
36//! ```
37#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43
44use hyper::upgrade::OnUpgrade;
45use percent_encoding::{CONTROLS, utf8_percent_encode};
46use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
47use salvo_core::http::uri::Uri;
48use salvo_core::http::{ReqBody, ResBody, StatusCode};
49use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
50
51#[macro_use]
52mod cfg;
53
54cfg_feature! {
55    #![feature = "hyper-client"]
56    mod hyper_client;
57    pub use hyper_client::*;
58}
59cfg_feature! {
60    #![feature = "reqwest-client"]
61    mod reqwest_client;
62    pub use reqwest_client::*;
63}
64
65type HyperRequest = hyper::Request<ReqBody>;
66type HyperResponse = hyper::Response<ResBody>;
67
68/// Encode url path. This can be used when build your custom url path getter.
69#[inline]
70pub(crate) fn encode_url_path(path: &str) -> String {
71    path.split('/')
72        .map(|s| utf8_percent_encode(s, CONTROLS).to_string())
73        .collect::<Vec<_>>()
74        .join("/")
75}
76
77/// Client trait for implementing different HTTP clients for proxying.
78///
79/// Implement this trait to create custom proxy clients with different
80/// backends or configurations.
81pub trait Client: Send + Sync + 'static {
82    /// Error type returned by the client.
83    type Error: StdError + Send + Sync + 'static;
84
85    /// Execute a request through the proxy client.
86    fn execute(
87        &self,
88        req: HyperRequest,
89        upgraded: Option<OnUpgrade>,
90    ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
91}
92
93/// Upstreams trait for selecting target servers.
94///
95/// Implement this trait to customize how target servers are selected
96/// for proxying requests. This can be used to implement load balancing,
97/// failover, or other server selection strategies.
98pub trait Upstreams: Send + Sync + 'static {
99    /// Error type returned when selecting a server fails.
100    type Error: StdError + Send + Sync + 'static;
101
102    /// Elect a server to handle the current request.
103    fn elect(&self) -> impl Future<Output = Result<&str, Self::Error>> + Send;
104}
105impl Upstreams for &'static str {
106    type Error = Infallible;
107
108    async fn elect(&self) -> Result<&str, Self::Error> {
109        Ok(*self)
110    }
111}
112impl Upstreams for String {
113    type Error = Infallible;
114    async fn elect(&self) -> Result<&str, Self::Error> {
115        Ok(self.as_str())
116    }
117}
118
119impl<const N: usize> Upstreams for [&'static str; N] {
120    type Error = Error;
121    async fn elect(&self) -> Result<&str, Self::Error> {
122        if self.is_empty() {
123            return Err(Error::other("upstreams is empty"));
124        }
125        let index = fastrand::usize(..self.len());
126        Ok(self[index])
127    }
128}
129
130impl<T> Upstreams for Vec<T>
131where
132    T: AsRef<str> + Send + Sync + 'static,
133{
134    type Error = Error;
135    async fn elect(&self) -> Result<&str, Self::Error> {
136        if self.is_empty() {
137            return Err(Error::other("upstreams is empty"));
138        }
139        let index = fastrand::usize(..self.len());
140        Ok(self[index].as_ref())
141    }
142}
143
144/// Url part getter. You can use this to get the proxied url path or query.
145pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
146
147/// Default url path getter.
148///
149/// This getter will get the last param as the rest url path from request.
150/// In most case you should use wildcard param, like `{**rest}`, `{*+rest}`.
151pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
152    req.params().tail().map(encode_url_path)
153}
154/// Default url query getter. This getter just return the query string from request uri.
155pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
156    req.uri().query().map(Into::into)
157}
158
159/// Handler that can proxy request to other server.
160#[non_exhaustive]
161pub struct Proxy<U, C>
162where
163    U: Upstreams,
164    C: Client,
165{
166    /// Upstreams list.
167    pub upstreams: U,
168    /// [`Client`] for proxy.
169    pub client: C,
170    /// Url path getter.
171    pub url_path_getter: UrlPartGetter,
172    /// Url query getter.
173    pub url_query_getter: UrlPartGetter,
174}
175
176impl<U, C> Proxy<U, C>
177where
178    U: Upstreams,
179    U::Error: Into<BoxedError>,
180    C: Client,
181{
182    /// Create new `Proxy` with upstreams list.
183    pub fn new(upstreams: U, client: C) -> Self {
184        Proxy {
185            upstreams,
186            client,
187            url_path_getter: Box::new(default_url_path_getter),
188            url_query_getter: Box::new(default_url_query_getter),
189        }
190    }
191
192    /// Set url path getter.
193    #[inline]
194    pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
195    where
196        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
197    {
198        self.url_path_getter = Box::new(url_path_getter);
199        self
200    }
201
202    /// Set url query getter.
203    #[inline]
204    pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
205    where
206        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
207    {
208        self.url_query_getter = Box::new(url_query_getter);
209        self
210    }
211
212    /// Get upstreams list.
213    #[inline]
214    pub fn upstreams(&self) -> &U {
215        &self.upstreams
216    }
217    /// Get upstreams mutable list.
218    #[inline]
219    pub fn upstreams_mut(&mut self) -> &mut U {
220        &mut self.upstreams
221    }
222
223    /// Get client reference.
224    #[inline]
225    pub fn client(&self) -> &C {
226        &self.client
227    }
228    /// Get client mutable reference.
229    #[inline]
230    pub fn client_mut(&mut self) -> &mut C {
231        &mut self.client
232    }
233
234    async fn build_proxied_request(
235        &self,
236        req: &mut Request,
237        depot: &Depot,
238    ) -> Result<HyperRequest, Error> {
239        let upstream = self.upstreams.elect().await.map_err(Error::other)?;
240        if upstream.is_empty() {
241            tracing::error!("upstreams is empty");
242            return Err(Error::other("upstreams is empty"));
243        }
244
245        let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
246        let query = (self.url_query_getter)(req, depot);
247        let rest = if let Some(query) = query {
248            if query.starts_with('?') {
249                format!("{}{}", path, query)
250            } else {
251                format!("{}?{}", path, query)
252            }
253        } else {
254            path
255        };
256        let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
257            format!("{}{}", upstream.trim_end_matches('/'), rest)
258        } else if upstream.ends_with('/') || rest.starts_with('/') {
259            format!("{}{}", upstream, rest)
260        } else if rest.is_empty() {
261            upstream.to_string()
262        } else {
263            format!("{}/{}", upstream, rest)
264        };
265        let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
266        let mut build = hyper::Request::builder()
267            .method(req.method())
268            .uri(&forward_url);
269        for (key, value) in req.headers() {
270            if key != HOST {
271                build = build.header(key, value);
272            }
273        }
274        if let Some(host) = forward_url
275            .host()
276            .and_then(|host| HeaderValue::from_str(host).ok())
277        {
278            build = build.header(HeaderName::from_static("host"), host);
279        }
280        // let x_forwarded_for_header_name = "x-forwarded-for";
281        // // Add forwarding information in the headers
282        // match request.headers_mut().entry(x_forwarded_for_header_name) {
283        //     Ok(header_entry) => {
284        //         match header_entry {
285        //             hyper::header::Entry::Vacant(entry) => {
286        //                 let addr = format!("{}", client_ip);
287        //                 entry.insert(addr.parse().unwrap());
288        //             },
289        //             hyper::header::Entry::Occupied(mut entry) => {
290        //                 let addr = format!("{}, {}", entry.get().to_str().unwrap(), client_ip);
291        //                 entry.insert(addr.parse().unwrap());
292        //             }
293        //         }
294        //     }
295        //     // shouldn't happen...
296        //     Err(_) => panic!("Invalid header name: {}", x_forwarded_for_header_name),
297        // }
298        build.body(req.take_body()).map_err(Error::other)
299    }
300}
301
302#[async_trait]
303impl<U, C> Handler for Proxy<U, C>
304where
305    U: Upstreams,
306    U::Error: Into<BoxedError>,
307    C: Client,
308{
309    async fn handle(
310        &self,
311        req: &mut Request,
312        depot: &mut Depot,
313        res: &mut Response,
314        _ctrl: &mut FlowCtrl,
315    ) {
316        match self.build_proxied_request(req, depot).await {
317            Ok(proxied_request) => {
318                match self
319                    .client
320                    .execute(proxied_request, req.extensions_mut().remove())
321                    .await
322                {
323                    Ok(response) => {
324                        let (
325                            salvo_core::http::response::Parts {
326                                status,
327                                // version,
328                                headers,
329                                // extensions,
330                                ..
331                            },
332                            body,
333                        ) = response.into_parts();
334                        res.status_code(status);
335                        for name in headers.keys() {
336                            for value in headers.get_all(name) {
337                                res.headers.append(name, value.to_owned());
338                            }
339                        }
340                        res.body(body);
341                    }
342                    Err(e) => {
343                        tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
344                        res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
345                    }
346                }
347            }
348            Err(e) => {
349                tracing::error!(error = ?e, "build proxied request failed");
350            }
351        }
352    }
353}
354#[inline]
355#[allow(dead_code)]
356fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
357    if headers
358        .get(&CONNECTION)
359        .map(|value| {
360            value
361                .to_str()
362                .unwrap_or_default()
363                .split(',')
364                .any(|e| e.trim() == UPGRADE)
365        })
366        .unwrap_or(false)
367    {
368        if let Some(upgrade_value) = headers.get(&UPGRADE) {
369            tracing::debug!(
370                "Found upgrade header with value: {:?}",
371                upgrade_value.to_str()
372            );
373            return upgrade_value.to_str().ok();
374        }
375    }
376
377    None
378}
379
380// Unit tests for Proxy
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn test_encode_url_path() {
387        let path = "/test/path";
388        let encoded_path = encode_url_path(path);
389        assert_eq!(encoded_path, "/test/path");
390    }
391
392    #[test]
393    fn test_get_upgrade_type() {
394        let mut headers = HeaderMap::new();
395        headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
396        headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
397        let upgrade_type = get_upgrade_type(&headers);
398        assert_eq!(upgrade_type, Some("websocket"));
399    }
400}