salvo_proxy/
hyper_client.rs1use hyper::upgrade::OnUpgrade;
2use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
3use hyper_util::client::legacy::Client as HyperUtilClient;
4use hyper_util::client::legacy::connect::{Connect, HttpConnector};
5use hyper_util::rt::TokioExecutor;
6use salvo_core::Error;
7use salvo_core::http::{ReqBody, ResBody, StatusCode};
8use salvo_core::rt::tokio::TokioIo;
9use tokio::io::copy_bidirectional;
10
11use crate::{BoxedError, Client, HyperRequest, HyperResponse, Proxy, Upstreams};
12
13#[derive(Clone, Debug)]
18pub struct HyperClient<C> {
19 inner: HyperUtilClient<C, ReqBody>,
20}
21
22impl Default for HyperClient<HttpsConnector<HttpConnector>> {
23 fn default() -> Self {
24 let https = HttpsConnectorBuilder::new()
25 .with_native_roots()
26 .expect("no native root CA certificates found")
27 .https_or_http()
28 .enable_all_versions()
29 .build();
30 Self {
31 inner: HyperUtilClient::builder(TokioExecutor::new()).build(https),
32 }
33 }
34}
35
36impl<U> Proxy<U, HyperClient<HttpsConnector<HttpConnector>>>
37where
38 U: Upstreams,
39 U::Error: Into<BoxedError>,
40{
41 pub fn use_hyper_client(upstreams: U) -> Self {
45 Self::new(upstreams, Default::default())
46 }
47}
48
49impl<C> HyperClient<C> {
50 #[must_use]
52 pub fn new(inner: HyperUtilClient<C, ReqBody>) -> Self {
53 Self { inner }
54 }
55}
56
57impl<C> Client for HyperClient<C>
58where
59 C: Connect + Clone + Send + Sync + 'static,
60{
61 type Error = salvo_core::Error;
62
63 async fn execute(
64 &self,
65 proxied_request: HyperRequest,
66 request_upgraded: Option<OnUpgrade>,
67 ) -> Result<HyperResponse, Self::Error> {
68 let request_upgrade_type =
69 crate::get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned());
70
71 let mut response = self
72 .inner
73 .request(proxied_request)
74 .await
75 .map_err(Error::other)?;
76
77 if response.status() == StatusCode::SWITCHING_PROTOCOLS {
78 let response_upgrade_type = crate::get_upgrade_type(response.headers());
79 if request_upgrade_type == response_upgrade_type.map(|s| s.to_lowercase()) {
80 let response_upgraded = hyper::upgrade::on(&mut response).await?;
81 if let Some(request_upgraded) = request_upgraded {
82 tokio::spawn(async move {
83 match request_upgraded.await {
84 Ok(request_upgraded) => {
85 let mut request_upgraded = TokioIo::new(request_upgraded);
86 let mut response_upgraded = TokioIo::new(response_upgraded);
87 if let Err(e) = copy_bidirectional(
88 &mut response_upgraded,
89 &mut request_upgraded,
90 )
91 .await
92 {
93 tracing::error!(error = ?e, "coping between upgraded connections failed.");
94 }
95 }
96 Err(e) => {
97 tracing::error!(error = ?e, "upgrade request failed.");
98 }
99 }
100 });
101 } else {
102 return Err(Error::other("request does not have an upgrade extension."));
103 }
104 } else {
105 return Err(Error::other("upgrade type mismatch"));
106 }
107 }
108 Ok(response.map(ResBody::Hyper))
109 }
110}
111
112#[cfg(test)]
114mod tests {
115 use salvo_core::prelude::*;
116 use salvo_core::test::*;
117
118 use super::*;
119 use crate::{Proxy, Upstreams};
120
121 #[tokio::test]
122 async fn test_upstreams_elect() {
123 let _ = rustls::crypto::aws_lc_rs::default_provider()
124 .install_default();
125 let upstreams = vec!["https://www.example.com", "https://www.example2.com"];
126 let proxy = Proxy::new(upstreams.clone(), HyperClient::default());
127 let request = Request::new();
128 let depot = Depot::new();
129 let elected_upstream = proxy.upstreams().elect(&request, &depot).await.unwrap();
130 assert!(upstreams.contains(&elected_upstream));
131 }
132
133 #[tokio::test]
134 async fn test_hyper_client() {
135 let _ = rustls::crypto::aws_lc_rs::default_provider()
136 .install_default();
137 let router = Router::new().push(
138 Router::with_path("rust/{**rest}")
139 .goal(Proxy::new(vec!["https://salvo.rs"], HyperClient::default())),
140 );
141
142 let content = TestClient::get("http://127.0.0.1:5801/rust/guide/index.html")
143 .send(router)
144 .await
145 .take_string()
146 .await
147 .unwrap();
148 assert!(content.contains("Salvo"));
149 }
150
151 #[test]
152 fn test_others() {
153 let _ = rustls::crypto::aws_lc_rs::default_provider()
154 .install_default();
155 let mut handler = Proxy::new(["https://www.bing.com"], HyperClient::default());
156 assert_eq!(handler.upstreams().len(), 1);
157 assert_eq!(handler.upstreams_mut().len(), 1);
158 }
159}