Skip to main content

tokio_boring/
lib.rs

1//! Async TLS streams backed by BoringSSL
2//!
3//! This library is an implementation of TLS streams using BoringSSL for
4//! negotiating the connection. Each TLS stream implements the `Read` and
5//! `Write` traits to interact and interoperate with the rest of the futures I/O
6//! ecosystem. Client connections initiated from this crate verify hostnames
7//! automatically and by default.
8//!
9//! `tokio-boring` exports this ability through [`accept`] and [`connect`]. `accept` should
10//! be used by servers, and `connect` by clients. These augment the functionality provided by the
11//! [`boring`] crate, on which this crate is built. Configuration of TLS parameters is still
12//! primarily done through the [`boring`] crate.
13#![warn(missing_docs)]
14
15use boring::ssl::{
16    self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor,
17    SslRef,
18};
19use boring_sys as ffi;
20use std::error::Error;
21use std::fmt;
22use std::future::Future;
23use std::io::{self, Write};
24use std::pin::Pin;
25use std::task::{Context, Poll};
26use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
27
28mod async_callbacks;
29mod bridge;
30
31use self::bridge::AsyncStreamBridge;
32
33pub use crate::async_callbacks::SslContextBuilderExt;
34pub use boring::ssl::{
35    AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
36    BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
37    BoxSelectCertFuture, ExDataFuture,
38};
39
40/// Asynchronously performs a client-side TLS handshake over the provided stream.
41///
42/// This function automatically sets the task waker on the `Ssl` from `config` to
43/// allow to make use of async callbacks provided by the boring crate.
44pub async fn connect<S>(
45    config: ConnectConfiguration,
46    domain: &str,
47    stream: S,
48) -> Result<SslStream<S>, HandshakeError<S>>
49where
50    S: AsyncRead + AsyncWrite + Unpin,
51{
52    let mid_handshake = config
53        .setup_connect(domain, AsyncStreamBridge::new(stream))
54        .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
55
56    HandshakeFuture(Some(mid_handshake)).await
57}
58
59/// Asynchronously performs a server-side TLS handshake over the provided stream.
60///
61/// This function automatically sets the task waker on the `Ssl` from `config` to
62/// allow to make use of async callbacks provided by the boring crate.
63pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
64where
65    S: AsyncRead + AsyncWrite + Unpin,
66{
67    let mid_handshake = acceptor
68        .setup_accept(AsyncStreamBridge::new(stream))
69        .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
70
71    HandshakeFuture(Some(mid_handshake)).await
72}
73
74fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
75    match r {
76        Ok(v) => Poll::Ready(Ok(v)),
77        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
78        Err(e) => Poll::Ready(Err(e)),
79    }
80}
81
82/// A partially constructed `SslStream`, useful for unusual handshakes.
83pub struct SslStreamBuilder<S> {
84    inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
85}
86
87impl<S> SslStreamBuilder<S>
88where
89    S: AsyncRead + AsyncWrite + Unpin,
90{
91    /// Begins creating an `SslStream` atop `stream`.
92    pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
93        Self {
94            inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
95        }
96    }
97
98    /// Initiates a client-side TLS handshake.
99    pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
100        let mid_handshake = self.inner.setup_accept();
101
102        HandshakeFuture(Some(mid_handshake)).await
103    }
104
105    /// Initiates a server-side TLS handshake.
106    pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
107        let mid_handshake = self.inner.setup_connect();
108
109        HandshakeFuture(Some(mid_handshake)).await
110    }
111}
112
113impl<S> SslStreamBuilder<S> {
114    /// Returns a shared reference to the `Ssl` object associated with this builder.
115    #[must_use]
116    pub fn ssl(&self) -> &SslRef {
117        self.inner.ssl()
118    }
119
120    /// Returns a mutable reference to the `Ssl` object associated with this builder.
121    pub fn ssl_mut(&mut self) -> &mut SslRef {
122        self.inner.ssl_mut()
123    }
124}
125
126/// A wrapper around an underlying raw stream which implements the SSL
127/// protocol.
128///
129/// A `SslStream<S>` represents a handshake that has been completed successfully
130/// and both the server and the client are ready for receiving and sending
131/// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written
132/// to a `SslStream` are encrypted when passing through to `S`.
133#[derive(Debug)]
134pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
135
136impl<S> SslStream<S> {
137    /// Returns a shared reference to the `Ssl` object associated with this stream.
138    #[must_use]
139    pub fn ssl(&self) -> &SslRef {
140        self.0.ssl()
141    }
142
143    /// Returns a mutable reference to the `Ssl` object associated with this stream.
144    pub fn ssl_mut(&mut self) -> &mut SslRef {
145        self.0.ssl_mut()
146    }
147
148    /// Returns a shared reference to the underlying stream.
149    #[must_use]
150    pub fn get_ref(&self) -> &S {
151        &self.0.get_ref().stream
152    }
153
154    /// Returns a mutable reference to the underlying stream.
155    pub fn get_mut(&mut self) -> &mut S {
156        &mut self.0.get_mut().stream
157    }
158
159    fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
160    where
161        F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
162    {
163        self.0.get_mut().set_waker(Some(ctx));
164
165        let result = f(&mut self.0);
166
167        // NOTE(nox): This should also be executed when `f` panics,
168        // but it's not that important as boring segfaults on panics
169        // and we always set the context prior to doing anything with
170        // the inner async stream.
171        self.0.get_mut().set_waker(None);
172
173        result
174    }
175}
176
177impl<S> SslStream<S>
178where
179    S: AsyncRead + AsyncWrite + Unpin,
180{
181    /// Constructs an `SslStream` from a pointer to the underlying OpenSSL `SSL` struct.
182    ///
183    /// This is useful if the handshake has already been completed elsewhere.
184    ///
185    /// # Safety
186    ///
187    /// The caller must ensure the pointer is valid.
188    pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
189        Self(ssl::SslStream::from_raw_parts(
190            ssl,
191            AsyncStreamBridge::new(stream),
192        ))
193    }
194}
195
196impl<S> AsyncRead for SslStream<S>
197where
198    S: AsyncRead + AsyncWrite + Unpin,
199{
200    fn poll_read(
201        mut self: Pin<&mut Self>,
202        ctx: &mut Context<'_>,
203        buf: &mut ReadBuf,
204    ) -> Poll<io::Result<()>> {
205        self.run_in_context(ctx, |s| {
206            // SAFETY: read_uninit does not de-initialize the buffer.
207            match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
208                Poll::Ready(nread) => {
209                    unsafe {
210                        buf.assume_init(nread);
211                    }
212                    buf.advance(nread);
213                    Poll::Ready(Ok(()))
214                }
215                Poll::Pending => Poll::Pending,
216            }
217        })
218    }
219}
220
221impl<S> AsyncWrite for SslStream<S>
222where
223    S: AsyncRead + AsyncWrite + Unpin,
224{
225    fn poll_write(
226        mut self: Pin<&mut Self>,
227        ctx: &mut Context,
228        buf: &[u8],
229    ) -> Poll<io::Result<usize>> {
230        self.run_in_context(ctx, |s| cvt(s.write(buf)))
231    }
232
233    fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
234        self.run_in_context(ctx, |s| cvt(s.flush()))
235    }
236
237    fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
238        match self.run_in_context(ctx, |s| s.shutdown()) {
239            Ok(ShutdownResult::Sent | ShutdownResult::Received) => {}
240            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
241            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
242                return Poll::Pending;
243            }
244            Err(e) => {
245                return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
246            }
247        }
248
249        Pin::new(&mut self.0.get_mut().stream).poll_shutdown(ctx)
250    }
251}
252
253/// The error type returned after a failed handshake.
254pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
255
256impl<S> HandshakeError<S> {
257    /// Returns a shared reference to the `Ssl` object associated with this error.
258    #[must_use]
259    pub fn ssl(&self) -> Option<&SslRef> {
260        match &self.0 {
261            ssl::HandshakeError::Failure(s) => Some(s.ssl()),
262            _ => None,
263        }
264    }
265
266    /// Converts error to the source data stream that was used for the handshake.
267    #[must_use]
268    pub fn into_source_stream(self) -> Option<S> {
269        match self.0 {
270            ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
271            _ => None,
272        }
273    }
274
275    /// Returns a reference to the source data stream.
276    #[must_use]
277    pub fn as_source_stream(&self) -> Option<&S> {
278        match &self.0 {
279            ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
280            _ => None,
281        }
282    }
283
284    /// Returns the error code, if any.
285    #[must_use]
286    pub fn code(&self) -> Option<ErrorCode> {
287        match &self.0 {
288            ssl::HandshakeError::Failure(s) => Some(s.error().code()),
289            _ => None,
290        }
291    }
292
293    /// Returns a reference to the inner I/O error, if any.
294    #[must_use]
295    pub fn as_io_error(&self) -> Option<&io::Error> {
296        match &self.0 {
297            ssl::HandshakeError::Failure(s) => s.error().io_error(),
298            _ => None,
299        }
300    }
301}
302
303impl<S> fmt::Debug for HandshakeError<S>
304where
305    S: fmt::Debug,
306{
307    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
308        fmt::Debug::fmt(&self.0, fmt)
309    }
310}
311
312impl<S> fmt::Display for HandshakeError<S> {
313    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
314        fmt::Display::fmt(&self.0, fmt)
315    }
316}
317
318impl<S> Error for HandshakeError<S>
319where
320    S: fmt::Debug,
321{
322    fn source(&self) -> Option<&(dyn Error + 'static)> {
323        self.0.source()
324    }
325}
326
327/// Future for an ongoing TLS handshake.
328///
329/// See [`connect`] and [`accept`].
330pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
331
332impl<S> Future for HandshakeFuture<S>
333where
334    S: AsyncRead + AsyncWrite + Unpin,
335{
336    type Output = Result<SslStream<S>, HandshakeError<S>>;
337
338    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
339        let mut mid_handshake = self.0.take().expect("future polled after completion");
340
341        mid_handshake.get_mut().set_waker(Some(ctx));
342        mid_handshake
343            .ssl_mut()
344            .set_task_waker(Some(ctx.waker().clone()));
345
346        match mid_handshake.handshake() {
347            Ok(mut stream) => {
348                stream.get_mut().set_waker(None);
349                stream.ssl_mut().set_task_waker(None);
350
351                Poll::Ready(Ok(SslStream(stream)))
352            }
353            Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
354                mid_handshake.get_mut().set_waker(None);
355                mid_handshake.ssl_mut().set_task_waker(None);
356
357                self.0 = Some(mid_handshake);
358
359                Poll::Pending
360            }
361            Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
362                mid_handshake.get_mut().set_waker(None);
363
364                Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
365                    mid_handshake,
366                ))))
367            }
368            Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
369                Poll::Ready(Err(HandshakeError(err)))
370            }
371        }
372    }
373}