simple_hyper_client_native_tls/
lib.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7use simple_hyper_client::{ConnectError, HttpConnection, NetworkConnection, NetworkConnector};
8
9use simple_hyper_client::connector_impl;
10use simple_hyper_client::hyper::client::connect::{Connected, Connection};
11use simple_hyper_client::Uri;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::net::TcpStream;
14use tokio_native_tls::{TlsConnector, TlsStream};
15
16use std::error::Error as StdError;
17use std::future::Future;
18use std::io;
19use std::pin::Pin;
20use std::task::{Context, Poll};
21use std::time::Duration;
22
23/// An HTTPS connector using native-tls.
24///
25/// TLS use is enforced by default. To allow plain `http` URIs call
26/// [`fn allow_http_scheme()`].
27pub struct HttpsConnector {
28    force_tls: bool,
29    tls: TlsConnector,
30    connect_timeout: Option<Duration>,
31}
32
33impl HttpsConnector {
34    pub fn new(tls: TlsConnector) -> Self {
35        HttpsConnector {
36            tls,
37            force_tls: true,
38            connect_timeout: None,
39        }
40    }
41
42    /// Set the connect timeout. Default is None.
43    pub fn connect_timeout(mut self, timeout: Option<Duration>) -> Self {
44        self.connect_timeout = timeout;
45        self
46    }
47
48    /// If called, the connector will allow URIs with the `http` scheme.
49    /// Otherwise only URIs with the `https` scheme are allowed.
50    pub fn allow_http_scheme(mut self) -> Self {
51        self.force_tls = false;
52        self
53    }
54
55    async fn connect(
56        uri: Uri,
57        tls: TlsConnector,
58        force_tls: bool,
59        connect_timeout: Option<Duration>,
60    ) -> Result<HttpOrHttpsConnection, ConnectError> {
61        let is_https = uri.scheme_str() == Some("https");
62        if !is_https && force_tls {
63            return Err(ConnectError::new("invalid URI: expected `https` scheme"));
64        }
65        let host = connector_impl::get_host(&uri)?.to_owned();
66        let http = connector_impl::connect(uri, true, connect_timeout).await?;
67        if is_https {
68            let tls = tls
69                .connect(&host, http.into_tcp_stream())
70                .await
71                .map_err(|e| ConnectError::new("TLS error").cause(e))?;
72
73            Ok(HttpOrHttpsConnection::Https(tls))
74        } else {
75            Ok(HttpOrHttpsConnection::Http(http))
76        }
77    }
78}
79
80impl NetworkConnector for HttpsConnector {
81    fn connect(
82        &self,
83        uri: Uri,
84    ) -> Pin<
85        Box<dyn Future<Output = Result<NetworkConnection, Box<dyn StdError + Send + Sync>>> + Send>,
86    > {
87        let tls = self.tls.clone();
88        let force_tls = self.force_tls;
89        let connect_timeout = self.connect_timeout;
90        Box::pin(async move {
91            match HttpsConnector::connect(uri, tls, force_tls, connect_timeout).await {
92                Ok(conn) => Ok(NetworkConnection::new(conn)),
93                Err(e) => Err(Box::new(e) as _),
94            }
95        })
96    }
97}
98
99/// An HTTP or HTTPS connection
100pub enum HttpOrHttpsConnection {
101    Http(HttpConnection),
102    Https(TlsStream<TcpStream>),
103}
104
105impl Connection for HttpOrHttpsConnection {
106    fn connected(&self) -> Connected {
107        // TODO(#13): provide remote address
108        // TODO(#14): provide information about http protocol version (if
109        // negotiated through ALPN)
110        Connected::new()
111    }
112}
113
114impl AsyncRead for HttpOrHttpsConnection {
115    fn poll_read(
116        self: Pin<&mut Self>,
117        cx: &mut Context<'_>,
118        buf: &mut ReadBuf<'_>,
119    ) -> Poll<io::Result<()>> {
120        match Pin::get_mut(self) {
121            HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_read(cx, buf),
122            HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_read(cx, buf),
123        }
124    }
125}
126
127impl AsyncWrite for HttpOrHttpsConnection {
128    fn poll_write(
129        self: Pin<&mut Self>,
130        cx: &mut Context<'_>,
131        buf: &[u8],
132    ) -> Poll<io::Result<usize>> {
133        match Pin::get_mut(self) {
134            HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_write(cx, buf),
135            HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_write(cx, buf),
136        }
137    }
138
139    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
140        match Pin::get_mut(self) {
141            HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_flush(cx),
142            HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_flush(cx),
143        }
144    }
145
146    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
147        match Pin::get_mut(self) {
148            HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_shutdown(cx),
149            HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_shutdown(cx),
150        }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use std::{convert::TryFrom, net::SocketAddr};
157
158    use simple_hyper_client::StatusCode;
159    use simple_hyper_client::{to_bytes, Client};
160    use tokio::{
161        io::{AsyncReadExt, AsyncWriteExt},
162        net::TcpListener,
163        sync::oneshot,
164        task::JoinHandle,
165    };
166    use tokio_native_tls::{
167        native_tls::{self, Certificate, Identity},
168        TlsAcceptor,
169    };
170
171    use super::*;
172
173    const RESPONSE_OK: &str = "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!\r\n";
174
175    fn get_tls_connector() -> TlsConnector {
176        let pem = include_bytes!("../../test-ca/ca.cert");
177        let cert = Certificate::from_pem(pem).unwrap();
178
179        tokio_native_tls::native_tls::TlsConnector::builder()
180            .add_root_certificate(cert)
181            .build()
182            .unwrap()
183            .into()
184    }
185
186    fn get_tls_acceptor() -> TlsAcceptor {
187        let pem = include_bytes!("../../test-ca/end.fullchain");
188
189        let key = include_bytes!("../../test-ca/end.key");
190
191        let identity = Identity::from_pkcs8(pem, key).unwrap();
192        native_tls::TlsAcceptor::builder(identity)
193            .build()
194            .unwrap()
195            .into()
196    }
197
198    async fn start_tls_server(
199        resp: &'static str,
200        mut shutdown_rev: oneshot::Receiver<()>,
201    ) -> (JoinHandle<()>, SocketAddr) {
202        let listener = TcpListener::bind("localhost:0")
203            .await
204            .expect("Failed to bind to localhost");
205        let local_addr = listener.local_addr().unwrap();
206        let acceptor = get_tls_acceptor();
207
208        let server_handler = tokio::spawn(async move {
209            println!("Started TLS server at {}", local_addr);
210            loop {
211                tokio::select! {
212                    Ok((stream, peer_addr)) = listener.accept() => {
213                        let acceptor = acceptor.clone();
214                        tokio::spawn(async move {
215                            match acceptor.accept(stream).await {
216                                Ok(mut tls_stream) => {
217                                    println!("TLS connection established with {}", peer_addr);
218                                    let mut input = Vec::with_capacity(1024);
219                                    let _n = tls_stream.read(&mut input).await.expect("failed to read");
220                                    println!("Received: {}", String::from_utf8_lossy(&input));
221                                    tls_stream.write_all(resp.as_bytes()).await.expect("failed to write");
222                                }
223                                Err(err) => eprintln!("TLS handshake failed: {}", err),
224                            }
225                        });
226                    }
227
228                    // Stop the server when a shutdown signal is received
229                    _ = &mut shutdown_rev => {
230                        println!("Shutting down TLS server at {} ...", local_addr);
231                        break;
232                    }
233                }
234            }
235        });
236        (server_handler, local_addr)
237    }
238
239    #[tokio::test]
240    async fn test_connect_invalid_scheme() {
241        let tls_connector = get_tls_connector();
242        let uri = Uri::try_from("http://example.com").unwrap();
243
244        let result = HttpsConnector::connect(uri, tls_connector, true, None).await;
245        match result {
246            Err(err) => assert!(
247                err.to_string()
248                    .contains("invalid URI: expected `https` scheme"),
249                "{}",
250                err
251            ),
252            Ok(_) => panic!("Expecting error for invalid URI scheme"),
253        }
254    }
255
256    #[tokio::test]
257    async fn test_connect() {
258        let tls_connector = get_tls_connector();
259        let _ = get_tls_acceptor();
260        let (tx, rx) = oneshot::channel::<()>();
261        let (server_handler, addr) = start_tls_server(RESPONSE_OK, rx).await;
262
263        // Note: localhost is important, since the cert we use have localhost in SAN
264        let uri = Uri::try_from(format!("https://localhost:{}", addr.port())).unwrap();
265        let connector = HttpsConnector::new(tls_connector);
266        let client = Client::with_connector(connector);
267        let response = client
268            .post(uri)
269            .unwrap()
270            .body(r#"plain text"#)
271            .send()
272            .await
273            .unwrap();
274
275        assert_eq!(response.status(), StatusCode::OK);
276        let body = to_bytes(response).await.unwrap();
277        assert_eq!(body, "Hello, world!".as_bytes());
278        tx.send(()).unwrap();
279        server_handler.await.unwrap();
280    }
281}