salvo_proxy/
reqwest_client.rs1use futures_util::TryStreamExt;
2use hyper::upgrade::OnUpgrade;
3use reqwest::Client as InnerClient;
4use salvo_core::http::{ResBody, StatusCode};
5use salvo_core::rt::tokio::TokioIo;
6use salvo_core::Error;
7use tokio::io::copy_bidirectional;
8
9use crate::{Client, HyperRequest, BoxedError, Proxy, Upstreams, HyperResponse};
10
11#[derive(Default, Clone, Debug)]
17pub struct ReqwestClient {
18 inner: InnerClient,
19}
20
21impl<U> Proxy<U, ReqwestClient>
22where
23 U: Upstreams,
24 U::Error: Into<BoxedError>,
25{
26 pub fn use_reqwest_client(upstreams: U) -> Self {
30 Self::new(upstreams, ReqwestClient::default())
31 }
32}
33
34impl ReqwestClient {
35 #[must_use]
37 pub fn new(inner: InnerClient) -> Self {
38 Self { inner }
39 }
40}
41
42impl Client for ReqwestClient {
43 type Error = salvo_core::Error;
44
45 async fn execute(
46 &self,
47 proxied_request: HyperRequest,
48 request_upgraded: Option<OnUpgrade>,
49 ) -> Result<HyperResponse, Self::Error> {
50 let request_upgrade_type = crate::get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned());
51
52 let proxied_request =
53 proxied_request.map(|s| reqwest::Body::wrap_stream(s.map_ok(|s| s.into_data().unwrap_or_default())));
54 let response = self
55 .inner
56 .execute(proxied_request.try_into().map_err(Error::other)?)
57 .await
58 .map_err(Error::other)?;
59
60 let res_headers = response.headers().clone();
61 let hyper_response = hyper::Response::builder()
62 .status(response.status())
63 .version(response.version());
64
65 let mut hyper_response = if response.status() == StatusCode::SWITCHING_PROTOCOLS {
66 let response_upgrade_type = crate::get_upgrade_type(response.headers());
67
68 if request_upgrade_type == response_upgrade_type.map(|s| s.to_lowercase()) {
69 let mut response_upgraded = response
70 .upgrade()
71 .await
72 .map_err(|e| Error::other(format!("response does not have an upgrade extension. {e}")))?;
73 if let Some(request_upgraded) = request_upgraded {
74 tokio::spawn(async move {
75 match request_upgraded.await {
76 Ok(request_upgraded) => {
77 let mut request_upgraded = TokioIo::new(request_upgraded);
78 if let Err(e) = copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await
79 {
80 tracing::error!(error = ?e, "coping between upgraded connections failed");
81 }
82 }
83 Err(e) => {
84 tracing::error!(error = ?e, "upgrade request failed");
85 }
86 }
87 });
88 } else {
89 return Err(Error::other("request does not have an upgrade extension"));
90 }
91 } else {
92 return Err(Error::other("upgrade type mismatch"));
93 }
94 hyper_response.body(ResBody::None).map_err(Error::other)?
95 } else {
96 hyper_response
97 .body(ResBody::stream(response.bytes_stream()))
98 .map_err(Error::other)?
99 };
100 *hyper_response.headers_mut() = res_headers;
101 Ok(hyper_response)
102 }
103}
104
105#[cfg(test)]
107mod tests {
108 use salvo_core::prelude::*;
109 use salvo_core::test::*;
110
111 use super::*;
112 use crate::{Upstreams, Proxy};
113
114 #[tokio::test]
115 async fn test_upstreams_elect() {
116 let upstreams = vec!["https://www.example.com", "https://www.example2.com"];
117 let proxy = Proxy::new(upstreams.clone(), ReqwestClient::default());
118 let elected_upstream = proxy.upstreams().elect().await.unwrap();
119 assert!(upstreams.contains(&elected_upstream));
120 }
121
122 #[tokio::test]
123 async fn test_reqwest_client() {
124 let router = Router::new().push(
125 Router::with_path("rust/{**rest}").goal(Proxy::new(vec!["https://salvo.rs"], ReqwestClient::default())),
126 );
127
128 let content = TestClient::get("http://127.0.0.1:5801/rust/guide/index.html")
129 .send(router)
130 .await
131 .take_string()
132 .await
133 .unwrap();
134 assert!(content.contains("Salvo"));
135 }
136
137 #[test]
138 fn test_others() {
139 let mut handler = Proxy::new(["https://www.bing.com"], ReqwestClient::default());
140 assert_eq!(handler.upstreams().len(), 1);
141 assert_eq!(handler.upstreams_mut().len(), 1);
142 }
143}