1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43
44use hyper::upgrade::OnUpgrade;
45use percent_encoding::{CONTROLS, utf8_percent_encode};
46use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
47use salvo_core::http::uri::Uri;
48use salvo_core::http::{ReqBody, ResBody, StatusCode};
49use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
50
51#[macro_use]
52mod cfg;
53
54cfg_feature! {
55 #![feature = "hyper-client"]
56 mod hyper_client;
57 pub use hyper_client::*;
58}
59cfg_feature! {
60 #![feature = "reqwest-client"]
61 mod reqwest_client;
62 pub use reqwest_client::*;
63}
64
65cfg_feature! {
66 #![feature = "unix-sock-client"]
67 #[cfg(unix)]
68 mod unix_sock_client;
69 #[cfg(unix)]
70 pub use unix_sock_client::*;
71}
72
73type HyperRequest = hyper::Request<ReqBody>;
74type HyperResponse = hyper::Response<ResBody>;
75
76#[inline]
78pub(crate) fn encode_url_path(path: &str) -> String {
79 path.split('/')
80 .map(|s| utf8_percent_encode(s, CONTROLS).to_string())
81 .collect::<Vec<_>>()
82 .join("/")
83}
84
85pub trait Client: Send + Sync + 'static {
90 type Error: StdError + Send + Sync + 'static;
92
93 fn execute(
95 &self,
96 req: HyperRequest,
97 upgraded: Option<OnUpgrade>,
98 ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
99}
100
101pub trait Upstreams: Send + Sync + 'static {
107 type Error: StdError + Send + Sync + 'static;
109
110 fn elect(&self) -> impl Future<Output = Result<&str, Self::Error>> + Send;
112}
113impl Upstreams for &'static str {
114 type Error = Infallible;
115
116 async fn elect(&self) -> Result<&str, Self::Error> {
117 Ok(*self)
118 }
119}
120impl Upstreams for String {
121 type Error = Infallible;
122 async fn elect(&self) -> Result<&str, Self::Error> {
123 Ok(self.as_str())
124 }
125}
126
127impl<const N: usize> Upstreams for [&'static str; N] {
128 type Error = Error;
129 async fn elect(&self) -> Result<&str, Self::Error> {
130 if self.is_empty() {
131 return Err(Error::other("upstreams is empty"));
132 }
133 let index = fastrand::usize(..self.len());
134 Ok(self[index])
135 }
136}
137
138impl<T> Upstreams for Vec<T>
139where
140 T: AsRef<str> + Send + Sync + 'static,
141{
142 type Error = Error;
143 async fn elect(&self) -> Result<&str, Self::Error> {
144 if self.is_empty() {
145 return Err(Error::other("upstreams is empty"));
146 }
147 let index = fastrand::usize(..self.len());
148 Ok(self[index].as_ref())
149 }
150}
151
152pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
154
155pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
160 req.params().tail().map(encode_url_path)
161}
162pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
164 req.uri().query().map(Into::into)
165}
166
167#[non_exhaustive]
169pub struct Proxy<U, C>
170where
171 U: Upstreams,
172 C: Client,
173{
174 pub upstreams: U,
176 pub client: C,
178 pub url_path_getter: UrlPartGetter,
180 pub url_query_getter: UrlPartGetter,
182}
183
184impl<U, C> Proxy<U, C>
185where
186 U: Upstreams,
187 U::Error: Into<BoxedError>,
188 C: Client,
189{
190 pub fn new(upstreams: U, client: C) -> Self {
192 Proxy {
193 upstreams,
194 client,
195 url_path_getter: Box::new(default_url_path_getter),
196 url_query_getter: Box::new(default_url_query_getter),
197 }
198 }
199
200 #[inline]
202 pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
203 where
204 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
205 {
206 self.url_path_getter = Box::new(url_path_getter);
207 self
208 }
209
210 #[inline]
212 pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
213 where
214 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
215 {
216 self.url_query_getter = Box::new(url_query_getter);
217 self
218 }
219
220 #[inline]
222 pub fn upstreams(&self) -> &U {
223 &self.upstreams
224 }
225 #[inline]
227 pub fn upstreams_mut(&mut self) -> &mut U {
228 &mut self.upstreams
229 }
230
231 #[inline]
233 pub fn client(&self) -> &C {
234 &self.client
235 }
236 #[inline]
238 pub fn client_mut(&mut self) -> &mut C {
239 &mut self.client
240 }
241
242 async fn build_proxied_request(
243 &self,
244 req: &mut Request,
245 depot: &Depot,
246 ) -> Result<HyperRequest, Error> {
247 let upstream = self.upstreams.elect().await.map_err(Error::other)?;
248 if upstream.is_empty() {
249 tracing::error!("upstreams is empty");
250 return Err(Error::other("upstreams is empty"));
251 }
252
253 let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
254 let query = (self.url_query_getter)(req, depot);
255 let rest = if let Some(query) = query {
256 if query.starts_with('?') {
257 format!("{path}{query}")
258 } else {
259 format!("{path}?{query}")
260 }
261 } else {
262 path
263 };
264 let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
265 format!("{}{}", upstream.trim_end_matches('/'), rest)
266 } else if upstream.ends_with('/') || rest.starts_with('/') {
267 format!("{upstream}{rest}")
268 } else if rest.is_empty() {
269 upstream.to_string()
270 } else {
271 format!("{upstream}/{rest}")
272 };
273 let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
274 let mut build = hyper::Request::builder()
275 .method(req.method())
276 .uri(&forward_url);
277 for (key, value) in req.headers() {
278 if key != HOST {
279 build = build.header(key, value);
280 }
281 }
282 if let Some(host) = forward_url
283 .host()
284 .and_then(|host| HeaderValue::from_str(host).ok())
285 {
286 build = build.header(HeaderName::from_static("host"), host);
287 }
288 build.body(req.take_body()).map_err(Error::other)
307 }
308}
309
310#[async_trait]
311impl<U, C> Handler for Proxy<U, C>
312where
313 U: Upstreams,
314 U::Error: Into<BoxedError>,
315 C: Client,
316{
317 async fn handle(
318 &self,
319 req: &mut Request,
320 depot: &mut Depot,
321 res: &mut Response,
322 _ctrl: &mut FlowCtrl,
323 ) {
324 match self.build_proxied_request(req, depot).await {
325 Ok(proxied_request) => {
326 match self
327 .client
328 .execute(proxied_request, req.extensions_mut().remove())
329 .await
330 {
331 Ok(response) => {
332 let (
333 salvo_core::http::response::Parts {
334 status,
335 headers,
337 ..
339 },
340 body,
341 ) = response.into_parts();
342 res.status_code(status);
343 for name in headers.keys() {
344 for value in headers.get_all(name) {
345 res.headers.append(name, value.to_owned());
346 }
347 }
348 res.body(body);
349 }
350 Err(e) => {
351 tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
352 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
353 }
354 }
355 }
356 Err(e) => {
357 tracing::error!(error = ?e, "build proxied request failed");
358 }
359 }
360 }
361}
362#[inline]
363#[allow(dead_code)]
364fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
365 if headers
366 .get(&CONNECTION)
367 .map(|value| {
368 value
369 .to_str()
370 .unwrap_or_default()
371 .split(',')
372 .any(|e| e.trim() == UPGRADE)
373 })
374 .unwrap_or(false)
375 {
376 if let Some(upgrade_value) = headers.get(&UPGRADE) {
377 tracing::debug!(
378 "Found upgrade header with value: {:?}",
379 upgrade_value.to_str()
380 );
381 return upgrade_value.to_str().ok();
382 }
383 }
384
385 None
386}
387
388#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_encode_url_path() {
395 let path = "/test/path";
396 let encoded_path = encode_url_path(path);
397 assert_eq!(encoded_path, "/test/path");
398 }
399
400 #[test]
401 fn test_get_upgrade_type() {
402 let mut headers = HeaderMap::new();
403 headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
404 headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
405 let upgrade_type = get_upgrade_type(&headers);
406 assert_eq!(upgrade_type, Some("websocket"));
407 }
408}