#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
#![cfg_attr(docsrs, feature(doc_cfg))]
use std::convert::Infallible;
use std::error::Error as StdError;
use std::future::Future;
use hyper::upgrade::OnUpgrade;
use percent_encoding::{utf8_percent_encode, CONTROLS};
use salvo_core::http::header::{HeaderMap, HeaderName, HeaderValue, CONNECTION, HOST, UPGRADE};
use salvo_core::http::uri::Uri;
use salvo_core::http::{ReqBody, ResBody, StatusCode};
use salvo_core::{async_trait, BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response};
#[macro_use]
mod cfg;
cfg_feature! {
#![feature = "hyper-client"]
mod hyper_client;
pub use hyper_client::*;
}
cfg_feature! {
#![feature = "reqwest-client"]
mod reqwest_client;
pub use reqwest_client::*;
}
type HyperRequest = hyper::Request<ReqBody>;
type HyperResponse = hyper::Response<ResBody>;
#[inline]
pub(crate) fn encode_url_path(path: &str) -> String {
path.split('/')
.map(|s| utf8_percent_encode(s, CONTROLS).to_string())
.collect::<Vec<_>>()
.join("/")
}
pub trait Client: Send + Sync + 'static {
type Error: StdError + Send + Sync + 'static;
fn execute(
&self,
req: HyperRequest,
upgraded: Option<OnUpgrade>,
) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
}
pub trait Upstreams: Send + Sync + 'static {
type Error: StdError + Send + Sync + 'static;
fn elect(&self) -> impl Future<Output = Result<&str, Self::Error>> + Send;
}
impl Upstreams for &'static str {
type Error = Infallible;
async fn elect(&self) -> Result<&str, Self::Error> {
Ok(*self)
}
}
impl Upstreams for String {
type Error = Infallible;
async fn elect(&self) -> Result<&str, Self::Error> {
Ok(self.as_str())
}
}
impl<const N: usize> Upstreams for [&'static str; N] {
type Error = Error;
async fn elect(&self) -> Result<&str, Self::Error> {
if self.is_empty() {
return Err(Error::other("upstreams is empty"));
}
let index = fastrand::usize(..self.len());
Ok(self[index])
}
}
impl<T> Upstreams for Vec<T>
where
T: AsRef<str> + Send + Sync + 'static,
{
type Error = Error;
async fn elect(&self) -> Result<&str, Self::Error> {
if self.is_empty() {
return Err(Error::other("upstreams is empty"));
}
let index = fastrand::usize(..self.len());
Ok(self[index].as_ref())
}
}
pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
req.params().tail().map(encode_url_path)
}
pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
req.uri().query().map(Into::into)
}
#[non_exhaustive]
pub struct Proxy<U, C>
where
U: Upstreams,
C: Client,
{
pub upstreams: U,
pub client: C,
pub url_path_getter: UrlPartGetter,
pub url_query_getter: UrlPartGetter,
}
impl<U, C> Proxy<U, C>
where
U: Upstreams,
U::Error: Into<BoxedError>,
C: Client,
{
pub fn new(upstreams: U, client: C) -> Self {
Proxy {
upstreams,
client,
url_path_getter: Box::new(default_url_path_getter),
url_query_getter: Box::new(default_url_query_getter),
}
}
#[inline]
pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
where
G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
{
self.url_path_getter = Box::new(url_path_getter);
self
}
#[inline]
pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
where
G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
{
self.url_query_getter = Box::new(url_query_getter);
self
}
#[inline]
pub fn upstreams(&self) -> &U {
&self.upstreams
}
#[inline]
pub fn upstreams_mut(&mut self) -> &mut U {
&mut self.upstreams
}
#[inline]
pub fn client(&self) -> &C {
&self.client
}
#[inline]
pub fn client_mut(&mut self) -> &mut C {
&mut self.client
}
async fn build_proxied_request(
&self,
req: &mut Request,
depot: &Depot,
) -> Result<HyperRequest, Error> {
let upstream = self.upstreams.elect().await.map_err(Error::other)?;
if upstream.is_empty() {
tracing::error!("upstreams is empty");
return Err(Error::other("upstreams is empty"));
}
let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
let query = (self.url_query_getter)(req, depot);
let rest = if let Some(query) = query {
if query.starts_with('?') {
format!("{}{}", path, query)
} else {
format!("{}?{}", path, query)
}
} else {
path
};
let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
format!("{}{}", upstream.trim_end_matches('/'), rest)
} else if upstream.ends_with('/') || rest.starts_with('/') {
format!("{}{}", upstream, rest)
} else if rest.is_empty() {
upstream.to_string()
} else {
format!("{}/{}", upstream, rest)
};
let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
let mut build = hyper::Request::builder()
.method(req.method())
.uri(&forward_url);
for (key, value) in req.headers() {
if key != HOST {
build = build.header(key, value);
}
}
if let Some(host) = forward_url
.host()
.and_then(|host| HeaderValue::from_str(host).ok())
{
build = build.header(HeaderName::from_static("host"), host);
}
build.body(req.take_body()).map_err(Error::other)
}
}
#[async_trait]
impl<U, C> Handler for Proxy<U, C>
where
U: Upstreams,
U::Error: Into<BoxedError>,
C: Client,
{
async fn handle(
&self,
req: &mut Request,
depot: &mut Depot,
res: &mut Response,
_ctrl: &mut FlowCtrl,
) {
match self.build_proxied_request(req, depot).await {
Ok(proxied_request) => {
match self
.client
.execute(proxied_request, req.extensions_mut().remove())
.await
{
Ok(response) => {
let (
salvo_core::http::response::Parts {
status,
headers,
..
},
body,
) = response.into_parts();
res.status_code(status);
for (name, value) in headers {
if let Some(name) = name {
res.headers.insert(name, value);
}
}
res.body(body);
}
Err(e) => {
tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
}
}
}
Err(e) => {
tracing::error!(error = ?e, "build proxied request failed");
}
}
}
}
#[inline]
#[allow(dead_code)]
fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
if headers
.get(&CONNECTION)
.map(|value| {
value
.to_str()
.unwrap_or_default()
.split(',')
.any(|e| e.trim() == UPGRADE)
})
.unwrap_or(false)
{
if let Some(upgrade_value) = headers.get(&UPGRADE) {
tracing::debug!(
"Found upgrade header with value: {:?}",
upgrade_value.to_str()
);
return upgrade_value.to_str().ok();
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_url_path() {
let path = "/test/path";
let encoded_path = encode_url_path(path);
assert_eq!(encoded_path, "/test/path");
}
#[test]
fn test_get_upgrade_type() {
let mut headers = HeaderMap::new();
headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
let upgrade_type = get_upgrade_type(&headers);
assert_eq!(upgrade_type, Some("websocket"));
}
}