salvo_proxy/
hyper_client.rs1use hyper::upgrade::OnUpgrade;
2use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
3use hyper_util::client::legacy::Client as HyperUtilClient;
4use hyper_util::client::legacy::connect::{Connect, HttpConnector};
5use hyper_util::rt::TokioExecutor;
6use salvo_core::Error;
7use salvo_core::http::{ReqBody, ResBody, StatusCode};
8use salvo_core::rt::tokio::TokioIo;
9use tokio::io::copy_bidirectional;
10
11use crate::{BoxedError, Client, HyperRequest, HyperResponse, Proxy, Upstreams};
12
13#[derive(Clone, Debug)]
18pub struct HyperClient<C> {
19 inner: HyperUtilClient<C, ReqBody>,
20}
21
22impl Default for HyperClient<HttpsConnector<HttpConnector>> {
23 fn default() -> Self {
24 #[cfg(feature = "ring")]
25 let _ = rustls::crypto::ring::default_provider().install_default();
26 let https = HttpsConnectorBuilder::new()
27 .with_native_roots()
28 .expect("no native root CA certificates found")
29 .https_or_http()
30 .enable_all_versions()
31 .build();
32 Self {
33 inner: HyperUtilClient::builder(TokioExecutor::new()).build(https),
34 }
35 }
36}
37
38impl<U> Proxy<U, HyperClient<HttpsConnector<HttpConnector>>>
39where
40 U: Upstreams,
41 U::Error: Into<BoxedError>,
42{
43 pub fn use_hyper_client(upstreams: U) -> Self {
47 Self::new(upstreams, Default::default())
48 }
49}
50
51impl<C> HyperClient<C> {
52 #[must_use]
54 pub fn new(inner: HyperUtilClient<C, ReqBody>) -> Self {
55 Self { inner }
56 }
57}
58
59impl<C> Client for HyperClient<C>
60where
61 C: Connect + Clone + Send + Sync + 'static,
62{
63 type Error = salvo_core::Error;
64
65 async fn execute(
66 &self,
67 proxied_request: HyperRequest,
68 request_upgraded: Option<OnUpgrade>,
69 ) -> Result<HyperResponse, Self::Error> {
70 let request_upgrade_type =
71 crate::get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned());
72
73 let mut response = self
74 .inner
75 .request(proxied_request)
76 .await
77 .map_err(Error::other)?;
78
79 if response.status() == StatusCode::SWITCHING_PROTOCOLS {
80 let response_upgrade_type = crate::get_upgrade_type(response.headers());
81 if request_upgrade_type == response_upgrade_type.map(|s| s.to_lowercase()) {
82 let response_upgraded = hyper::upgrade::on(&mut response).await?;
83 if let Some(request_upgraded) = request_upgraded {
84 tokio::spawn(async move {
85 match request_upgraded.await {
86 Ok(request_upgraded) => {
87 let mut request_upgraded = TokioIo::new(request_upgraded);
88 let mut response_upgraded = TokioIo::new(response_upgraded);
89 if let Err(e) = copy_bidirectional(
90 &mut response_upgraded,
91 &mut request_upgraded,
92 )
93 .await
94 {
95 tracing::error!(error = ?e, "coping between upgraded connections failed.");
96 }
97 }
98 Err(e) => {
99 tracing::error!(error = ?e, "upgrade request failed.");
100 }
101 }
102 });
103 } else {
104 return Err(Error::other("request does not have an upgrade extension."));
105 }
106 } else {
107 return Err(Error::other("upgrade type mismatch"));
108 }
109 }
110 Ok(response.map(ResBody::Hyper))
111 }
112}
113
114#[cfg(test)]
116mod tests {
117 use salvo_core::prelude::*;
118 use salvo_core::test::*;
119
120 use super::*;
121 use crate::{Proxy, Upstreams};
122
123 #[tokio::test]
124 async fn test_upstreams_elect() {
125 let _ = rustls::crypto::aws_lc_rs::default_provider()
126 .install_default();
127 let upstreams = vec!["https://www.example.com", "https://www.example2.com"];
128 let proxy = Proxy::new(upstreams.clone(), HyperClient::default());
129 let request = Request::new();
130 let depot = Depot::new();
131 let elected_upstream = proxy.upstreams().elect(&request, &depot).await.unwrap();
132 assert!(upstreams.contains(&elected_upstream));
133 }
134
135 #[tokio::test]
136 async fn test_hyper_client() {
137 let _ = rustls::crypto::aws_lc_rs::default_provider()
138 .install_default();
139 let router = Router::new().push(
140 Router::with_path("rust/{**rest}")
141 .goal(Proxy::new(vec!["https://salvo.rs"], HyperClient::default())),
142 );
143
144 let content = TestClient::get("http://127.0.0.1:5801/rust/guide/index.html")
145 .send(router)
146 .await
147 .take_string()
148 .await
149 .unwrap();
150 assert!(content.contains("Salvo"));
151 }
152
153 #[test]
154 fn test_others() {
155 let _ = rustls::crypto::aws_lc_rs::default_provider()
156 .install_default();
157 let mut handler = Proxy::new(["https://www.bing.com"], HyperClient::default());
158 assert_eq!(handler.upstreams().len(), 1);
159 assert_eq!(handler.upstreams_mut().len(), 1);
160 }
161}