Skip to main content

tokio_btls/
lib.rs

1//! Async TLS streams backed by BoringSSL.
2//!
3//! This crate provides a wrapper around the [`btls`] crate's [`SslStream`](ssl::SslStream) type
4//! that works with with [`tokio`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's
5//! blocking [`Read`] and [`Write`] traits.
6//!
7//! This file reimplements tokio-btls with the [overhauled](https://github.com/sfackler/tokio-openssl/commit/56f6618ab619f3e431fa8feec2d20913bf1473aa)
8//! tokio-openssl interface while the tokio APIs from official [btls] crate is not yet caught up
9//! to it.
10
11use std::{
12    fmt, future,
13    io::{self, Read, Write},
14    pin::Pin,
15    task::{Context, Poll},
16};
17
18use btls::{
19    error::ErrorStack,
20    ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef, SslStream as SslStreamCore},
21};
22use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
23
24struct StreamWrapper<S> {
25    stream: S,
26    context: usize,
27}
28
29impl<S> fmt::Debug for StreamWrapper<S>
30where
31    S: fmt::Debug,
32{
33    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
34        fmt::Debug::fmt(&self.stream, fmt)
35    }
36}
37
38impl<S> StreamWrapper<S> {
39    /// # Safety
40    ///
41    /// Must be called with `context` set to a valid pointer to a live `Context` object, and the
42    /// wrapper must be pinned in memory.
43    unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
44        debug_assert_ne!(self.context, 0);
45        let stream = Pin::new_unchecked(&mut self.stream);
46        let context = &mut *(self.context as *mut _);
47        (stream, context)
48    }
49}
50
51impl<S> Read for StreamWrapper<S>
52where
53    S: AsyncRead,
54{
55    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
56        let (stream, cx) = unsafe { self.parts() };
57        let mut buf = ReadBuf::new(buf);
58        match stream.poll_read(cx, &mut buf)? {
59            Poll::Ready(()) => Ok(buf.filled().len()),
60            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
61        }
62    }
63}
64
65impl<S> Write for StreamWrapper<S>
66where
67    S: AsyncWrite,
68{
69    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
70        let (stream, cx) = unsafe { self.parts() };
71        match stream.poll_write(cx, buf) {
72            Poll::Ready(r) => r,
73            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
74        }
75    }
76
77    fn flush(&mut self) -> io::Result<()> {
78        let (stream, cx) = unsafe { self.parts() };
79        match stream.poll_flush(cx) {
80            Poll::Ready(r) => r,
81            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
82        }
83    }
84}
85
86fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
87    match r {
88        Ok(v) => Poll::Ready(Ok(v)),
89        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
90        Err(e) => Poll::Ready(Err(e)),
91    }
92}
93
94fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
95    match r {
96        Ok(v) => Poll::Ready(Ok(v)),
97        Err(e) => match e.code() {
98            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
99            _ => Poll::Ready(Err(e)),
100        },
101    }
102}
103
104/// An asynchronous version of [`btls::ssl::SslStream`].
105#[derive(Debug)]
106pub struct SslStream<S>(SslStreamCore<StreamWrapper<S>>);
107
108impl<S: AsyncRead + AsyncWrite> SslStream<S> {
109    #[inline]
110    /// Like [`SslStream::new`](ssl::SslStream::new).
111    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
112        SslStreamCore::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
113    }
114
115    #[inline]
116    /// Like [`SslStream::connect`](ssl::SslStream::connect).
117    pub fn poll_connect(
118        self: Pin<&mut Self>,
119        cx: &mut Context<'_>,
120    ) -> Poll<Result<(), ssl::Error>> {
121        self.with_context(cx, |s| cvt_ossl(s.connect()))
122    }
123
124    #[inline]
125    /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
126    pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
127        future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
128    }
129
130    #[inline]
131    /// Like [`SslStream::accept`](ssl::SslStream::accept).
132    pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
133        self.with_context(cx, |s| cvt_ossl(s.accept()))
134    }
135
136    #[inline]
137    /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
138    pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
139        future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
140    }
141
142    #[inline]
143    /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
144    pub fn poll_do_handshake(
145        self: Pin<&mut Self>,
146        cx: &mut Context<'_>,
147    ) -> Poll<Result<(), ssl::Error>> {
148        self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
149    }
150
151    #[inline]
152    /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
153    pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
154        future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
155    }
156}
157
158impl<S> SslStream<S> {
159    #[inline]
160    /// Returns a shared reference to the `Ssl` object associated with this stream.
161    pub fn ssl(&self) -> &SslRef {
162        self.0.ssl()
163    }
164
165    #[inline]
166    /// Returns a shared reference to the underlying stream.
167    pub fn get_ref(&self) -> &S {
168        &self.0.get_ref().stream
169    }
170
171    #[inline]
172    /// Returns a mutable reference to the underlying stream.
173    pub fn get_mut(&mut self) -> &mut S {
174        &mut self.0.get_mut().stream
175    }
176
177    #[inline]
178    /// Returns a pinned mutable reference to the underlying stream.
179    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
180        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
181    }
182
183    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
184    where
185        F: FnOnce(&mut SslStreamCore<StreamWrapper<S>>) -> R,
186    {
187        let this = unsafe { self.get_unchecked_mut() };
188        this.0.get_mut().context = ctx as *mut _ as usize;
189        let r = f(&mut this.0);
190        this.0.get_mut().context = 0;
191        r
192    }
193}
194
195impl<S> AsyncRead for SslStream<S>
196where
197    S: AsyncRead + AsyncWrite,
198{
199    fn poll_read(
200        self: Pin<&mut Self>,
201        ctx: &mut Context<'_>,
202        buf: &mut ReadBuf<'_>,
203    ) -> Poll<io::Result<()>> {
204        self.with_context(ctx, |s| {
205            // SAFETY: read_uninit does not de-initialize the buffer.
206            match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
207                Poll::Ready(nread) => {
208                    // SAFETY: read_uninit guarantees that nread bytes have been initialized.
209                    unsafe { buf.assume_init(nread) };
210                    buf.advance(nread);
211                    Poll::Ready(Ok(()))
212                }
213                Poll::Pending => Poll::Pending,
214            }
215        })
216    }
217}
218
219impl<S> AsyncWrite for SslStream<S>
220where
221    S: AsyncRead + AsyncWrite,
222{
223    #[inline]
224    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
225        self.with_context(ctx, |s| cvt(s.write(buf)))
226    }
227
228    #[inline]
229    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
230        self.with_context(ctx, |s| cvt(s.flush()))
231    }
232
233    fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
234        match self.as_mut().with_context(ctx, |s| s.shutdown()) {
235            Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
236            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
237            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
238                return Poll::Pending;
239            }
240            Err(e) => {
241                return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
242            }
243        }
244
245        self.get_pin_mut().poll_shutdown(ctx)
246    }
247}