salvo_proxy/
hyper_client.rs

1use hyper::upgrade::OnUpgrade;
2use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
3use hyper_util::client::legacy::Client as HyperUtilClient;
4use hyper_util::client::legacy::connect::{Connect, HttpConnector};
5use hyper_util::rt::TokioExecutor;
6use salvo_core::Error;
7use salvo_core::http::{ReqBody, ResBody, StatusCode};
8use salvo_core::rt::tokio::TokioIo;
9use tokio::io::copy_bidirectional;
10
11use crate::{BoxedError, Client, HyperRequest, HyperResponse, Proxy, Upstreams};
12
13/// A [`Client`] implementation based on [`hyper_util::client::legacy::Client`].
14///
15/// This client provides proxy capabilities using the Hyper HTTP client library.
16/// It's lightweight and tightly integrated with the Tokio runtime.
17#[derive(Clone, Debug)]
18pub struct HyperClient<C> {
19    inner: HyperUtilClient<C, ReqBody>,
20}
21
22impl Default for HyperClient<HttpsConnector<HttpConnector>> {
23    fn default() -> Self {
24        let https = HttpsConnectorBuilder::new()
25            .with_native_roots()
26            .expect("no native root CA certificates found")
27            .https_or_http()
28            .enable_all_versions()
29            .build();
30        Self {
31            inner: HyperUtilClient::builder(TokioExecutor::new()).build(https),
32        }
33    }
34}
35
36impl<U> Proxy<U, HyperClient<HttpsConnector<HttpConnector>>>
37where
38    U: Upstreams,
39    U::Error: Into<BoxedError>,
40{
41    /// Create a new `Proxy` using the default Hyper client.
42    ///
43    /// This is a convenient way to create a proxy with standard configuration.
44    pub fn use_hyper_client(upstreams: U) -> Self {
45        Self::new(upstreams, Default::default())
46    }
47}
48
49impl<C> HyperClient<C> {
50    /// Create a new `HyperClient` with the given `HyperClient`.
51    #[must_use]
52    pub fn new(inner: HyperUtilClient<C, ReqBody>) -> Self {
53        Self { inner }
54    }
55}
56
57impl<C> Client for HyperClient<C>
58where
59    C: Connect + Clone + Send + Sync + 'static,
60{
61    type Error = salvo_core::Error;
62
63    async fn execute(
64        &self,
65        proxied_request: HyperRequest,
66        request_upgraded: Option<OnUpgrade>,
67    ) -> Result<HyperResponse, Self::Error> {
68        let request_upgrade_type =
69            crate::get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned());
70
71        let mut response = self
72            .inner
73            .request(proxied_request)
74            .await
75            .map_err(Error::other)?;
76
77        if response.status() == StatusCode::SWITCHING_PROTOCOLS {
78            let response_upgrade_type = crate::get_upgrade_type(response.headers());
79            if request_upgrade_type == response_upgrade_type.map(|s| s.to_lowercase()) {
80                let response_upgraded = hyper::upgrade::on(&mut response).await?;
81                if let Some(request_upgraded) = request_upgraded {
82                    tokio::spawn(async move {
83                        match request_upgraded.await {
84                            Ok(request_upgraded) => {
85                                let mut request_upgraded = TokioIo::new(request_upgraded);
86                                let mut response_upgraded = TokioIo::new(response_upgraded);
87                                if let Err(e) = copy_bidirectional(
88                                    &mut response_upgraded,
89                                    &mut request_upgraded,
90                                )
91                                .await
92                                {
93                                    tracing::error!(error = ?e, "coping between upgraded connections failed.");
94                                }
95                            }
96                            Err(e) => {
97                                tracing::error!(error = ?e, "upgrade request failed.");
98                            }
99                        }
100                    });
101                } else {
102                    return Err(Error::other("request does not have an upgrade extension."));
103                }
104            } else {
105                return Err(Error::other("upgrade type mismatch"));
106            }
107        }
108        Ok(response.map(ResBody::Hyper))
109    }
110}
111
112// Unit tests for Proxy
113#[cfg(test)]
114mod tests {
115    use salvo_core::prelude::*;
116    use salvo_core::test::*;
117
118    use super::*;
119    use crate::{Proxy, Upstreams};
120
121    #[tokio::test]
122    async fn test_upstreams_elect() {
123        let _ = rustls::crypto::aws_lc_rs::default_provider()
124            .install_default();
125        let upstreams = vec!["https://www.example.com", "https://www.example2.com"];
126        let proxy = Proxy::new(upstreams.clone(), HyperClient::default());
127        let request = Request::new();
128        let depot = Depot::new();
129        let elected_upstream = proxy.upstreams().elect(&request, &depot).await.unwrap();
130        assert!(upstreams.contains(&elected_upstream));
131    }
132
133    #[tokio::test]
134    async fn test_hyper_client() {
135        let _ = rustls::crypto::aws_lc_rs::default_provider()
136            .install_default();
137        let router = Router::new().push(
138            Router::with_path("rust/{**rest}")
139                .goal(Proxy::new(vec!["https://salvo.rs"], HyperClient::default())),
140        );
141
142        let content = TestClient::get("http://127.0.0.1:5801/rust/guide/index.html")
143            .send(router)
144            .await
145            .take_string()
146            .await
147            .unwrap();
148        assert!(content.contains("Salvo"));
149    }
150
151    #[test]
152    fn test_others() {
153        let _ = rustls::crypto::aws_lc_rs::default_provider()
154            .install_default();
155        let mut handler = Proxy::new(["https://www.bing.com"], HyperClient::default());
156        assert_eq!(handler.upstreams().len(), 1);
157        assert_eq!(handler.upstreams_mut().len(), 1);
158    }
159}