simple_hyper_client_native_tls/
lib.rs1use simple_hyper_client::{ConnectError, HttpConnection, NetworkConnection, NetworkConnector};
8
9use simple_hyper_client::connector_impl;
10use simple_hyper_client::hyper::client::connect::{Connected, Connection};
11use simple_hyper_client::Uri;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::net::TcpStream;
14use tokio_native_tls::{TlsConnector, TlsStream};
15
16use std::error::Error as StdError;
17use std::future::Future;
18use std::io;
19use std::pin::Pin;
20use std::task::{Context, Poll};
21use std::time::Duration;
22
23pub struct HttpsConnector {
28 force_tls: bool,
29 tls: TlsConnector,
30 connect_timeout: Option<Duration>,
31}
32
33impl HttpsConnector {
34 pub fn new(tls: TlsConnector) -> Self {
35 HttpsConnector {
36 tls,
37 force_tls: true,
38 connect_timeout: None,
39 }
40 }
41
42 pub fn connect_timeout(mut self, timeout: Option<Duration>) -> Self {
44 self.connect_timeout = timeout;
45 self
46 }
47
48 pub fn allow_http_scheme(mut self) -> Self {
51 self.force_tls = false;
52 self
53 }
54
55 async fn connect(
56 uri: Uri,
57 tls: TlsConnector,
58 force_tls: bool,
59 connect_timeout: Option<Duration>,
60 ) -> Result<HttpOrHttpsConnection, ConnectError> {
61 let is_https = uri.scheme_str() == Some("https");
62 if !is_https && force_tls {
63 return Err(ConnectError::new("invalid URI: expected `https` scheme"));
64 }
65 let host = connector_impl::get_host(&uri)?.to_owned();
66 let http = connector_impl::connect(uri, true, connect_timeout).await?;
67 if is_https {
68 let tls = tls
69 .connect(&host, http.into_tcp_stream())
70 .await
71 .map_err(|e| ConnectError::new("TLS error").cause(e))?;
72
73 Ok(HttpOrHttpsConnection::Https(tls))
74 } else {
75 Ok(HttpOrHttpsConnection::Http(http))
76 }
77 }
78}
79
80impl NetworkConnector for HttpsConnector {
81 fn connect(
82 &self,
83 uri: Uri,
84 ) -> Pin<
85 Box<dyn Future<Output = Result<NetworkConnection, Box<dyn StdError + Send + Sync>>> + Send>,
86 > {
87 let tls = self.tls.clone();
88 let force_tls = self.force_tls;
89 let connect_timeout = self.connect_timeout;
90 Box::pin(async move {
91 match HttpsConnector::connect(uri, tls, force_tls, connect_timeout).await {
92 Ok(conn) => Ok(NetworkConnection::new(conn)),
93 Err(e) => Err(Box::new(e) as _),
94 }
95 })
96 }
97}
98
99pub enum HttpOrHttpsConnection {
101 Http(HttpConnection),
102 Https(TlsStream<TcpStream>),
103}
104
105impl Connection for HttpOrHttpsConnection {
106 fn connected(&self) -> Connected {
107 Connected::new()
111 }
112}
113
114impl AsyncRead for HttpOrHttpsConnection {
115 fn poll_read(
116 self: Pin<&mut Self>,
117 cx: &mut Context<'_>,
118 buf: &mut ReadBuf<'_>,
119 ) -> Poll<io::Result<()>> {
120 match Pin::get_mut(self) {
121 HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_read(cx, buf),
122 HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_read(cx, buf),
123 }
124 }
125}
126
127impl AsyncWrite for HttpOrHttpsConnection {
128 fn poll_write(
129 self: Pin<&mut Self>,
130 cx: &mut Context<'_>,
131 buf: &[u8],
132 ) -> Poll<io::Result<usize>> {
133 match Pin::get_mut(self) {
134 HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_write(cx, buf),
135 HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_write(cx, buf),
136 }
137 }
138
139 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
140 match Pin::get_mut(self) {
141 HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_flush(cx),
142 HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_flush(cx),
143 }
144 }
145
146 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
147 match Pin::get_mut(self) {
148 HttpOrHttpsConnection::Http(s) => Pin::new(s).poll_shutdown(cx),
149 HttpOrHttpsConnection::Https(s) => Pin::new(s).poll_shutdown(cx),
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use std::{convert::TryFrom, net::SocketAddr};
157
158 use simple_hyper_client::StatusCode;
159 use simple_hyper_client::{to_bytes, Client};
160 use tokio::{
161 io::{AsyncReadExt, AsyncWriteExt},
162 net::TcpListener,
163 sync::oneshot,
164 task::JoinHandle,
165 };
166 use tokio_native_tls::{
167 native_tls::{self, Certificate, Identity},
168 TlsAcceptor,
169 };
170
171 use super::*;
172
173 const RESPONSE_OK: &str = "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!\r\n";
174
175 fn get_tls_connector() -> TlsConnector {
176 let pem = include_bytes!("../../test-ca/ca.cert");
177 let cert = Certificate::from_pem(pem).unwrap();
178
179 tokio_native_tls::native_tls::TlsConnector::builder()
180 .add_root_certificate(cert)
181 .build()
182 .unwrap()
183 .into()
184 }
185
186 fn get_tls_acceptor() -> TlsAcceptor {
187 let pem = include_bytes!("../../test-ca/end.fullchain");
188
189 let key = include_bytes!("../../test-ca/end.key");
190
191 let identity = Identity::from_pkcs8(pem, key).unwrap();
192 native_tls::TlsAcceptor::builder(identity)
193 .build()
194 .unwrap()
195 .into()
196 }
197
198 async fn start_tls_server(
199 resp: &'static str,
200 mut shutdown_rev: oneshot::Receiver<()>,
201 ) -> (JoinHandle<()>, SocketAddr) {
202 let listener = TcpListener::bind("localhost:0")
203 .await
204 .expect("Failed to bind to localhost");
205 let local_addr = listener.local_addr().unwrap();
206 let acceptor = get_tls_acceptor();
207
208 let server_handler = tokio::spawn(async move {
209 println!("Started TLS server at {}", local_addr);
210 loop {
211 tokio::select! {
212 Ok((stream, peer_addr)) = listener.accept() => {
213 let acceptor = acceptor.clone();
214 tokio::spawn(async move {
215 match acceptor.accept(stream).await {
216 Ok(mut tls_stream) => {
217 println!("TLS connection established with {}", peer_addr);
218 let mut input = Vec::with_capacity(1024);
219 let _n = tls_stream.read(&mut input).await.expect("failed to read");
220 println!("Received: {}", String::from_utf8_lossy(&input));
221 tls_stream.write_all(resp.as_bytes()).await.expect("failed to write");
222 }
223 Err(err) => eprintln!("TLS handshake failed: {}", err),
224 }
225 });
226 }
227
228 _ = &mut shutdown_rev => {
230 println!("Shutting down TLS server at {} ...", local_addr);
231 break;
232 }
233 }
234 }
235 });
236 (server_handler, local_addr)
237 }
238
239 #[tokio::test]
240 async fn test_connect_invalid_scheme() {
241 let tls_connector = get_tls_connector();
242 let uri = Uri::try_from("http://example.com").unwrap();
243
244 let result = HttpsConnector::connect(uri, tls_connector, true, None).await;
245 match result {
246 Err(err) => assert!(
247 err.to_string()
248 .contains("invalid URI: expected `https` scheme"),
249 "{}",
250 err
251 ),
252 Ok(_) => panic!("Expecting error for invalid URI scheme"),
253 }
254 }
255
256 #[tokio::test]
257 async fn test_connect() {
258 let tls_connector = get_tls_connector();
259 let _ = get_tls_acceptor();
260 let (tx, rx) = oneshot::channel::<()>();
261 let (server_handler, addr) = start_tls_server(RESPONSE_OK, rx).await;
262
263 let uri = Uri::try_from(format!("https://localhost:{}", addr.port())).unwrap();
265 let connector = HttpsConnector::new(tls_connector);
266 let client = Client::with_connector(connector);
267 let response = client
268 .post(uri)
269 .unwrap()
270 .body(r#"plain text"#)
271 .send()
272 .await
273 .unwrap();
274
275 assert_eq!(response.status(), StatusCode::OK);
276 let body = to_bytes(response).await.unwrap();
277 assert_eq!(body, "Hello, world!".as_bytes());
278 tx.send(()).unwrap();
279 server_handler.await.unwrap();
280 }
281}