tokio_libtls/
lib.rs

1// Copyright (c) 2019, 2020 Reyk Floeter <contact@reykfloeter.com>
2//
3// Permission to use, copy, modify, and distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15//! Async [`Tls`] bindings for [`libtls`].
16//!
17//! See also [`libtls`] for more information.
18//!
19//! > Note, the API for this crate is neither finished nor documented yet.
20//!
21//! # Example
22//!
23//! ```rust
24//! use std::io;
25//! use tokio::io::{AsyncReadExt, AsyncWriteExt};
26//! use tokio_libtls::prelude::{connect, Builder};
27//!
28//! async fn async_https_connect(servername: String) -> io::Result<()> {
29//!     let request = format!(
30//!         "GET / HTTP/1.1\r\n\
31//!          Host: {}\r\n\
32//!          Connection: close\r\n\r\n",
33//!         servername
34//!     );
35//!
36//!     let config = Builder::new().build()?;
37//!     let mut tls = connect(&(servername + ":443"), &config, None).await?;
38//!     tls.write_all(request.as_bytes()).await?;
39//!
40//!     let mut buf = vec![0u8; 1024];
41//!     tls.read_exact(&mut buf).await?;
42//!
43//!     let ok = b"HTTP/1.1 200 OK\r\n";
44//!     assert_eq!(&buf[..ok.len()], ok);
45//!
46//!     Ok(())
47//! }
48//! # #[tokio::main]
49//! # async fn main() {
50//! #    async_https_connect("www.example.com".to_owned()).await.unwrap();
51//! # }
52//! ```
53//!
54//! [`Tls`]: https://reyk.github.io/rust-libtls/libtls/tls/struct.Tls.html
55//! [`libtls`]: https://reyk.github.io/rust-libtls/libtls
56
57#![doc(
58    html_logo_url = "https://www.libressl.org/images/libressl.jpg",
59    html_favicon_url = "https://www.libressl.org/favicon.ico"
60)]
61#![warn(missing_docs)]
62
63/// Error handling.
64pub mod error;
65
66/// A "prelude" for crates using the `tokio-libtls` crate.
67pub mod prelude;
68
69use error::Error;
70use libtls::{config::Config, error::Error as TlsError, tls::Tls};
71use mio::{event::Evented, unix::EventedFd, PollOpt, Ready, Token};
72use prelude::*;
73use std::{
74    io::{self, Read, Write},
75    net::ToSocketAddrs,
76    ops::{Deref, DerefMut},
77    os::unix::io::{AsRawFd, RawFd},
78    pin::Pin,
79    task::{Context, Poll},
80    time::Duration,
81};
82use tokio::{
83    io::{AsyncRead, AsyncWrite, PollEvented},
84    net::{TcpListener, TcpStream},
85    time::timeout,
86};
87
88macro_rules! try_async_tls {
89    ($call: expr) => {
90        match $call {
91            Ok(size) => Poll::Ready(Ok(size)),
92            Err(err) => {
93                let err: io::Error = err.into();
94                if err.kind() == io::ErrorKind::WouldBlock {
95                    Poll::Pending
96                } else {
97                    Poll::Ready(Err(err))
98                }
99            }
100        }
101    };
102}
103
104/// Wrapper for async I/O operations with `Tls`.
105#[derive(Debug)]
106pub struct TlsStream {
107    tls: Tls,
108    tcp: TcpStream,
109}
110
111impl TlsStream {
112    /// Create new `TlsStream` from `Tls` object and `TcpStream`.
113    pub fn new(tls: Tls, tcp: TcpStream) -> Self {
114        Self { tls, tcp }
115    }
116}
117
118impl Deref for TlsStream {
119    type Target = Tls;
120
121    fn deref(&self) -> &Self::Target {
122        &self.tls
123    }
124}
125
126impl DerefMut for TlsStream {
127    fn deref_mut(&mut self) -> &mut Self::Target {
128        &mut self.tls
129    }
130}
131
132impl AsRawFd for TlsStream {
133    fn as_raw_fd(&self) -> RawFd {
134        self.tcp.as_raw_fd()
135    }
136}
137
138impl io::Read for TlsStream {
139    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
140        self.tls.read(buf)
141    }
142}
143
144impl io::Write for TlsStream {
145    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
146        self.tls.write(buf)
147    }
148
149    fn flush(&mut self) -> io::Result<()> {
150        self.tls.flush()
151    }
152}
153
154impl AsyncRead for TlsStream {
155    fn poll_read(
156        mut self: Pin<&mut Self>,
157        _cx: &mut Context<'_>,
158        buf: &mut [u8],
159    ) -> Poll<Result<usize, io::Error>> {
160        try_async_tls!(self.tls.read(buf))
161    }
162}
163
164impl AsyncWrite for TlsStream {
165    fn poll_write(
166        mut self: Pin<&mut Self>,
167        _cx: &mut Context<'_>,
168        buf: &[u8],
169    ) -> Poll<Result<usize, io::Error>> {
170        try_async_tls!(self.tls.write(buf))
171    }
172
173    fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
174        try_async_tls!(self.tls.close()).map(|_| Ok(()))
175    }
176
177    fn poll_shutdown(
178        mut self: Pin<&mut Self>,
179        _cx: &mut Context<'_>,
180    ) -> Poll<Result<(), io::Error>> {
181        try_async_tls!(self.tls.close()).map(|_| Ok(()))
182    }
183}
184
185impl Evented for TlsStream {
186    fn register(
187        &self,
188        poll: &mio::Poll,
189        token: Token,
190        interest: Ready,
191        opts: PollOpt,
192    ) -> io::Result<()> {
193        match EventedFd(&self.as_raw_fd()).register(poll, token, interest, opts) {
194            Err(ref err) if err.kind() == io::ErrorKind::AlreadyExists => {
195                self.reregister(poll, token, interest, opts)
196            }
197            Err(err) => Err(err),
198            Ok(_) => Ok(()),
199        }
200    }
201
202    fn reregister(
203        &self,
204        poll: &mio::Poll,
205        token: Token,
206        interest: Ready,
207        opts: PollOpt,
208    ) -> io::Result<()> {
209        EventedFd(&self.as_raw_fd()).reregister(poll, token, interest, opts)
210    }
211
212    fn deregister(&self, poll: &mio::Poll) -> io::Result<()> {
213        EventedFd(&self.as_raw_fd()).deregister(poll)
214    }
215}
216
217unsafe impl Send for TlsStream {}
218unsafe impl Sync for TlsStream {}
219
220/// Pollable wrapper for async I/O operations with `Tls`.
221pub type AsyncTlsStream = PollEvented<TlsStream>;
222
223/// Async `Tls` struct.
224pub struct AsyncTls {
225    inner: Option<Result<AsyncTlsStream, Error>>,
226}
227
228impl AsyncTls {
229    /// Accept a new async `Tls` connection.
230    #[deprecated(since = "1.1.1", note = "Please use module function `accept_stream`")]
231    pub async fn accept_stream(
232        tcp: TcpStream,
233        config: &Config,
234        options: Option<Options>,
235    ) -> io::Result<AsyncTlsStream> {
236        accept_stream(tcp, config, options).await
237    }
238
239    /// Upgrade a TCP stream to a new async `Tls` connection.
240    #[deprecated(since = "1.1.1", note = "Please use module function `connect_stream`")]
241    pub async fn connect_stream(
242        tcp: TcpStream,
243        config: &Config,
244        options: Option<Options>,
245    ) -> io::Result<AsyncTlsStream> {
246        connect_stream(tcp, config, options).await
247    }
248
249    /// Connect a new async `Tls` connection.
250    #[deprecated(since = "1.1.1", note = "Please use module function `connect`")]
251    pub async fn connect(
252        host: &str,
253        config: &Config,
254        options: Option<Options>,
255    ) -> io::Result<AsyncTlsStream> {
256        connect(host, config, options).await
257    }
258}
259
260impl Future for AsyncTls {
261    type Output = Result<AsyncTlsStream, io::Error>;
262
263    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
264        let inner = self
265            .inner
266            .take()
267            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "cannot take inner"))?;
268        match inner {
269            Ok(tls) => {
270                cx.waker().wake_by_ref();
271                Poll::Ready(Ok(tls))
272            }
273            Err(Error::Readable(stream)) => {
274                self.inner = match stream.poll_read_ready(cx, Ready::readable()) {
275                    Poll::Ready(_) => Some(Err(Error::Handshake(stream))),
276                    _ => Some(Err(Error::Handshake(stream))),
277                };
278                cx.waker().wake_by_ref();
279                Poll::Pending
280            }
281            Err(Error::Writeable(stream)) => {
282                self.inner = match stream.poll_write_ready(cx) {
283                    Poll::Ready(_) => Some(Err(Error::Handshake(stream))),
284                    _ => Some(Err(Error::Writeable(stream))),
285                };
286                cx.waker().wake_by_ref();
287                Poll::Pending
288            }
289            Err(Error::Handshake(mut stream)) => {
290                let tls = &mut *stream.get_mut();
291                let res = match tls.tls_handshake() {
292                    Ok(res) => {
293                        if res == libtls::TLS_WANT_POLLIN as isize {
294                            Err(Error::Readable(stream))
295                        } else if res == libtls::TLS_WANT_POLLOUT as isize {
296                            Err(Error::Writeable(stream))
297                        } else {
298                            Ok(stream)
299                        }
300                    }
301                    Err(err) => Err(err.into()),
302                };
303                self.inner = Some(res);
304                cx.waker().wake_by_ref();
305                Poll::Pending
306            }
307            Err(Error::Error(TlsError::IoError(err))) => Poll::Ready(Err(err)),
308            Err(Error::Error(err)) => {
309                Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err.to_string())))
310            }
311        }
312    }
313}
314
315unsafe impl Send for AsyncTls {}
316unsafe impl Sync for AsyncTls {}
317
318/// Accept a new async `Tls` connection.
319pub async fn accept(
320    listener: &mut TcpListener,
321    config: &Config,
322    options: Option<Options>,
323) -> io::Result<AsyncTlsStream> {
324    let options = options.unwrap_or_else(Options::new);
325
326    let (tcp, _) = listener.accept().await?;
327    let mut server = Tls::server()?;
328    server.configure(config)?;
329    let client = server.accept_raw_fd(&tcp)?;
330
331    let async_tls = TlsStream::new(client, tcp);
332    let stream = PollEvented::new(async_tls)?;
333    let fut = AsyncTls {
334        inner: Some(Err(Error::Readable(stream))),
335    };
336
337    // Accept with an optional timeout for the TLS handshake.
338    let tls = match options.timeout {
339        Some(tm) => match timeout(tm, fut).await {
340            Ok(res) => res,
341            Err(err) => Err(err.into()),
342        },
343        None => fut.await,
344    }?;
345
346    Ok(tls)
347}
348
349/// Accept a new async `Tls` connection on an established client connection.
350pub async fn accept_stream(
351    tcp: TcpStream,
352    config: &Config,
353    options: Option<Options>,
354) -> io::Result<AsyncTlsStream> {
355    let options = options.unwrap_or_else(Options::new);
356
357    let mut server = Tls::server()?;
358    server.configure(config)?;
359    let client = server.accept_raw_fd(&tcp)?;
360
361    let async_tls = TlsStream::new(client, tcp);
362    let stream = PollEvented::new(async_tls)?;
363    let fut = AsyncTls {
364        inner: Some(Err(Error::Readable(stream))),
365    };
366
367    // Accept with an optional timeout for the TLS handshake.
368    let tls = match options.timeout {
369        Some(tm) => match timeout(tm, fut).await {
370            Ok(res) => res,
371            Err(err) => Err(err.into()),
372        },
373        None => fut.await,
374    }?;
375
376    Ok(tls)
377}
378
379/// Upgrade a TCP stream to a new async `Tls` connection.
380pub async fn connect_stream(
381    tcp: TcpStream,
382    config: &Config,
383    options: Option<Options>,
384) -> io::Result<AsyncTlsStream> {
385    let options = options.unwrap_or_else(Options::new);
386    let servername = match options.servername {
387        Some(name) => name,
388        None => tcp.peer_addr()?.to_string(),
389    };
390
391    let mut tls = Tls::client()?;
392
393    tls.configure(config)?;
394    tls.connect_raw_fd(&tcp, &servername)?;
395
396    let async_tls = TlsStream::new(tls, tcp);
397    let stream = PollEvented::new(async_tls)?;
398    let fut = AsyncTls {
399        inner: Some(Err(Error::Readable(stream))),
400    };
401
402    // Connect with an optional timeout for the TLS handshake.
403    let tls = match options.timeout {
404        Some(tm) => match timeout(tm, fut).await {
405            Ok(res) => res,
406            Err(err) => Err(err.into()),
407        },
408        None => fut.await,
409    }?;
410
411    Ok(tls)
412}
413
414/// Connect a new async `Tls` connection.
415pub async fn connect(
416    host: &str,
417    config: &Config,
418    options: Option<Options>,
419) -> io::Result<AsyncTlsStream> {
420    let mut options = options.unwrap_or_else(Options::new);
421
422    // Remove _last_ colon (to satisfy the IPv6 form, e.g. [::1]::443).
423    if options.servername.is_none() {
424        match host.rfind(':') {
425            None => return Err(io::ErrorKind::InvalidInput.into()),
426            Some(index) => options.servername(&host[0..index]),
427        };
428    };
429
430    let mut last_error = io::ErrorKind::ConnectionRefused.into();
431
432    for addr in host.to_socket_addrs()? {
433        // Connect with an optional timeout.
434        let res = match options.timeout {
435            Some(tm) => match timeout(tm, TcpStream::connect(&addr)).await {
436                Ok(res) => res,
437                Err(err) => Err(err.into()),
438            },
439            None => TcpStream::connect(&addr).await,
440        };
441
442        // Return the first TCP successful connection, store the last error.
443        match res {
444            Ok(tcp) => {
445                return connect_stream(tcp, config, Some(options)).await;
446            }
447            Err(err) => last_error = err,
448        }
449    }
450
451    Err(last_error)
452}
453
454/// Configuration options for `AsyncTls`.
455///
456/// # See also
457///
458/// [`AsyncTls`]
459///
460/// [`AsyncTls`]: ./struct.AsyncTls.html
461#[derive(Clone, Default, Debug, PartialEq)]
462pub struct Options {
463    timeout: Option<Duration>,
464    servername: Option<String>,
465}
466
467/// Configuration options for `AsyncTls`.
468#[deprecated(
469    since = "1.1.1",
470    note = "Please use `Options` instead of `AsyncTlsOptions`"
471)]
472pub type AsyncTlsOptions = Options;
473
474impl Options {
475    /// Return new empty `Options` struct.
476    pub fn new() -> Self {
477        Default::default()
478    }
479
480    /// Set the optional TCP connection and TLS handshake timeout.
481    pub fn timeout(&'_ mut self, timeout: Duration) -> &'_ mut Self {
482        self.timeout = Some(timeout);
483        self
484    }
485
486    /// Set the optional TLS servername.
487    ///
488    /// If not specified, the address is derived from the host or address.
489    pub fn servername(&'_ mut self, servername: &str) -> &'_ mut Self {
490        self.servername = Some(servername.to_owned());
491        self
492    }
493
494    /// Return as `Some(Options)` or `None` if the options are empty.
495    pub fn build(&'_ mut self) -> Option<Self> {
496        if self == &mut Self::new() {
497            None
498        } else {
499            Some(self.clone())
500        }
501    }
502}
503
504#[cfg(test)]
505mod test {
506    use crate::prelude::*;
507    use std::{io, time::Duration};
508    use tokio::io::{AsyncReadExt, AsyncWriteExt};
509
510    async fn async_https_connect(servername: String) -> io::Result<()> {
511        let request = format!(
512            "GET / HTTP/1.1\r\n\
513             Host: {}\r\n\
514             Connection: close\r\n\r\n",
515            servername
516        );
517
518        let config = Builder::new().build()?;
519        let options = Options::new()
520            .servername(&servername)
521            .timeout(Duration::from_secs(60))
522            .build();
523        let mut tls = connect(&(servername + ":443"), &config, options).await?;
524        tls.write_all(request.as_bytes()).await?;
525
526        let mut buf = vec![0u8; 1024];
527        tls.read_exact(&mut buf).await?;
528
529        let ok = b"HTTP/1.1 200 OK\r\n";
530        assert_eq!(&buf[..ok.len()], ok);
531
532        Ok(())
533    }
534
535    #[tokio::test]
536    async fn test_async_https_connect() {
537        async_https_connect("www.example.com".to_owned())
538            .await
539            .unwrap();
540    }
541}