simple_hyper_client_rustls/
lib.rs1use simple_hyper_client::{ConnectError, HttpConnection, NetworkConnection, NetworkConnector};
8
9use simple_hyper_client::connector_impl;
10use simple_hyper_client::hyper::client::connect::{Connected, Connection};
11use simple_hyper_client::Uri;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::net::TcpStream;
14use tokio_rustls::{client::TlsStream, TlsConnector};
15
16use std::convert::TryFrom;
17use std::error::Error as StdError;
18use std::future::Future;
19use std::io;
20use std::pin::Pin;
21use std::task::{Context, Poll};
22use std::time::Duration;
23
24pub struct HttpsConnector {
29 force_tls: bool,
30 tls: TlsConnector,
31 connect_timeout: Option<Duration>,
32}
33
34impl HttpsConnector {
35 pub fn new(tls: TlsConnector) -> Self {
36 HttpsConnector {
37 tls,
38 force_tls: true,
39 connect_timeout: None,
40 }
41 }
42
43 pub fn connect_timeout(mut self, timeout: Option<Duration>) -> Self {
45 self.connect_timeout = timeout;
46 self
47 }
48
49 pub fn allow_http_scheme(mut self) -> Self {
52 self.force_tls = false;
53 self
54 }
55
56 async fn connect(
57 uri: Uri,
58 tls: TlsConnector,
59 force_tls: bool,
60 connect_timeout: Option<Duration>,
61 ) -> Result<HttpOrHttpsConnection, ConnectError> {
62 let is_https = uri.scheme_str() == Some("https");
63 if !is_https && force_tls {
64 return Err(ConnectError::new("invalid URI: expected `https` scheme"));
65 }
66 let host = connector_impl::get_host(&uri)?.to_owned();
67 let http = connector_impl::connect(uri, true, connect_timeout).await?;
68 if is_https {
69 let server_name = rustls_pki_types::ServerName::try_from(host)
70 .map_err(|err| ConnectError::new("invalid host name").cause(err))?;
71 let tls = tls
72 .connect(server_name, http.into_tcp_stream())
73 .await
74 .map_err(|e| ConnectError::new("TLS error").cause(e))?;
75
76 Ok(HttpOrHttpsConnection::Https(tls))
77 } else {
78 Ok(HttpOrHttpsConnection::Http(http))
79 }
80 }
81}
82
83impl NetworkConnector for HttpsConnector {
84 fn connect(
85 &self,
86 uri: Uri,
87 ) -> Pin<
88 Box<dyn Future<Output = Result<NetworkConnection, Box<dyn StdError + Send + Sync>>> + Send>,
89 > {
90 let tls = self.tls.clone();
91 let force_tls = self.force_tls;
92 let connect_timeout = self.connect_timeout;
93 Box::pin(async move {
94 match HttpsConnector::connect(uri, tls, force_tls, connect_timeout).await {
95 Ok(conn) => Ok(NetworkConnection::new(conn)),
96 Err(e) => Err(Box::new(e) as _),
97 }
98 })
99 }
100}
101
102pub enum HttpOrHttpsConnection {
104 Http(HttpConnection),
105 Https(TlsStream<TcpStream>),
106}
107
108impl Connection for HttpOrHttpsConnection {
109 fn connected(&self) -> Connected {
110 Connected::new()
114 }
115}
116
117impl AsyncRead for HttpOrHttpsConnection {
118 fn poll_read(
119 self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 buf: &mut ReadBuf<'_>,
122 ) -> Poll<io::Result<()>> {
123 match Pin::get_mut(self) {
124 HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_read(cx, buf),
125 HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_read(cx, buf),
126 }
127 }
128}
129
130impl AsyncWrite for HttpOrHttpsConnection {
131 fn poll_write(
132 self: Pin<&mut Self>,
133 cx: &mut Context<'_>,
134 buf: &[u8],
135 ) -> Poll<io::Result<usize>> {
136 match Pin::get_mut(self) {
137 HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_write(cx, buf),
138 HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_write(cx, buf),
139 }
140 }
141
142 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
143 match Pin::get_mut(self) {
144 HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_flush(cx),
145 HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_flush(cx),
146 }
147 }
148
149 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
150 match Pin::get_mut(self) {
151 HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_shutdown(cx),
152 HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_shutdown(cx),
153 }
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use rustls_pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer};
161 use simple_hyper_client::{to_bytes, Client};
162 use simple_hyper_client::{StatusCode, Uri};
163 use std::{convert::TryFrom, net::SocketAddr, sync::Arc};
164 use tokio::{
165 io::{AsyncReadExt, AsyncWriteExt},
166 net::TcpListener,
167 sync::oneshot,
168 task::JoinHandle,
169 };
170 use tokio_rustls::{
171 rustls::{ClientConfig, RootCertStore, ServerConfig},
172 TlsAcceptor, TlsConnector,
173 };
174
175 const RESPONSE_OK: &str = "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!\r\n";
176
177 fn get_tls_connector() -> TlsConnector {
178 let test_ca_bytes = include_bytes!("../../test-ca/ca.cert");
179 let test_ca_der = CertificateDer::from_pem_slice(test_ca_bytes).unwrap();
180 let mut root_cert_store = RootCertStore::empty();
181 root_cert_store.add(test_ca_der).unwrap();
182 let config = ClientConfig::builder()
183 .with_root_certificates(root_cert_store)
184 .with_no_client_auth();
185 TlsConnector::from(Arc::new(config))
186 }
187
188 fn get_tls_acceptor() -> TlsAcceptor {
189 let test_end_cert_bytes = include_bytes!("../../test-ca/end.fullchain");
190 let test_end_cert = CertificateDer::pem_slice_iter(test_end_cert_bytes)
191 .collect::<Result<Vec<_>, _>>()
192 .unwrap();
193 let test_end_key_bytes = include_bytes!("../../test-ca/end.key");
194 let test_end_key = PrivateKeyDer::from_pem_slice(test_end_key_bytes).unwrap();
195 let config = ServerConfig::builder()
196 .with_no_client_auth()
197 .with_single_cert(test_end_cert, test_end_key)
198 .unwrap();
199 TlsAcceptor::from(Arc::new(config))
200 }
201
202 async fn start_tls_server(
203 resp: &'static str,
204 mut shutdown_rev: oneshot::Receiver<()>,
205 ) -> (JoinHandle<()>, SocketAddr) {
206 let listener = TcpListener::bind("localhost:0")
207 .await
208 .expect("Failed to bind to localhost");
209 let local_addr = listener.local_addr().unwrap();
210 let acceptor = get_tls_acceptor();
211
212 let server_handler = tokio::spawn(async move {
213 println!("Started TLS server at {}", local_addr);
214 loop {
215 tokio::select! {
216 Ok((stream, peer_addr)) = listener.accept() => {
217 let acceptor = acceptor.clone();
218 tokio::spawn(async move {
219 match acceptor.accept(stream).await {
220 Ok(mut tls_stream) => {
221 println!("TLS connection established with {}", peer_addr);
222 let mut input = Vec::with_capacity(1024);
223 let _n = tls_stream.read(&mut input).await.expect("failed to read");
224 println!("Received: {}", String::from_utf8_lossy(&input));
225 tls_stream.write_all(resp.as_bytes()).await.expect("failed to write");
226 }
227 Err(err) => eprintln!("TLS handshake failed: {}", err),
228 }
229 });
230 }
231
232 _ = &mut shutdown_rev => {
234 println!("Shutting down TLS server at {} ...", local_addr);
235 break;
236 }
237 }
238 }
239 });
240 (server_handler, local_addr)
241 }
242
243 #[tokio::test]
244 async fn test_connect_invalid_scheme() {
245 let tls_connector = get_tls_connector();
246 let uri = Uri::try_from("http://example.com").unwrap();
247
248 let result = HttpsConnector::connect(uri, tls_connector, true, None).await;
249 match result {
250 Err(err) => assert!(
251 err.to_string()
252 .contains("invalid URI: expected `https` scheme"),
253 "{}",
254 err
255 ),
256 Ok(_) => panic!("Expecting error for invalid URI scheme"),
257 }
258 }
259
260 #[tokio::test]
261 async fn test_connect() {
262 let _ = env_logger::try_init();
263 let tls_connector = get_tls_connector();
264
265 let (tx, rx) = oneshot::channel::<()>();
266 let (server_handler, addr) = start_tls_server(RESPONSE_OK, rx).await;
267
268 let uri = Uri::try_from(format!("https://localhost:{}", addr.port())).unwrap();
270 let connector = HttpsConnector::new(tls_connector);
271 let client = Client::with_connector(connector);
272 let response = client
273 .post(uri)
274 .unwrap()
275 .body(r#"plain text"#)
276 .send()
277 .await
278 .unwrap();
279
280 assert_eq!(response.status(), StatusCode::OK);
281 let body = to_bytes(response).await.unwrap();
282 assert_eq!(body, "Hello, world!".as_bytes());
283 tx.send(()).unwrap();
284 server_handler.await.unwrap();
285 }
286}