tokio_boring/
lib.rs

1//! Async TLS streams backed by BoringSSL
2//!
3//! This library is an implementation of TLS streams using BoringSSL for
4//! negotiating the connection. Each TLS stream implements the `Read` and
5//! `Write` traits to interact and interoperate with the rest of the futures I/O
6//! ecosystem. Client connections initiated from this crate verify hostnames
7//! automatically and by default.
8//!
9//! `tokio-boring` exports this ability through [`accept`] and [`connect`]. `accept` should
10//! be used by servers, and `connect` by clients. These augment the functionality provided by the
11//! [`boring`] crate, on which this crate is built. Configuration of TLS parameters is still
12//! primarily done through the [`boring`] crate.
13#![warn(missing_docs)]
14#![cfg_attr(docsrs, feature(doc_auto_cfg))]
15
16use boring::ssl::{
17    self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor,
18    SslRef,
19};
20use boring_sys as ffi;
21use std::error::Error;
22use std::fmt;
23use std::future::Future;
24use std::io::{self, Write};
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
28
29mod async_callbacks;
30mod bridge;
31
32use self::bridge::AsyncStreamBridge;
33
34pub use crate::async_callbacks::SslContextBuilderExt;
35pub use boring::ssl::{
36    AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
37    BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
38    BoxSelectCertFuture, ExDataFuture,
39};
40
41/// Asynchronously performs a client-side TLS handshake over the provided stream.
42///
43/// This function automatically sets the task waker on the `Ssl` from `config` to
44/// allow to make use of async callbacks provided by the boring crate.
45pub async fn connect<S>(
46    config: ConnectConfiguration,
47    domain: &str,
48    stream: S,
49) -> Result<SslStream<S>, HandshakeError<S>>
50where
51    S: AsyncRead + AsyncWrite + Unpin,
52{
53    let mid_handshake = config
54        .setup_connect(domain, AsyncStreamBridge::new(stream))
55        .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
56
57    HandshakeFuture(Some(mid_handshake)).await
58}
59
60/// Asynchronously performs a server-side TLS handshake over the provided stream.
61///
62/// This function automatically sets the task waker on the `Ssl` from `config` to
63/// allow to make use of async callbacks provided by the boring crate.
64pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
65where
66    S: AsyncRead + AsyncWrite + Unpin,
67{
68    let mid_handshake = acceptor
69        .setup_accept(AsyncStreamBridge::new(stream))
70        .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
71
72    HandshakeFuture(Some(mid_handshake)).await
73}
74
75fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
76    match r {
77        Ok(v) => Poll::Ready(Ok(v)),
78        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
79        Err(e) => Poll::Ready(Err(e)),
80    }
81}
82
83/// A partially constructed `SslStream`, useful for unusual handshakes.
84pub struct SslStreamBuilder<S> {
85    inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
86}
87
88impl<S> SslStreamBuilder<S>
89where
90    S: AsyncRead + AsyncWrite + Unpin,
91{
92    /// Begins creating an `SslStream` atop `stream`.
93    pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
94        Self {
95            inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
96        }
97    }
98
99    /// Initiates a client-side TLS handshake.
100    pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
101        let mid_handshake = self.inner.setup_accept();
102
103        HandshakeFuture(Some(mid_handshake)).await
104    }
105
106    /// Initiates a server-side TLS handshake.
107    pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
108        let mid_handshake = self.inner.setup_connect();
109
110        HandshakeFuture(Some(mid_handshake)).await
111    }
112}
113
114impl<S> SslStreamBuilder<S> {
115    /// Returns a shared reference to the `Ssl` object associated with this builder.
116    #[must_use]
117    pub fn ssl(&self) -> &SslRef {
118        self.inner.ssl()
119    }
120
121    /// Returns a mutable reference to the `Ssl` object associated with this builder.
122    pub fn ssl_mut(&mut self) -> &mut SslRef {
123        self.inner.ssl_mut()
124    }
125}
126
127/// A wrapper around an underlying raw stream which implements the SSL
128/// protocol.
129///
130/// A `SslStream<S>` represents a handshake that has been completed successfully
131/// and both the server and the client are ready for receiving and sending
132/// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written
133/// to a `SslStream` are encrypted when passing through to `S`.
134#[derive(Debug)]
135pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
136
137impl<S> SslStream<S> {
138    /// Returns a shared reference to the `Ssl` object associated with this stream.
139    #[must_use]
140    pub fn ssl(&self) -> &SslRef {
141        self.0.ssl()
142    }
143
144    /// Returns a mutable reference to the `Ssl` object associated with this stream.
145    pub fn ssl_mut(&mut self) -> &mut SslRef {
146        self.0.ssl_mut()
147    }
148
149    /// Returns a shared reference to the underlying stream.
150    #[must_use]
151    pub fn get_ref(&self) -> &S {
152        &self.0.get_ref().stream
153    }
154
155    /// Returns a mutable reference to the underlying stream.
156    pub fn get_mut(&mut self) -> &mut S {
157        &mut self.0.get_mut().stream
158    }
159
160    fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
161    where
162        F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
163    {
164        self.0.get_mut().set_waker(Some(ctx));
165
166        let result = f(&mut self.0);
167
168        // NOTE(nox): This should also be executed when `f` panics,
169        // but it's not that important as boring segfaults on panics
170        // and we always set the context prior to doing anything with
171        // the inner async stream.
172        self.0.get_mut().set_waker(None);
173
174        result
175    }
176}
177
178impl<S> SslStream<S>
179where
180    S: AsyncRead + AsyncWrite + Unpin,
181{
182    /// Constructs an `SslStream` from a pointer to the underlying OpenSSL `SSL` struct.
183    ///
184    /// This is useful if the handshake has already been completed elsewhere.
185    ///
186    /// # Safety
187    ///
188    /// The caller must ensure the pointer is valid.
189    pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
190        Self(ssl::SslStream::from_raw_parts(
191            ssl,
192            AsyncStreamBridge::new(stream),
193        ))
194    }
195}
196
197impl<S> AsyncRead for SslStream<S>
198where
199    S: AsyncRead + AsyncWrite + Unpin,
200{
201    fn poll_read(
202        mut self: Pin<&mut Self>,
203        ctx: &mut Context<'_>,
204        buf: &mut ReadBuf,
205    ) -> Poll<io::Result<()>> {
206        self.run_in_context(ctx, |s| {
207            // SAFETY: read_uninit does not de-initialize the buffer.
208            match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
209                Poll::Ready(nread) => {
210                    unsafe {
211                        buf.assume_init(nread);
212                    }
213                    buf.advance(nread);
214                    Poll::Ready(Ok(()))
215                }
216                Poll::Pending => Poll::Pending,
217            }
218        })
219    }
220}
221
222impl<S> AsyncWrite for SslStream<S>
223where
224    S: AsyncRead + AsyncWrite + Unpin,
225{
226    fn poll_write(
227        mut self: Pin<&mut Self>,
228        ctx: &mut Context,
229        buf: &[u8],
230    ) -> Poll<io::Result<usize>> {
231        self.run_in_context(ctx, |s| cvt(s.write(buf)))
232    }
233
234    fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
235        self.run_in_context(ctx, |s| cvt(s.flush()))
236    }
237
238    fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
239        match self.run_in_context(ctx, |s| s.shutdown()) {
240            Ok(ShutdownResult::Sent | ShutdownResult::Received) => {}
241            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
242            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
243                return Poll::Pending;
244            }
245            Err(e) => {
246                return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
247            }
248        }
249
250        Pin::new(&mut self.0.get_mut().stream).poll_shutdown(ctx)
251    }
252}
253
254/// The error type returned after a failed handshake.
255pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
256
257impl<S> HandshakeError<S> {
258    /// Returns a shared reference to the `Ssl` object associated with this error.
259    #[must_use]
260    pub fn ssl(&self) -> Option<&SslRef> {
261        match &self.0 {
262            ssl::HandshakeError::Failure(s) => Some(s.ssl()),
263            _ => None,
264        }
265    }
266
267    /// Converts error to the source data stream that was used for the handshake.
268    #[must_use]
269    pub fn into_source_stream(self) -> Option<S> {
270        match self.0 {
271            ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
272            _ => None,
273        }
274    }
275
276    /// Returns a reference to the source data stream.
277    #[must_use]
278    pub fn as_source_stream(&self) -> Option<&S> {
279        match &self.0 {
280            ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
281            _ => None,
282        }
283    }
284
285    /// Returns the error code, if any.
286    #[must_use]
287    pub fn code(&self) -> Option<ErrorCode> {
288        match &self.0 {
289            ssl::HandshakeError::Failure(s) => Some(s.error().code()),
290            _ => None,
291        }
292    }
293
294    /// Returns a reference to the inner I/O error, if any.
295    #[must_use]
296    pub fn as_io_error(&self) -> Option<&io::Error> {
297        match &self.0 {
298            ssl::HandshakeError::Failure(s) => s.error().io_error(),
299            _ => None,
300        }
301    }
302}
303
304impl<S> fmt::Debug for HandshakeError<S>
305where
306    S: fmt::Debug,
307{
308    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
309        fmt::Debug::fmt(&self.0, fmt)
310    }
311}
312
313impl<S> fmt::Display for HandshakeError<S> {
314    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
315        fmt::Display::fmt(&self.0, fmt)
316    }
317}
318
319impl<S> Error for HandshakeError<S>
320where
321    S: fmt::Debug,
322{
323    fn source(&self) -> Option<&(dyn Error + 'static)> {
324        self.0.source()
325    }
326}
327
328/// Future for an ongoing TLS handshake.
329///
330/// See [`connect`] and [`accept`].
331pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
332
333impl<S> Future for HandshakeFuture<S>
334where
335    S: AsyncRead + AsyncWrite + Unpin,
336{
337    type Output = Result<SslStream<S>, HandshakeError<S>>;
338
339    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
340        let mut mid_handshake = self.0.take().expect("future polled after completion");
341
342        mid_handshake.get_mut().set_waker(Some(ctx));
343        mid_handshake
344            .ssl_mut()
345            .set_task_waker(Some(ctx.waker().clone()));
346
347        match mid_handshake.handshake() {
348            Ok(mut stream) => {
349                stream.get_mut().set_waker(None);
350                stream.ssl_mut().set_task_waker(None);
351
352                Poll::Ready(Ok(SslStream(stream)))
353            }
354            Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
355                mid_handshake.get_mut().set_waker(None);
356                mid_handshake.ssl_mut().set_task_waker(None);
357
358                self.0 = Some(mid_handshake);
359
360                Poll::Pending
361            }
362            Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
363                mid_handshake.get_mut().set_waker(None);
364
365                Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
366                    mid_handshake,
367                ))))
368            }
369            Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
370                Poll::Ready(Err(HandshakeError(err)))
371            }
372        }
373    }
374}