s2n_tls_tokio/
lib.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use errno::{set_errno, Errno};
5use s2n_tls::{
6    config::Config,
7    connection::{Builder, Connection},
8    enums::{Blinding, CallbackResult, Mode},
9    error::Error,
10};
11use std::{
12    fmt,
13    future::Future,
14    io,
15    os::raw::{c_int, c_void},
16    pin::Pin,
17    task::{
18        Context, Poll,
19        Poll::{Pending, Ready},
20    },
21};
22use tokio::{
23    io::{AsyncRead, AsyncWrite, ReadBuf},
24    time::{sleep, Duration, Sleep},
25};
26
27// TODO use the version from s2n_quic_core
28mod task;
29use task::waker::debug_assert_contract as debug_assert_waker_contract;
30
31macro_rules! ready {
32    ($x:expr) => {
33        match $x {
34            Ready(r) => r,
35            Pending => return Pending,
36        }
37    };
38}
39
40#[derive(Clone)]
41pub struct TlsAcceptor<B: Builder = Config>
42where
43    <B as Builder>::Output: Unpin,
44{
45    builder: B,
46}
47
48impl<B: Builder> TlsAcceptor<B>
49where
50    <B as Builder>::Output: Unpin,
51{
52    pub fn new(builder: B) -> Self {
53        TlsAcceptor { builder }
54    }
55
56    pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S, B::Output>, Error>
57    where
58        S: AsyncRead + AsyncWrite + Unpin,
59    {
60        let conn = self.builder.build_connection(Mode::Server)?;
61        TlsStream::open(conn, stream).await
62    }
63}
64
65#[derive(Clone)]
66pub struct TlsConnector<B: Builder = Config>
67where
68    <B as Builder>::Output: Unpin,
69{
70    builder: B,
71}
72
73impl<B: Builder> TlsConnector<B>
74where
75    <B as Builder>::Output: Unpin,
76{
77    pub fn new(builder: B) -> Self {
78        TlsConnector { builder }
79    }
80
81    pub async fn connect<S>(
82        &self,
83        domain: &str,
84        stream: S,
85    ) -> Result<TlsStream<S, B::Output>, Error>
86    where
87        S: AsyncRead + AsyncWrite + Unpin,
88    {
89        let mut conn = self.builder.build_connection(Mode::Client)?;
90        conn.as_mut().set_server_name(domain)?;
91        TlsStream::open(conn, stream).await
92    }
93}
94
95struct TlsHandshake<'a, S, C>
96where
97    C: AsRef<Connection> + AsMut<Connection> + Unpin,
98    S: AsyncRead + AsyncWrite + Unpin,
99{
100    tls: &'a mut TlsStream<S, C>,
101    error: Option<Error>,
102}
103
104impl<S, C> Future for TlsHandshake<'_, S, C>
105where
106    C: AsRef<Connection> + AsMut<Connection> + Unpin,
107    S: AsyncRead + AsyncWrite + Unpin,
108{
109    type Output = Result<(), Error>;
110
111    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
112        debug_assert_waker_contract(ctx, |ctx| {
113            // Retrieve a result, either from the stored error
114            // or by polling Connection::poll_negotiate().
115            // Connection::poll_negotiate() only completes once,
116            // regardless of how often this method is polled.
117            let result = match self.error.take() {
118                Some(err) => Err(err),
119                None => {
120                    let handshake_poll = self.tls.with_io(ctx, |context| {
121                        let conn = context.get_mut().as_mut();
122                        conn.poll_negotiate().map(|r| r.map(|_| ()))
123                    });
124                    ready!(handshake_poll)
125                }
126            };
127            // If the result isn't a fatal error, return it immediately.
128            // Otherwise, poll Connection::poll_shutdown().
129            //
130            // Shutdown is only best-effort.
131            // When Connection::poll_shutdown() completes, even with an error,
132            // we return the original Connection::poll_negotiate() error.
133            match result {
134                Ok(r) => Ok(r).into(),
135                Err(e) if e.is_retryable() => Err(e).into(),
136                Err(e) => match Pin::new(&mut self.tls).poll_shutdown(ctx) {
137                    Pending => {
138                        self.error = Some(e);
139                        Pending
140                    }
141                    Ready(_) => Err(e).into(),
142                },
143            }
144        })
145    }
146}
147
148pub struct TlsStream<S, C = Connection>
149where
150    C: AsRef<Connection> + AsMut<Connection> + Unpin,
151    S: AsyncRead + AsyncWrite + Unpin,
152{
153    conn: C,
154    stream: S,
155    blinding: Option<Pin<Box<Sleep>>>,
156    shutdown_error: Option<Error>,
157}
158
159impl<S, C> TlsStream<S, C>
160where
161    C: AsRef<Connection> + AsMut<Connection> + Unpin,
162    S: AsyncRead + AsyncWrite + Unpin,
163{
164    ///Access a shared reference to the underlaying io stream
165    pub fn get_ref(&self) -> &S {
166        &self.stream
167    }
168
169    ///Access the mutable reference to the underlaying io stream
170    pub fn get_mut(&mut self) -> &mut S {
171        &mut self.stream
172    }
173
174    async fn open(conn: C, stream: S) -> Result<Self, Error> {
175        let mut tls = TlsStream {
176            conn,
177            stream,
178            blinding: None,
179            shutdown_error: None,
180        };
181        TlsHandshake {
182            tls: &mut tls,
183            error: None,
184        }
185        .await?;
186        Ok(tls)
187    }
188
189    fn with_io<F, R>(&mut self, ctx: &mut Context, action: F) -> Poll<Result<R, Error>>
190    where
191        F: FnOnce(Pin<&mut Self>) -> Poll<Result<R, Error>>,
192    {
193        // Setting contexts on a connection is considered unsafe
194        // because the raw pointers provide no lifetime or memory guarantees.
195        // We protect against this by pinning the stream during the action
196        // and clearing the context afterwards.
197        unsafe {
198            let context = self as *mut Self as *mut c_void;
199
200            self.as_mut().set_receive_callback(Some(Self::recv_io_cb))?;
201            self.as_mut().set_send_callback(Some(Self::send_io_cb))?;
202            self.as_mut().set_receive_context(context)?;
203            self.as_mut().set_send_context(context)?;
204            self.as_mut().set_waker(Some(ctx.waker()))?;
205            self.as_mut().set_blinding(Blinding::SelfService)?;
206
207            let result = action(Pin::new(self));
208
209            self.as_mut().set_receive_callback(None)?;
210            self.as_mut().set_send_callback(None)?;
211            self.as_mut().set_receive_context(std::ptr::null_mut())?;
212            self.as_mut().set_send_context(std::ptr::null_mut())?;
213            self.as_mut().set_waker(None)?;
214            result
215        }
216    }
217
218    fn poll_io<F>(ctx: *mut c_void, action: F) -> c_int
219    where
220        F: FnOnce(Pin<&mut S>, &mut Context) -> Poll<Result<usize, std::io::Error>>,
221    {
222        debug_assert_ne!(ctx, std::ptr::null_mut());
223        let tls = unsafe { &mut *(ctx as *mut Self) };
224
225        let mut async_context = Context::from_waker(tls.conn.as_ref().waker().unwrap());
226        let stream = Pin::new(&mut tls.stream);
227
228        let res = debug_assert_waker_contract(&mut async_context, |async_context| {
229            action(stream, async_context)
230        });
231
232        match res {
233            Poll::Ready(Ok(len)) => len as c_int,
234            Poll::Pending => {
235                set_errno(Errno(libc::EWOULDBLOCK));
236                CallbackResult::Failure.into()
237            }
238            _ => CallbackResult::Failure.into(),
239        }
240    }
241
242    unsafe extern "C" fn recv_io_cb(ctx: *mut c_void, buf: *mut u8, len: u32) -> c_int {
243        Self::poll_io(ctx, |stream, async_context| {
244            let mut dest = ReadBuf::new(std::slice::from_raw_parts_mut(buf, len as usize));
245            stream
246                .poll_read(async_context, &mut dest)
247                .map_ok(|_| dest.filled().len())
248        })
249    }
250
251    unsafe extern "C" fn send_io_cb(ctx: *mut c_void, buf: *const u8, len: u32) -> c_int {
252        Self::poll_io(ctx, |stream, async_context| {
253            let src = std::slice::from_raw_parts(buf, len as usize);
254            stream.poll_write(async_context, src)
255        })
256    }
257
258    /// Polls the blinding timer, if there is any.
259    ///
260    /// s2n has a "blinding" functionality - when a bad behavior from the peer
261    /// is detected, sleeps for 10-30 seconds before answering the client
262    /// and closing the connection. This mitigates some timing side channels
263    /// that could leak information about encrypted data. See the
264    /// `s2n_connection_set_blinding` docs for more details.
265    ///
266    /// For security reasons, to allow for blinding to correctly function,
267    /// before dropping an s2n connection, you should wait until either
268    /// `poll_blinding` or `poll_shutdown` (which calls `poll_blinding`
269    /// internally) returns ready.
270    pub fn poll_blinding(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Error>> {
271        debug_assert_waker_contract(ctx, |ctx| {
272            let tls = self.get_mut();
273
274            if tls.blinding.is_none() {
275                let delay = tls.as_ref().remaining_blinding_delay()?;
276                if !delay.is_zero() {
277                    // Sleep operates at the milisecond resolution, so add an extra
278                    // millisecond to account for any stray nanoseconds.
279                    let safety = Duration::from_millis(1);
280                    tls.blinding = Some(Box::pin(sleep(delay.saturating_add(safety))));
281                }
282            };
283
284            if let Some(timer) = tls.blinding.as_mut() {
285                ready!(timer.as_mut().poll(ctx));
286                tls.blinding = None;
287            }
288
289            Poll::Ready(Ok(()))
290        })
291    }
292
293    pub async fn apply_blinding(&mut self) -> Result<(), Error> {
294        ApplyBlinding { stream: self }.await
295    }
296}
297
298impl<S, C> AsRef<Connection> for TlsStream<S, C>
299where
300    C: AsRef<Connection> + AsMut<Connection> + Unpin,
301    S: AsyncRead + AsyncWrite + Unpin,
302{
303    fn as_ref(&self) -> &Connection {
304        self.conn.as_ref()
305    }
306}
307
308impl<S, C> AsMut<Connection> for TlsStream<S, C>
309where
310    C: AsRef<Connection> + AsMut<Connection> + Unpin,
311    S: AsyncRead + AsyncWrite + Unpin,
312{
313    fn as_mut(&mut self) -> &mut Connection {
314        self.conn.as_mut()
315    }
316}
317
318impl<S, C> AsyncRead for TlsStream<S, C>
319where
320    C: AsRef<Connection> + AsMut<Connection> + Unpin,
321    S: AsyncRead + AsyncWrite + Unpin,
322{
323    fn poll_read(
324        self: Pin<&mut Self>,
325        ctx: &mut Context<'_>,
326        buf: &mut ReadBuf<'_>,
327    ) -> Poll<io::Result<()>> {
328        let tls = self.get_mut();
329        tls.with_io(ctx, |mut context| {
330            context
331                .conn
332                .as_mut()
333                // Safe since poll_recv_uninitialized does not
334                // deinitialize any bytes.
335                .poll_recv_uninitialized(unsafe { buf.unfilled_mut() })
336                .map_ok(|size| {
337                    unsafe {
338                        // Safe since poll_recv_uninitialized guaranteed
339                        // us that the first `size` bytes have been
340                        // initialized.
341                        buf.assume_init(size);
342                    }
343                    buf.advance(size);
344                })
345        })
346        .map_err(io::Error::from)
347    }
348}
349
350impl<S, C> AsyncWrite for TlsStream<S, C>
351where
352    C: AsRef<Connection> + AsMut<Connection> + Unpin,
353    S: AsyncRead + AsyncWrite + Unpin,
354{
355    fn poll_write(
356        self: Pin<&mut Self>,
357        ctx: &mut Context<'_>,
358        buf: &[u8],
359    ) -> Poll<io::Result<usize>> {
360        let tls = self.get_mut();
361        tls.with_io(ctx, |mut context| context.conn.as_mut().poll_send(buf))
362            .map_err(io::Error::from)
363    }
364
365    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
366        let tls = self.get_mut();
367
368        ready!(tls.with_io(ctx, |mut context| {
369            context.conn.as_mut().poll_flush().map(|r| r.map(|_| ()))
370        }))
371        .map_err(io::Error::from)?;
372
373        Pin::new(&mut tls.stream).poll_flush(ctx)
374    }
375
376    fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
377        debug_assert_waker_contract(ctx, |ctx| {
378            ready!(self.as_mut().poll_blinding(ctx))?;
379
380            // s2n_shutdown_send must not be called again if it errors
381            if self.shutdown_error.is_none() {
382                let result = ready!(self.as_mut().with_io(ctx, |mut context| {
383                    context
384                        .conn
385                        .as_mut()
386                        .poll_shutdown_send()
387                        .map(|r| r.map(|_| ()))
388                }));
389                if let Err(error) = result {
390                    self.shutdown_error = Some(error);
391                    // s2n_shutdown_send only writes, so will never trigger blinding again.
392                    // So we do not need to poll_blinding again after this error.
393                }
394            };
395
396            let tcp_result = ready!(Pin::new(&mut self.as_mut().stream).poll_shutdown(ctx));
397
398            if let Some(err) = self.shutdown_error.take() {
399                // poll methods shouldn't be called again after returning Ready, but
400                // nothing actually prevents it so poll_shutdown should handle it.
401                // s2n_shutdown can be polled indefinitely after succeeding, but not after failing.
402                // s2n_tls::error::Error isn't cloneable, so we can't just return the same error
403                // if poll_shutdown is called again. Instead, save a different error.
404                let next_error = Error::application("Shutdown called again after error".into());
405                self.shutdown_error = Some(next_error);
406
407                Ready(Err(io::Error::from(err)))
408            } else {
409                Ready(tcp_result)
410            }
411        })
412    }
413}
414
415impl<S, C> fmt::Debug for TlsStream<S, C>
416where
417    C: AsRef<Connection> + AsMut<Connection> + Unpin,
418    S: AsyncRead + AsyncWrite + Unpin,
419{
420    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
421        f.debug_struct("TlsStream")
422            .field("connection", self.as_ref())
423            .finish()
424    }
425}
426
427struct ApplyBlinding<'a, S, C>
428where
429    C: AsRef<Connection> + AsMut<Connection> + Unpin,
430    S: AsyncRead + AsyncWrite + Unpin,
431{
432    stream: &'a mut TlsStream<S, C>,
433}
434
435impl<S, C> Future for ApplyBlinding<'_, S, C>
436where
437    C: AsRef<Connection> + AsMut<Connection> + Unpin,
438    S: AsyncRead + AsyncWrite + Unpin,
439{
440    type Output = Result<(), Error>;
441
442    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
443        Pin::new(&mut *self.as_mut().stream).poll_blinding(ctx)
444    }
445}