1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43use std::fmt::{self, Debug, Formatter};
44
45use hyper::upgrade::OnUpgrade;
46use percent_encoding::{CONTROLS, utf8_percent_encode};
47use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
48use salvo_core::http::uri::Uri;
49use salvo_core::http::{ReqBody, ResBody, StatusCode};
50use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
51
52#[macro_use]
53mod cfg;
54
55cfg_feature! {
56 #![feature = "hyper-client"]
57 mod hyper_client;
58 pub use hyper_client::*;
59}
60cfg_feature! {
61 #![feature = "reqwest-client"]
62 mod reqwest_client;
63 pub use reqwest_client::*;
64}
65
66cfg_feature! {
67 #![feature = "unix-sock-client"]
68 #[cfg(unix)]
69 mod unix_sock_client;
70 #[cfg(unix)]
71 pub use unix_sock_client::*;
72}
73
74type HyperRequest = hyper::Request<ReqBody>;
75type HyperResponse = hyper::Response<ResBody>;
76
77#[inline]
79pub(crate) fn encode_url_path(path: &str) -> String {
80 path.split('/')
81 .map(|s| utf8_percent_encode(s, CONTROLS).to_string())
82 .collect::<Vec<_>>()
83 .join("/")
84}
85
86pub trait Client: Send + Sync + 'static {
91 type Error: StdError + Send + Sync + 'static;
93
94 fn execute(
96 &self,
97 req: HyperRequest,
98 upgraded: Option<OnUpgrade>,
99 ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
100}
101
102pub trait Upstreams: Send + Sync + 'static {
108 type Error: StdError + Send + Sync + 'static;
110
111 fn elect(&self) -> impl Future<Output = Result<&str, Self::Error>> + Send;
113}
114impl Upstreams for &'static str {
115 type Error = Infallible;
116
117 async fn elect(&self) -> Result<&str, Self::Error> {
118 Ok(*self)
119 }
120}
121impl Upstreams for String {
122 type Error = Infallible;
123 async fn elect(&self) -> Result<&str, Self::Error> {
124 Ok(self.as_str())
125 }
126}
127
128impl<const N: usize> Upstreams for [&'static str; N] {
129 type Error = Error;
130 async fn elect(&self) -> Result<&str, Self::Error> {
131 if self.is_empty() {
132 return Err(Error::other("upstreams is empty"));
133 }
134 let index = fastrand::usize(..self.len());
135 Ok(self[index])
136 }
137}
138
139impl<T> Upstreams for Vec<T>
140where
141 T: AsRef<str> + Send + Sync + 'static,
142{
143 type Error = Error;
144 async fn elect(&self) -> Result<&str, Self::Error> {
145 if self.is_empty() {
146 return Err(Error::other("upstreams is empty"));
147 }
148 let index = fastrand::usize(..self.len());
149 Ok(self[index].as_ref())
150 }
151}
152
153pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
155
156pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
161 req.params().tail().map(encode_url_path)
162}
163pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
165 req.uri().query().map(Into::into)
166}
167
168#[non_exhaustive]
170pub struct Proxy<U, C>
171where
172 U: Upstreams,
173 C: Client,
174{
175 pub upstreams: U,
177 pub client: C,
179 pub url_path_getter: UrlPartGetter,
181 pub url_query_getter: UrlPartGetter,
183}
184
185impl<U, C> Debug for Proxy<U, C>
186where
187 U: Upstreams,
188 C: Client,
189{
190 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
191 f.debug_struct("Proxy").finish()
192 }
193}
194
195impl<U, C> Proxy<U, C>
196where
197 U: Upstreams,
198 U::Error: Into<BoxedError>,
199 C: Client,
200{
201 #[must_use]
203 pub fn new(upstreams: U, client: C) -> Self {
204 Self {
205 upstreams,
206 client,
207 url_path_getter: Box::new(default_url_path_getter),
208 url_query_getter: Box::new(default_url_query_getter),
209 }
210 }
211
212 #[inline]
214 #[must_use]
215 pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
216 where
217 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
218 {
219 self.url_path_getter = Box::new(url_path_getter);
220 self
221 }
222
223 #[inline]
225 #[must_use]
226 pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
227 where
228 G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
229 {
230 self.url_query_getter = Box::new(url_query_getter);
231 self
232 }
233
234 #[inline]
236 pub fn upstreams(&self) -> &U {
237 &self.upstreams
238 }
239 #[inline]
241 pub fn upstreams_mut(&mut self) -> &mut U {
242 &mut self.upstreams
243 }
244
245 #[inline]
247 pub fn client(&self) -> &C {
248 &self.client
249 }
250 #[inline]
252 pub fn client_mut(&mut self) -> &mut C {
253 &mut self.client
254 }
255
256 async fn build_proxied_request(
257 &self,
258 req: &mut Request,
259 depot: &Depot,
260 ) -> Result<HyperRequest, Error> {
261 let upstream = self.upstreams.elect().await.map_err(Error::other)?;
262 if upstream.is_empty() {
263 tracing::error!("upstreams is empty");
264 return Err(Error::other("upstreams is empty"));
265 }
266
267 let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
268 let query = (self.url_query_getter)(req, depot);
269 let rest = if let Some(query) = query {
270 if query.starts_with('?') {
271 format!("{path}{query}")
272 } else {
273 format!("{path}?{query}")
274 }
275 } else {
276 path
277 };
278 let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
279 format!("{}{}", upstream.trim_end_matches('/'), rest)
280 } else if upstream.ends_with('/') || rest.starts_with('/') {
281 format!("{upstream}{rest}")
282 } else if rest.is_empty() {
283 upstream.to_owned()
284 } else {
285 format!("{upstream}/{rest}")
286 };
287 let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
288 let mut build = hyper::Request::builder()
289 .method(req.method())
290 .uri(&forward_url);
291 for (key, value) in req.headers() {
292 if key != HOST {
293 build = build.header(key, value);
294 }
295 }
296 if let Some(host) = forward_url
297 .host()
298 .and_then(|host| HeaderValue::from_str(host).ok())
299 {
300 build = build.header(HeaderName::from_static("host"), host);
301 }
302 build.body(req.take_body()).map_err(Error::other)
321 }
322}
323
324#[async_trait]
325impl<U, C> Handler for Proxy<U, C>
326where
327 U: Upstreams,
328 U::Error: Into<BoxedError>,
329 C: Client,
330{
331 async fn handle(
332 &self,
333 req: &mut Request,
334 depot: &mut Depot,
335 res: &mut Response,
336 _ctrl: &mut FlowCtrl,
337 ) {
338 match self.build_proxied_request(req, depot).await {
339 Ok(proxied_request) => {
340 match self
341 .client
342 .execute(proxied_request, req.extensions_mut().remove())
343 .await
344 {
345 Ok(response) => {
346 let (
347 salvo_core::http::response::Parts {
348 status,
349 headers,
351 ..
353 },
354 body,
355 ) = response.into_parts();
356 res.status_code(status);
357 for name in headers.keys() {
358 for value in headers.get_all(name) {
359 res.headers.append(name, value.to_owned());
360 }
361 }
362 res.body(body);
363 }
364 Err(e) => {
365 tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
366 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
367 }
368 }
369 }
370 Err(e) => {
371 tracing::error!(error = ?e, "build proxied request failed");
372 }
373 }
374 }
375}
376#[inline]
377#[allow(dead_code)]
378fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
379 if headers
380 .get(&CONNECTION)
381 .map(|value| {
382 value
383 .to_str()
384 .unwrap_or_default()
385 .split(',')
386 .any(|e| e.trim() == UPGRADE)
387 })
388 .unwrap_or(false)
389 {
390 if let Some(upgrade_value) = headers.get(&UPGRADE) {
391 tracing::debug!(
392 "Found upgrade header with value: {:?}",
393 upgrade_value.to_str()
394 );
395 return upgrade_value.to_str().ok();
396 }
397 }
398
399 None
400}
401
402#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_encode_url_path() {
409 let path = "/test/path";
410 let encoded_path = encode_url_path(path);
411 assert_eq!(encoded_path, "/test/path");
412 }
413
414 #[test]
415 fn test_get_upgrade_type() {
416 let mut headers = HeaderMap::new();
417 headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
418 headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
419 let upgrade_type = get_upgrade_type(&headers);
420 assert_eq!(upgrade_type, Some("websocket"));
421 }
422}