simple_request/
lib.rs

1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3
4use std::sync::Arc;
5
6use tokio::sync::Mutex;
7
8use tower_service::Service as TowerService;
9#[cfg(feature = "tls")]
10use hyper_rustls::{HttpsConnectorBuilder, HttpsConnector};
11use hyper::{Uri, header::HeaderValue, body::Bytes, client::conn::http1::SendRequest};
12use hyper_util::{
13  rt::tokio::TokioExecutor,
14  client::legacy::{Client as HyperClient, connect::HttpConnector},
15};
16pub use hyper;
17
18mod request;
19pub use request::*;
20
21mod response;
22pub use response::*;
23
24#[derive(Debug)]
25pub enum Error {
26  InvalidUri,
27  MissingHost,
28  InconsistentHost,
29  ConnectionError(Box<dyn Send + Sync + std::error::Error>),
30  Hyper(hyper::Error),
31  HyperUtil(hyper_util::client::legacy::Error),
32}
33
34#[cfg(not(feature = "tls"))]
35type Connector = HttpConnector;
36#[cfg(feature = "tls")]
37type Connector = HttpsConnector<HttpConnector>;
38
39#[derive(Clone, Debug)]
40enum Connection {
41  ConnectionPool(HyperClient<Connector, Full<Bytes>>),
42  Connection {
43    connector: Connector,
44    host: Uri,
45    connection: Arc<Mutex<Option<SendRequest<Full<Bytes>>>>>,
46  },
47}
48
49#[derive(Clone, Debug)]
50pub struct Client {
51  connection: Connection,
52}
53
54impl Client {
55  #[allow(clippy::unnecessary_wraps)]
56  fn connector() -> Result<Connector, Error> {
57    let mut res = HttpConnector::new();
58    res.set_keepalive(Some(core::time::Duration::from_secs(60)));
59    res.set_nodelay(true);
60    res.set_reuse_address(true);
61
62    #[cfg(feature = "tls")]
63    res.enforce_http(false);
64    #[cfg(feature = "tls")]
65    let https = HttpsConnectorBuilder::new().with_native_roots();
66    #[cfg(all(feature = "tls", not(feature = "webpki-roots")))]
67    let https = https.map_err(|e| {
68      Error::ConnectionError(
69        format!("couldn't load system's SSL root certificates and webpki-roots unavilable: {e:?}")
70          .into(),
71      )
72    })?;
73    // Fallback to `webpki-roots` if present
74    #[cfg(all(feature = "tls", feature = "webpki-roots"))]
75    let https = https.unwrap_or(HttpsConnectorBuilder::new().with_webpki_roots());
76    #[cfg(feature = "tls")]
77    let res = https.https_or_http().enable_http1().wrap_connector(res);
78
79    Ok(res)
80  }
81
82  pub fn with_connection_pool() -> Result<Client, Error> {
83    Ok(Client {
84      connection: Connection::ConnectionPool(
85        HyperClient::builder(TokioExecutor::new())
86          .pool_idle_timeout(core::time::Duration::from_secs(60))
87          .build(Self::connector()?),
88      ),
89    })
90  }
91
92  pub fn without_connection_pool(host: &str) -> Result<Client, Error> {
93    Ok(Client {
94      connection: Connection::Connection {
95        connector: Self::connector()?,
96        host: {
97          let uri: Uri = host.parse().map_err(|_| Error::InvalidUri)?;
98          if uri.host().is_none() {
99            Err(Error::MissingHost)?;
100          };
101          uri
102        },
103        connection: Arc::new(Mutex::new(None)),
104      },
105    })
106  }
107
108  pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response<'_>, Error> {
109    let request: Request = request.into();
110    let Request { mut request, response_size_limit } = request;
111    if let Some(header_host) = request.headers().get(hyper::header::HOST) {
112      match &self.connection {
113        Connection::ConnectionPool(_) => {}
114        Connection::Connection { host, .. } => {
115          if header_host.to_str().map_err(|_| Error::InvalidUri)? != host.host().unwrap() {
116            Err(Error::InconsistentHost)?;
117          }
118        }
119      }
120    } else {
121      let host = match &self.connection {
122        Connection::ConnectionPool(_) => {
123          request.uri().host().ok_or(Error::MissingHost)?.to_string()
124        }
125        Connection::Connection { host, .. } => {
126          let host_str = host.host().unwrap();
127          if let Some(uri_host) = request.uri().host() {
128            if host_str != uri_host {
129              Err(Error::InconsistentHost)?;
130            }
131          }
132          host_str.to_string()
133        }
134      };
135      request
136        .headers_mut()
137        .insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?);
138    }
139
140    let response = match &self.connection {
141      Connection::ConnectionPool(client) => {
142        client.request(request).await.map_err(Error::HyperUtil)?
143      }
144      Connection::Connection { connector, host, connection } => {
145        let mut connection_lock = connection.lock().await;
146
147        // If there's not a connection...
148        if connection_lock.is_none() {
149          let call_res = connector.clone().call(host.clone()).await;
150          #[cfg(not(feature = "tls"))]
151          let call_res = call_res.map_err(|e| Error::ConnectionError(format!("{e:?}").into()));
152          #[cfg(feature = "tls")]
153          let call_res = call_res.map_err(Error::ConnectionError);
154          let (requester, connection) =
155            hyper::client::conn::http1::handshake(call_res?).await.map_err(Error::Hyper)?;
156          // This will die when we drop the requester, so we don't need to track an AbortHandle
157          // for it
158          tokio::spawn(connection);
159          *connection_lock = Some(requester);
160        }
161
162        let connection = connection_lock.as_mut().expect("lock over the connection was poisoned");
163        let mut err = connection.ready().await.err();
164        if err.is_none() {
165          // Send the request
166          let response = connection.send_request(request).await;
167          if let Ok(response) = response {
168            return Ok(Response { response, size_limit: response_size_limit, client: self });
169          }
170          err = response.err();
171        }
172        // Since this connection has been put into an error state, drop it
173        *connection_lock = None;
174        Err(Error::Hyper(err.expect("only here if `err` is some yet no error")))?
175      }
176    };
177
178    Ok(Response { response, size_limit: response_size_limit, client: self })
179  }
180}