1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3
4use std::sync::Arc;
5
6use tokio::sync::Mutex;
7
8use tower_service::Service as TowerService;
9#[cfg(feature = "tls")]
10use hyper_rustls::{HttpsConnectorBuilder, HttpsConnector};
11use hyper::{Uri, header::HeaderValue, body::Bytes, client::conn::http1::SendRequest};
12use hyper_util::{
13 rt::tokio::TokioExecutor,
14 client::legacy::{Client as HyperClient, connect::HttpConnector},
15};
16pub use hyper;
17
18mod request;
19pub use request::*;
20
21mod response;
22pub use response::*;
23
24#[derive(Debug)]
25pub enum Error {
26 InvalidUri,
27 MissingHost,
28 InconsistentHost,
29 ConnectionError(Box<dyn Send + Sync + std::error::Error>),
30 Hyper(hyper::Error),
31 HyperUtil(hyper_util::client::legacy::Error),
32}
33
34#[cfg(not(feature = "tls"))]
35type Connector = HttpConnector;
36#[cfg(feature = "tls")]
37type Connector = HttpsConnector<HttpConnector>;
38
39#[derive(Clone, Debug)]
40enum Connection {
41 ConnectionPool(HyperClient<Connector, Full<Bytes>>),
42 Connection {
43 connector: Connector,
44 host: Uri,
45 connection: Arc<Mutex<Option<SendRequest<Full<Bytes>>>>>,
46 },
47}
48
49#[derive(Clone, Debug)]
50pub struct Client {
51 connection: Connection,
52}
53
54impl Client {
55 #[allow(clippy::unnecessary_wraps)]
56 fn connector() -> Result<Connector, Error> {
57 let mut res = HttpConnector::new();
58 res.set_keepalive(Some(core::time::Duration::from_secs(60)));
59 res.set_nodelay(true);
60 res.set_reuse_address(true);
61
62 #[cfg(feature = "tls")]
63 res.enforce_http(false);
64 #[cfg(feature = "tls")]
65 let https = HttpsConnectorBuilder::new().with_native_roots();
66 #[cfg(all(feature = "tls", not(feature = "webpki-roots")))]
67 let https = https.map_err(|e| {
68 Error::ConnectionError(
69 format!("couldn't load system's SSL root certificates and webpki-roots unavilable: {e:?}")
70 .into(),
71 )
72 })?;
73 #[cfg(all(feature = "tls", feature = "webpki-roots"))]
75 let https = https.unwrap_or(HttpsConnectorBuilder::new().with_webpki_roots());
76 #[cfg(feature = "tls")]
77 let res = https.https_or_http().enable_http1().wrap_connector(res);
78
79 Ok(res)
80 }
81
82 pub fn with_connection_pool() -> Result<Client, Error> {
83 Ok(Client {
84 connection: Connection::ConnectionPool(
85 HyperClient::builder(TokioExecutor::new())
86 .pool_idle_timeout(core::time::Duration::from_secs(60))
87 .build(Self::connector()?),
88 ),
89 })
90 }
91
92 pub fn without_connection_pool(host: &str) -> Result<Client, Error> {
93 Ok(Client {
94 connection: Connection::Connection {
95 connector: Self::connector()?,
96 host: {
97 let uri: Uri = host.parse().map_err(|_| Error::InvalidUri)?;
98 if uri.host().is_none() {
99 Err(Error::MissingHost)?;
100 };
101 uri
102 },
103 connection: Arc::new(Mutex::new(None)),
104 },
105 })
106 }
107
108 pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response<'_>, Error> {
109 let request: Request = request.into();
110 let Request { mut request, response_size_limit } = request;
111 if let Some(header_host) = request.headers().get(hyper::header::HOST) {
112 match &self.connection {
113 Connection::ConnectionPool(_) => {}
114 Connection::Connection { host, .. } => {
115 if header_host.to_str().map_err(|_| Error::InvalidUri)? != host.host().unwrap() {
116 Err(Error::InconsistentHost)?;
117 }
118 }
119 }
120 } else {
121 let host = match &self.connection {
122 Connection::ConnectionPool(_) => {
123 request.uri().host().ok_or(Error::MissingHost)?.to_string()
124 }
125 Connection::Connection { host, .. } => {
126 let host_str = host.host().unwrap();
127 if let Some(uri_host) = request.uri().host() {
128 if host_str != uri_host {
129 Err(Error::InconsistentHost)?;
130 }
131 }
132 host_str.to_string()
133 }
134 };
135 request
136 .headers_mut()
137 .insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?);
138 }
139
140 let response = match &self.connection {
141 Connection::ConnectionPool(client) => {
142 client.request(request).await.map_err(Error::HyperUtil)?
143 }
144 Connection::Connection { connector, host, connection } => {
145 let mut connection_lock = connection.lock().await;
146
147 if connection_lock.is_none() {
149 let call_res = connector.clone().call(host.clone()).await;
150 #[cfg(not(feature = "tls"))]
151 let call_res = call_res.map_err(|e| Error::ConnectionError(format!("{e:?}").into()));
152 #[cfg(feature = "tls")]
153 let call_res = call_res.map_err(Error::ConnectionError);
154 let (requester, connection) =
155 hyper::client::conn::http1::handshake(call_res?).await.map_err(Error::Hyper)?;
156 tokio::spawn(connection);
159 *connection_lock = Some(requester);
160 }
161
162 let connection = connection_lock.as_mut().expect("lock over the connection was poisoned");
163 let mut err = connection.ready().await.err();
164 if err.is_none() {
165 let response = connection.send_request(request).await;
167 if let Ok(response) = response {
168 return Ok(Response { response, size_limit: response_size_limit, client: self });
169 }
170 err = response.err();
171 }
172 *connection_lock = None;
174 Err(Error::Hyper(err.expect("only here if `err` is some yet no error")))?
175 }
176 };
177
178 Ok(Response { response, size_limit: response_size_limit, client: self })
179 }
180}