salvo_proxy/
reqwest_client.rs

1use futures_util::TryStreamExt;
2use hyper::upgrade::OnUpgrade;
3use reqwest::Client as InnerClient;
4use salvo_core::http::{ResBody, StatusCode};
5use salvo_core::rt::tokio::TokioIo;
6use salvo_core::Error;
7use tokio::io::copy_bidirectional;
8
9use crate::{Client, HyperRequest, BoxedError, Proxy, Upstreams, HyperResponse};
10
11/// A [`Client`] implementation based on [`reqwest::Client`].
12/// 
13/// This client provides proxy capabilities using the Reqwest HTTP client.
14/// It supports all features of Reqwest including automatic redirect handling,
15/// connection pooling, and other HTTP client features.
16#[derive(Default, Clone, Debug)]
17pub struct ReqwestClient {
18    inner: InnerClient,
19}
20
21impl<U> Proxy<U, ReqwestClient>
22where
23    U: Upstreams,
24    U::Error: Into<BoxedError>,
25{
26    /// Create a new `Proxy` using the default Reqwest client.
27    /// 
28    /// This is a convenient way to create a proxy with standard configuration.
29    pub fn use_reqwest_client(upstreams: U) -> Self {
30        Self::new(upstreams, ReqwestClient::default())
31    }
32}
33
34impl ReqwestClient {
35    /// Create a new `ReqwestClient` with the given [`reqwest::Client`].
36    #[must_use]
37    pub fn new(inner: InnerClient) -> Self {
38        Self { inner }
39    }
40}
41
42impl Client for ReqwestClient {
43    type Error = salvo_core::Error;
44
45    async fn execute(
46        &self,
47        proxied_request: HyperRequest,
48        request_upgraded: Option<OnUpgrade>,
49    ) -> Result<HyperResponse, Self::Error> {
50        let request_upgrade_type = crate::get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned());
51
52        let proxied_request =
53            proxied_request.map(|s| reqwest::Body::wrap_stream(s.map_ok(|s| s.into_data().unwrap_or_default())));
54        let response = self
55            .inner
56            .execute(proxied_request.try_into().map_err(Error::other)?)
57            .await
58            .map_err(Error::other)?;
59
60        let res_headers = response.headers().clone();
61        let hyper_response = hyper::Response::builder()
62            .status(response.status())
63            .version(response.version());
64
65        let mut hyper_response = if response.status() == StatusCode::SWITCHING_PROTOCOLS {
66            let response_upgrade_type = crate::get_upgrade_type(response.headers());
67
68            if request_upgrade_type == response_upgrade_type.map(|s| s.to_lowercase()) {
69                let mut response_upgraded = response
70                    .upgrade()
71                    .await
72                    .map_err(|e| Error::other(format!("response does not have an upgrade extension. {e}")))?;
73                if let Some(request_upgraded) = request_upgraded {
74                    tokio::spawn(async move {
75                        match request_upgraded.await {
76                            Ok(request_upgraded) => {
77                                let mut request_upgraded = TokioIo::new(request_upgraded);
78                                if let Err(e) = copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await
79                                {
80                                    tracing::error!(error = ?e, "coping between upgraded connections failed");
81                                }
82                            }
83                            Err(e) => {
84                                tracing::error!(error = ?e, "upgrade request failed");
85                            }
86                        }
87                    });
88                } else {
89                    return Err(Error::other("request does not have an upgrade extension"));
90                }
91            } else {
92                return Err(Error::other("upgrade type mismatch"));
93            }
94            hyper_response.body(ResBody::None).map_err(Error::other)?
95        } else {
96            hyper_response
97                .body(ResBody::stream(response.bytes_stream()))
98                .map_err(Error::other)?
99        };
100        *hyper_response.headers_mut() = res_headers;
101        Ok(hyper_response)
102    }
103}
104
105// Unit tests for Proxy
106#[cfg(test)]
107mod tests {
108    use salvo_core::prelude::*;
109    use salvo_core::test::*;
110
111    use super::*;
112    use crate::{Upstreams, Proxy};
113
114    #[tokio::test]
115    async fn test_upstreams_elect() {
116        let upstreams = vec!["https://www.example.com", "https://www.example2.com"];
117        let proxy = Proxy::new(upstreams.clone(), ReqwestClient::default());
118        let request = Request::new();
119        let depot = Depot::new();
120        let elected_upstream = proxy.upstreams().elect(&request, &depot).await.unwrap();
121        assert!(upstreams.contains(&elected_upstream));
122    }
123
124    #[tokio::test]
125    async fn test_reqwest_client() {
126        let router = Router::new().push(
127            Router::with_path("rust/{**rest}").goal(Proxy::new(vec!["https://salvo.rs"], ReqwestClient::default())),
128        );
129
130        let content = TestClient::get("http://127.0.0.1:5801/rust/guide/index.html")
131            .send(router)
132            .await
133            .take_string()
134            .await
135            .unwrap();
136        assert!(content.contains("Salvo"));
137    }
138
139    #[test]
140    fn test_others() {
141        let mut handler = Proxy::new(["https://www.bing.com"], ReqwestClient::default());
142        assert_eq!(handler.upstreams().len(), 1);
143        assert_eq!(handler.upstreams_mut().len(), 1);
144    }
145}