simple_hyper_client_rustls/
lib.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7use simple_hyper_client::{ConnectError, HttpConnection, NetworkConnection, NetworkConnector};
8
9use simple_hyper_client::connector_impl;
10use simple_hyper_client::hyper::client::connect::{Connected, Connection};
11use simple_hyper_client::Uri;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::net::TcpStream;
14use tokio_rustls::{client::TlsStream, TlsConnector};
15
16use std::convert::TryFrom;
17use std::error::Error as StdError;
18use std::future::Future;
19use std::io;
20use std::pin::Pin;
21use std::task::{Context, Poll};
22use std::time::Duration;
23
24/// An HTTPS connector using tokio-rustls.
25///
26/// TLS use is enforced by default. To allow plain `http` URIs call
27/// [`fn allow_http_scheme()`].
28pub struct HttpsConnector {
29    force_tls: bool,
30    tls: TlsConnector,
31    connect_timeout: Option<Duration>,
32}
33
34impl HttpsConnector {
35    pub fn new(tls: TlsConnector) -> Self {
36        HttpsConnector {
37            tls,
38            force_tls: true,
39            connect_timeout: None,
40        }
41    }
42
43    /// Set the connect timeout. Default is None.
44    pub fn connect_timeout(mut self, timeout: Option<Duration>) -> Self {
45        self.connect_timeout = timeout;
46        self
47    }
48
49    /// If called, the connector will allow URIs with the `http` scheme.
50    /// Otherwise only URIs with the `https` scheme are allowed.
51    pub fn allow_http_scheme(mut self) -> Self {
52        self.force_tls = false;
53        self
54    }
55
56    async fn connect(
57        uri: Uri,
58        tls: TlsConnector,
59        force_tls: bool,
60        connect_timeout: Option<Duration>,
61    ) -> Result<HttpOrHttpsConnection, ConnectError> {
62        let is_https = uri.scheme_str() == Some("https");
63        if !is_https && force_tls {
64            return Err(ConnectError::new("invalid URI: expected `https` scheme"));
65        }
66        let host = connector_impl::get_host(&uri)?.to_owned();
67        let http = connector_impl::connect(uri, true, connect_timeout).await?;
68        if is_https {
69            let server_name = rustls_pki_types::ServerName::try_from(host)
70                .map_err(|err| ConnectError::new("invalid host name").cause(err))?;
71            let tls = tls
72                .connect(server_name, http.into_tcp_stream())
73                .await
74                .map_err(|e| ConnectError::new("TLS error").cause(e))?;
75
76            Ok(HttpOrHttpsConnection::Https(tls))
77        } else {
78            Ok(HttpOrHttpsConnection::Http(http))
79        }
80    }
81}
82
83impl NetworkConnector for HttpsConnector {
84    fn connect(
85        &self,
86        uri: Uri,
87    ) -> Pin<
88        Box<dyn Future<Output = Result<NetworkConnection, Box<dyn StdError + Send + Sync>>> + Send>,
89    > {
90        let tls = self.tls.clone();
91        let force_tls = self.force_tls;
92        let connect_timeout = self.connect_timeout;
93        Box::pin(async move {
94            match HttpsConnector::connect(uri, tls, force_tls, connect_timeout).await {
95                Ok(conn) => Ok(NetworkConnection::new(conn)),
96                Err(e) => Err(Box::new(e) as _),
97            }
98        })
99    }
100}
101
102/// An HTTP or HTTPS connection
103pub enum HttpOrHttpsConnection {
104    Http(HttpConnection),
105    Https(TlsStream<TcpStream>),
106}
107
108impl Connection for HttpOrHttpsConnection {
109    fn connected(&self) -> Connected {
110        // TODO(#13): provide remote address
111        // TODO(#14): provide information about http protocol version (if
112        // negotiated through ALPN)
113        Connected::new()
114    }
115}
116
117impl AsyncRead for HttpOrHttpsConnection {
118    fn poll_read(
119        self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121        buf: &mut ReadBuf<'_>,
122    ) -> Poll<io::Result<()>> {
123        match Pin::get_mut(self) {
124            HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_read(cx, buf),
125            HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_read(cx, buf),
126        }
127    }
128}
129
130impl AsyncWrite for HttpOrHttpsConnection {
131    fn poll_write(
132        self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134        buf: &[u8],
135    ) -> Poll<io::Result<usize>> {
136        match Pin::get_mut(self) {
137            HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_write(cx, buf),
138            HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_write(cx, buf),
139        }
140    }
141
142    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
143        match Pin::get_mut(self) {
144            HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_flush(cx),
145            HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_flush(cx),
146        }
147    }
148
149    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
150        match Pin::get_mut(self) {
151            HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_shutdown(cx),
152            HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_shutdown(cx),
153        }
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use rustls_pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer};
161    use simple_hyper_client::{to_bytes, Client};
162    use simple_hyper_client::{StatusCode, Uri};
163    use std::{convert::TryFrom, net::SocketAddr, sync::Arc};
164    use tokio::{
165        io::{AsyncReadExt, AsyncWriteExt},
166        net::TcpListener,
167        sync::oneshot,
168        task::JoinHandle,
169    };
170    use tokio_rustls::{
171        rustls::{ClientConfig, RootCertStore, ServerConfig},
172        TlsAcceptor, TlsConnector,
173    };
174
175    const RESPONSE_OK: &str = "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!\r\n";
176
177    fn get_tls_connector() -> TlsConnector {
178        let test_ca_bytes = include_bytes!("../../test-ca/ca.cert");
179        let test_ca_der = CertificateDer::from_pem_slice(test_ca_bytes).unwrap();
180        let mut root_cert_store = RootCertStore::empty();
181        root_cert_store.add(test_ca_der).unwrap();
182        let config = ClientConfig::builder()
183            .with_root_certificates(root_cert_store)
184            .with_no_client_auth();
185        TlsConnector::from(Arc::new(config))
186    }
187
188    fn get_tls_acceptor() -> TlsAcceptor {
189        let test_end_cert_bytes = include_bytes!("../../test-ca/end.fullchain");
190        let test_end_cert = CertificateDer::pem_slice_iter(test_end_cert_bytes)
191            .collect::<Result<Vec<_>, _>>()
192            .unwrap();
193        let test_end_key_bytes = include_bytes!("../../test-ca/end.key");
194        let test_end_key = PrivateKeyDer::from_pem_slice(test_end_key_bytes).unwrap();
195        let config = ServerConfig::builder()
196            .with_no_client_auth()
197            .with_single_cert(test_end_cert, test_end_key)
198            .unwrap();
199        TlsAcceptor::from(Arc::new(config))
200    }
201
202    async fn start_tls_server(
203        resp: &'static str,
204        mut shutdown_rev: oneshot::Receiver<()>,
205    ) -> (JoinHandle<()>, SocketAddr) {
206        let listener = TcpListener::bind("localhost:0")
207            .await
208            .expect("Failed to bind to localhost");
209        let local_addr = listener.local_addr().unwrap();
210        let acceptor = get_tls_acceptor();
211
212        let server_handler = tokio::spawn(async move {
213            println!("Started TLS server at {}", local_addr);
214            loop {
215                tokio::select! {
216                    Ok((stream, peer_addr)) = listener.accept() => {
217                        let acceptor = acceptor.clone();
218                        tokio::spawn(async move {
219                            match acceptor.accept(stream).await {
220                                Ok(mut tls_stream) => {
221                                    println!("TLS connection established with {}", peer_addr);
222                                    let mut input = Vec::with_capacity(1024);
223                                    let _n = tls_stream.read(&mut input).await.expect("failed to read");
224                                    println!("Received: {}", String::from_utf8_lossy(&input));
225                                    tls_stream.write_all(resp.as_bytes()).await.expect("failed to write");
226                                }
227                                Err(err) => eprintln!("TLS handshake failed: {}", err),
228                            }
229                        });
230                    }
231
232                    // Stop the server when a shutdown signal is received
233                    _ = &mut shutdown_rev => {
234                        println!("Shutting down TLS server at {} ...", local_addr);
235                        break;
236                    }
237                }
238            }
239        });
240        (server_handler, local_addr)
241    }
242
243    #[tokio::test]
244    async fn test_connect_invalid_scheme() {
245        let tls_connector = get_tls_connector();
246        let uri = Uri::try_from("http://example.com").unwrap();
247
248        let result = HttpsConnector::connect(uri, tls_connector, true, None).await;
249        match result {
250            Err(err) => assert!(
251                err.to_string()
252                    .contains("invalid URI: expected `https` scheme"),
253                "{}",
254                err
255            ),
256            Ok(_) => panic!("Expecting error for invalid URI scheme"),
257        }
258    }
259
260    #[tokio::test]
261    async fn test_connect() {
262        let _ = env_logger::try_init();
263        let tls_connector = get_tls_connector();
264
265        let (tx, rx) = oneshot::channel::<()>();
266        let (server_handler, addr) = start_tls_server(RESPONSE_OK, rx).await;
267
268        // Note: localhost is important, since the cert we use have localhost in SAN
269        let uri = Uri::try_from(format!("https://localhost:{}", addr.port())).unwrap();
270        let connector = HttpsConnector::new(tls_connector);
271        let client = Client::with_connector(connector);
272        let response = client
273            .post(uri)
274            .unwrap()
275            .body(r#"plain text"#)
276            .send()
277            .await
278            .unwrap();
279
280        assert_eq!(response.status(), StatusCode::OK);
281        let body = to_bytes(response).await.unwrap();
282        assert_eq!(body, "Hello, world!".as_bytes());
283        tx.send(()).unwrap();
284        server_handler.await.unwrap();
285    }
286}