1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43
44use hyper::upgrade::OnUpgrade;
45use percent_encoding::{CONTROLS, utf8_percent_encode};
46use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
47use salvo_core::http::uri::Uri;
48use salvo_core::http::{ReqBody, ResBody, StatusCode};
49use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
50
51#[macro_use]
52mod cfg;
53
54cfg_feature! {
55 #![feature = "hyper-client"]
56 mod hyper_client;
57 pub use hyper_client::*;
58}
59cfg_feature! {
60 #![feature = "reqwest-client"]
61 mod reqwest_client;
62 pub use reqwest_client::*;
63}
64
65type HyperRequest = hyper::Request<ReqBody>;
66type HyperResponse = hyper::Response<ResBody>;
67
68#[inline]
70pub(crate) fn encode_url_path(path: &str) -> String {
71 path.split('/')
72 .map(|s| utf8_percent_encode(s, CONTROLS).to_string())
73 .collect::<Vec<_>>()
74 .join("/")
75}
76
77pub trait Client: Send + Sync + 'static {
82 type Error: StdError + Send + Sync + 'static;
84
85 fn execute(
87 &self,
88 req: HyperRequest,
89 upgraded: Option<OnUpgrade>,
90 ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
91}
92
93pub trait Upstreams: Send + Sync + 'static {
99 type Error: StdError + Send + Sync + 'static;
101
102 fn elect(&self) -> impl Future<Output = Result<&str, Self::Error>> + Send;
104}
105impl Upstreams for &'static str {
106 type Error = Infallible;
107
108 async fn elect(&self) -> Result<&str, Self::Error> {
109 Ok(*self)
110 }
111}
112impl Upstreams for String {
113 type Error = Infallible;
114 async fn elect(&self) -> Result<&str, Self::Error> {
115 Ok(self.as_str())
116 }
117}
118
119impl<const N: usize> Upstreams for [&'static str; N] {
120 type Error = Error;
121 async fn elect(&self) -> Result<&str, Self::Error> {
122 if self.is_empty() {
123 return Err(Error::other("upstreams is empty"));
124 }
125 let index = fastrand::usize(..self.len());
126 Ok(self[index])
127 }
128}
129
130impl<T> Upstreams for Vec<T>
131where
132 T: AsRef<str> + Send + Sync + 'static,
133{
134 type Error = Error;
135 async fn elect(&self) -> Result<&str, Self::Error> {
136 if self.is_empty() {
137 return Err(Error::other("upstreams is empty"));
138 }
139 let index = fastrand::usize(..self.len());
140 Ok(self[index].as_ref())
141 }
142}
143
144pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
146
147pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
152 req.params().tail().map(encode_url_path)
153}
154pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
156 req.uri().query().map(Into::into)
157}
158
159#[non_exhaustive]
161pub struct Proxy<U, C>
162where
163 U: Upstreams,
164 C: Client,
165{
166 pub upstreams: U,
168 pub client: C,
170 pub url_path_getter: UrlPartGetter,
172 pub url_query_getter: UrlPartGetter,
174}
175
176impl<U, C> Proxy<U, C>
177where
178 U: Upstreams,
179 U::Error: Into<BoxedError>,
180 C: Client,
181{
182 pub fn new(upstreams: U, client: C) -> Self {
184 Proxy {
185 upstreams,
186 client,
187 url_path_getter: Box::new(default_url_path_getter),
188 url_query_getter: Box::new(default_url_query_getter),
189 }
190 }
191
192 #[inline]
194 pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
195 where
196 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
197 {
198 self.url_path_getter = Box::new(url_path_getter);
199 self
200 }
201
202 #[inline]
204 pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
205 where
206 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
207 {
208 self.url_query_getter = Box::new(url_query_getter);
209 self
210 }
211
212 #[inline]
214 pub fn upstreams(&self) -> &U {
215 &self.upstreams
216 }
217 #[inline]
219 pub fn upstreams_mut(&mut self) -> &mut U {
220 &mut self.upstreams
221 }
222
223 #[inline]
225 pub fn client(&self) -> &C {
226 &self.client
227 }
228 #[inline]
230 pub fn client_mut(&mut self) -> &mut C {
231 &mut self.client
232 }
233
234 async fn build_proxied_request(
235 &self,
236 req: &mut Request,
237 depot: &Depot,
238 ) -> Result<HyperRequest, Error> {
239 let upstream = self.upstreams.elect().await.map_err(Error::other)?;
240 if upstream.is_empty() {
241 tracing::error!("upstreams is empty");
242 return Err(Error::other("upstreams is empty"));
243 }
244
245 let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
246 let query = (self.url_query_getter)(req, depot);
247 let rest = if let Some(query) = query {
248 if query.starts_with('?') {
249 format!("{}{}", path, query)
250 } else {
251 format!("{}?{}", path, query)
252 }
253 } else {
254 path
255 };
256 let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
257 format!("{}{}", upstream.trim_end_matches('/'), rest)
258 } else if upstream.ends_with('/') || rest.starts_with('/') {
259 format!("{}{}", upstream, rest)
260 } else if rest.is_empty() {
261 upstream.to_string()
262 } else {
263 format!("{}/{}", upstream, rest)
264 };
265 let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
266 let mut build = hyper::Request::builder()
267 .method(req.method())
268 .uri(&forward_url);
269 for (key, value) in req.headers() {
270 if key != HOST {
271 build = build.header(key, value);
272 }
273 }
274 if let Some(host) = forward_url
275 .host()
276 .and_then(|host| HeaderValue::from_str(host).ok())
277 {
278 build = build.header(HeaderName::from_static("host"), host);
279 }
280 build.body(req.take_body()).map_err(Error::other)
299 }
300}
301
302#[async_trait]
303impl<U, C> Handler for Proxy<U, C>
304where
305 U: Upstreams,
306 U::Error: Into<BoxedError>,
307 C: Client,
308{
309 async fn handle(
310 &self,
311 req: &mut Request,
312 depot: &mut Depot,
313 res: &mut Response,
314 _ctrl: &mut FlowCtrl,
315 ) {
316 match self.build_proxied_request(req, depot).await {
317 Ok(proxied_request) => {
318 match self
319 .client
320 .execute(proxied_request, req.extensions_mut().remove())
321 .await
322 {
323 Ok(response) => {
324 let (
325 salvo_core::http::response::Parts {
326 status,
327 headers,
329 ..
331 },
332 body,
333 ) = response.into_parts();
334 res.status_code(status);
335 for name in headers.keys() {
336 for value in headers.get_all(name) {
337 res.headers.append(name, value.to_owned());
338 }
339 }
340 res.body(body);
341 }
342 Err(e) => {
343 tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
344 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
345 }
346 }
347 }
348 Err(e) => {
349 tracing::error!(error = ?e, "build proxied request failed");
350 }
351 }
352 }
353}
354#[inline]
355#[allow(dead_code)]
356fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
357 if headers
358 .get(&CONNECTION)
359 .map(|value| {
360 value
361 .to_str()
362 .unwrap_or_default()
363 .split(',')
364 .any(|e| e.trim() == UPGRADE)
365 })
366 .unwrap_or(false)
367 {
368 if let Some(upgrade_value) = headers.get(&UPGRADE) {
369 tracing::debug!(
370 "Found upgrade header with value: {:?}",
371 upgrade_value.to_str()
372 );
373 return upgrade_value.to_str().ok();
374 }
375 }
376
377 None
378}
379
380#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_encode_url_path() {
387 let path = "/test/path";
388 let encoded_path = encode_url_path(path);
389 assert_eq!(encoded_path, "/test/path");
390 }
391
392 #[test]
393 fn test_get_upgrade_type() {
394 let mut headers = HeaderMap::new();
395 headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
396 headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
397 let upgrade_type = get_upgrade_type(&headers);
398 assert_eq!(upgrade_type, Some("websocket"));
399 }
400}