salvo_proxy/
hyper_client.rs

1use hyper::upgrade::OnUpgrade;
2use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
3use hyper_util::client::legacy::Client as HyperUtilClient;
4use hyper_util::client::legacy::connect::{Connect, HttpConnector};
5use hyper_util::rt::TokioExecutor;
6use salvo_core::Error;
7use salvo_core::http::{ReqBody, ResBody, StatusCode};
8use salvo_core::rt::tokio::TokioIo;
9use tokio::io::copy_bidirectional;
10
11use crate::{BoxedError, Client, HyperRequest, HyperResponse, Proxy, Upstreams};
12
13/// A [`Client`] implementation based on [`hyper_util::client::legacy::Client`].
14///
15/// This client provides proxy capabilities using the Hyper HTTP client library.
16/// It's lightweight and tightly integrated with the Tokio runtime.
17#[derive(Clone, Debug)]
18pub struct HyperClient<C> {
19    inner: HyperUtilClient<C, ReqBody>,
20}
21
22impl Default for HyperClient<HttpsConnector<HttpConnector>> {
23    fn default() -> Self {
24        #[cfg(feature = "ring")]
25        let _ = rustls::crypto::ring::default_provider().install_default();
26        let https = HttpsConnectorBuilder::new()
27            .with_native_roots()
28            .expect("no native root CA certificates found")
29            .https_or_http()
30            .enable_all_versions()
31            .build();
32        Self {
33            inner: HyperUtilClient::builder(TokioExecutor::new()).build(https),
34        }
35    }
36}
37
38impl<U> Proxy<U, HyperClient<HttpsConnector<HttpConnector>>>
39where
40    U: Upstreams,
41    U::Error: Into<BoxedError>,
42{
43    /// Create a new `Proxy` using the default Hyper client.
44    ///
45    /// This is a convenient way to create a proxy with standard configuration.
46    pub fn use_hyper_client(upstreams: U) -> Self {
47        Self::new(upstreams, Default::default())
48    }
49}
50
51impl<C> HyperClient<C> {
52    /// Create a new `HyperClient` with the given `HyperClient`.
53    #[must_use]
54    pub fn new(inner: HyperUtilClient<C, ReqBody>) -> Self {
55        Self { inner }
56    }
57}
58
59impl<C> Client for HyperClient<C>
60where
61    C: Connect + Clone + Send + Sync + 'static,
62{
63    type Error = salvo_core::Error;
64
65    async fn execute(
66        &self,
67        proxied_request: HyperRequest,
68        request_upgraded: Option<OnUpgrade>,
69    ) -> Result<HyperResponse, Self::Error> {
70        let request_upgrade_type =
71            crate::get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned());
72
73        let mut response = self
74            .inner
75            .request(proxied_request)
76            .await
77            .map_err(Error::other)?;
78
79        if response.status() == StatusCode::SWITCHING_PROTOCOLS {
80            let response_upgrade_type = crate::get_upgrade_type(response.headers());
81            if request_upgrade_type == response_upgrade_type.map(|s| s.to_lowercase()) {
82                let response_upgraded = hyper::upgrade::on(&mut response).await?;
83                if let Some(request_upgraded) = request_upgraded {
84                    tokio::spawn(async move {
85                        match request_upgraded.await {
86                            Ok(request_upgraded) => {
87                                let mut request_upgraded = TokioIo::new(request_upgraded);
88                                let mut response_upgraded = TokioIo::new(response_upgraded);
89                                if let Err(e) = copy_bidirectional(
90                                    &mut response_upgraded,
91                                    &mut request_upgraded,
92                                )
93                                .await
94                                {
95                                    tracing::error!(error = ?e, "coping between upgraded connections failed.");
96                                }
97                            }
98                            Err(e) => {
99                                tracing::error!(error = ?e, "upgrade request failed.");
100                            }
101                        }
102                    });
103                } else {
104                    return Err(Error::other("request does not have an upgrade extension."));
105                }
106            } else {
107                return Err(Error::other("upgrade type mismatch"));
108            }
109        }
110        Ok(response.map(ResBody::Hyper))
111    }
112}
113
114// Unit tests for Proxy
115#[cfg(test)]
116mod tests {
117    use salvo_core::prelude::*;
118    use salvo_core::test::*;
119
120    use super::*;
121    use crate::{Proxy, Upstreams};
122
123    #[tokio::test]
124    async fn test_upstreams_elect() {
125        let _ = rustls::crypto::aws_lc_rs::default_provider()
126            .install_default();
127        let upstreams = vec!["https://www.example.com", "https://www.example2.com"];
128        let proxy = Proxy::new(upstreams.clone(), HyperClient::default());
129        let request = Request::new();
130        let depot = Depot::new();
131        let elected_upstream = proxy.upstreams().elect(&request, &depot).await.unwrap();
132        assert!(upstreams.contains(&elected_upstream));
133    }
134
135    #[tokio::test]
136    async fn test_hyper_client() {
137        let _ = rustls::crypto::aws_lc_rs::default_provider()
138            .install_default();
139        let router = Router::new().push(
140            Router::with_path("rust/{**rest}")
141                .goal(Proxy::new(vec!["https://salvo.rs"], HyperClient::default())),
142        );
143
144        let content = TestClient::get("http://127.0.0.1:5801/rust/guide/index.html")
145            .send(router)
146            .await
147            .take_string()
148            .await
149            .unwrap();
150        assert!(content.contains("Salvo"));
151    }
152
153    #[test]
154    fn test_others() {
155        let _ = rustls::crypto::aws_lc_rs::default_provider()
156            .install_default();
157        let mut handler = Proxy::new(["https://www.bing.com"], HyperClient::default());
158        assert_eq!(handler.upstreams().len(), 1);
159        assert_eq!(handler.upstreams_mut().len(), 1);
160    }
161}