tls_listener/
spawning_handshake.rs

1use super::AsyncTls;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio::io::{AsyncRead, AsyncWrite};
6use tokio::task::JoinHandle;
7
8/// Convert an [`AsyncTls`] into one that will spawn a new task for each new connection.
9///
10/// This will wrap each call to [`accept`](AsyncTls::accept) with a call to [`tokio::spawn`]. This
11/// is especially useful when using a multi-threaded runtime, so that the TLS handshakes
12/// are distributed between multiple threads.
13#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
14#[derive(Clone, Debug)]
15pub struct SpawningHandshakes<T>(pub T);
16
17impl<C, T> AsyncTls<C> for SpawningHandshakes<T>
18where
19    T: AsyncTls<C>,
20    C: AsyncRead + AsyncWrite,
21    T::AcceptFuture: Send + 'static,
22    T::Stream: Send + 'static,
23    T::Error: Send + 'static,
24{
25    type Stream = T::Stream;
26    type Error = T::Error;
27    type AcceptFuture = HandshakeJoin<T::Stream, T::Error>;
28
29    fn accept(&self, stream: C) -> Self::AcceptFuture {
30        HandshakeJoin(tokio::spawn(self.0.accept(stream)))
31    }
32}
33
34/// Future type returned by [`SpawningHandshakeTls::accept`];
35#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
36pub struct HandshakeJoin<Stream, Error>(JoinHandle<Result<Stream, Error>>);
37
38impl<Stream, Error> Future for HandshakeJoin<Stream, Error> {
39    type Output = Result<Stream, Error>;
40    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
41        match Pin::new(&mut self.as_mut().0).poll(cx) {
42            Poll::Ready(Ok(v)) => Poll::Ready(v),
43            Poll::Pending => Poll::Pending,
44            Poll::Ready(Err(e)) => {
45                if e.is_panic() {
46                    std::panic::resume_unwind(e.into_panic());
47                } else {
48                    unreachable!("Tls handshake was aborted: {:?}", e);
49                }
50            }
51        }
52    }
53}
54
55impl<Stream, Error> Drop for HandshakeJoin<Stream, Error> {
56    fn drop(&mut self) {
57        self.0.abort();
58    }
59}