tls_listener/
lib.rs

1#![deny(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! Async TLS listener
5//!
6//! This library is intended to automatically initiate a TLS connection
7//! for each new connection in a source of new streams (such as a listening
8//! TCP or unix domain socket).
9//!
10//! # Features:
11//! - `tokio-net`: Implementations for tokio socket types (default)
12//! - `rt`: Features that depend on the tokio runtime, such as [`SpawningHandshakes`]
13//! - `rustls-core`: Support the tokio-rustls backend for tls
14//! - `rustls-aws-lc`: Include the aws-lc provider for rustls
15//! - `rustls-ring`: Include the ring provider for rustls
16//! - `rustls-fips`: Include enabling the "fips" feature for rustls
17//! - `native-tls`: support the tokio-native-tls backend for tls
18
19use futures_util::stream::{FuturesUnordered, Stream, StreamExt, TryStreamExt};
20use pin_project_lite::pin_project;
21#[cfg(feature = "rt")]
22pub use spawning_handshake::SpawningHandshakes;
23use std::fmt::Debug;
24use std::future::{poll_fn, Future};
25use std::num::NonZeroUsize;
26use std::pin::Pin;
27use std::task::{ready, Context, Poll};
28use std::time::Duration;
29use thiserror::Error;
30use tokio::io::{AsyncRead, AsyncWrite};
31use tokio::time::{timeout, Timeout};
32#[cfg(feature = "native-tls")]
33pub use tokio_native_tls as native_tls;
34#[cfg(feature = "openssl")]
35pub use tokio_openssl as openssl;
36#[cfg(feature = "rustls-core")]
37pub use tokio_rustls as rustls;
38
39mod accept;
40#[cfg(feature = "tokio-net")]
41mod net;
42#[cfg(feature = "rt")]
43mod spawning_handshake;
44
45pub use accept::*;
46
47#[cfg(feature = "axum")]
48mod axum;
49
50/// Default number of connections to accept in a batch before trying to
51pub const DEFAULT_ACCEPT_BATCH_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(64) };
52/// Default timeout for the TLS handshake.
53pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
54
55/// Trait for TLS implementation.
56///
57/// Implementations are provided by the rustls and native-tls features.
58pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
59    /// The type of the TLS stream created from the underlying stream.
60    type Stream;
61    /// Error type for completing the TLS handshake
62    type Error: std::error::Error;
63    /// Type of the Future for the TLS stream that is accepted.
64    type AcceptFuture: Future<Output = Result<Self::Stream, Self::Error>>;
65
66    /// Accept a TLS connection on an underlying stream
67    fn accept(&self, stream: C) -> Self::AcceptFuture;
68}
69
70pin_project! {
71    ///
72    /// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
73    /// encrypted using TLS.
74    ///
75    /// It is similar to:
76    ///
77    /// ```ignore
78    /// tcpListener.and_then(|s| tlsAcceptor.accept(s))
79    /// ```
80    ///
81    /// except that it has the ability to accept multiple transport-level connections
82    /// simultaneously while the TLS handshake is pending for other connections.
83    ///
84    /// By default, if a client fails the TLS handshake, that is treated as an error, and the
85    /// `TlsListener` will return an `Err`. If the error is not handled, then an invalid handshake can
86    /// cause the server to stop accepting connections.
87    /// See [`http-stream.rs`][2] or [`http-low-level`][3] examples, for examples of how to avoid this.
88    ///
89    /// Note that if the maximum number of pending connections is greater than 1, the resulting
90    /// [`T::Stream`][4] connections may come in a different order than the connections produced by the
91    /// underlying listener.
92    ///
93    /// [2]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-stream.rs
94    /// [3]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-low-level.rs
95    /// [4]: AsyncTls::Stream
96    ///
97    pub struct TlsListener<A: AsyncAccept, T: AsyncTls<A::Connection>> {
98        #[pin]
99        listener: A,
100        tls: T,
101        waiting: FuturesUnordered<Waiting<A, T>>,
102        accept_batch_size: NonZeroUsize,
103        timeout: Duration,
104    }
105}
106
107/// Builder for `TlsListener`.
108#[derive(Clone)]
109pub struct Builder<T> {
110    tls: T,
111    accept_batch_size: NonZeroUsize,
112    handshake_timeout: Duration,
113}
114
115/// Wraps errors from either the listener or the TLS Acceptor
116#[derive(Debug, Error)]
117#[non_exhaustive]
118pub enum Error<LE: std::error::Error, TE: std::error::Error, Addr> {
119    /// An error that arose from the listener ([AsyncAccept::Error])
120    #[error("{0}")]
121    ListenerError(#[source] LE),
122    /// An error that occurred during the TLS accept handshake
123    #[error("{error}")]
124    #[non_exhaustive]
125    TlsAcceptError {
126        /// The original error that occurred
127        #[source]
128        error: TE,
129
130        /// Address of the other side of the connection
131        peer_addr: Addr,
132    },
133    /// The TLS handshake timed out
134    #[error("Timeout during TLS handshake")]
135    #[non_exhaustive]
136    HandshakeTimeout {
137        /// Address of the other side of the connection
138        peer_addr: Addr,
139    },
140}
141
142impl<A: AsyncAccept, T> TlsListener<A, T>
143where
144    T: AsyncTls<A::Connection>,
145{
146    /// Create a `TlsListener` with default options.
147    pub fn new(tls: T, listener: A) -> Self {
148        builder(tls).listen(listener)
149    }
150}
151
152/// Convenience type alias to get the proper error type from the type of the [`AsyncAccept`] and
153/// [`AsyncTls`] used.
154type TlsListenerError<A, T> = Error<
155    <A as AsyncAccept>::Error,
156    <T as AsyncTls<<A as AsyncAccept>::Connection>>::Error,
157    <A as AsyncAccept>::Address,
158>;
159
160impl<A, T> TlsListener<A, T>
161where
162    A: AsyncAccept,
163    T: AsyncTls<A::Connection>,
164{
165    /// Poll accepting a connection.
166    ///
167    /// This will return ready once the TLS handshake has completed on an incoming
168    /// connection and return the connection and the source address.
169    pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
170        let mut this = self.project();
171
172        loop {
173            let mut empty_listener = false;
174            for _ in 0..this.accept_batch_size.get() {
175                match this.listener.as_mut().poll_accept(cx) {
176                    Poll::Pending => {
177                        empty_listener = true;
178                        break;
179                    }
180                    Poll::Ready(Ok((conn, addr))) => {
181                        this.waiting.push(Waiting {
182                            inner: timeout(*this.timeout, this.tls.accept(conn)),
183                            peer_addr: Some(addr),
184                        });
185                    }
186                    Poll::Ready(Err(e)) => {
187                        return Poll::Ready(Err(Error::ListenerError(e)));
188                    }
189                }
190            }
191
192            match this.waiting.poll_next_unpin(cx) {
193                Poll::Ready(Some(result)) => return Poll::Ready(result),
194                // If we don't have anything waiting yet,
195                // then we are still pending,
196                Poll::Ready(None) | Poll::Pending => {
197                    if empty_listener {
198                        return Poll::Pending;
199                    }
200                }
201            }
202        }
203    }
204
205    /// Accept the next connection
206    ///
207    /// This is similar to `self.next()`, but doesn't return an `Option` because
208    /// there isn't an end condition on accepting connections,
209    /// and has a more domain-appropriate name.
210    ///
211    /// The future returned is "cancellation safe".
212    pub fn accept(&mut self) -> impl Future<Output = <Self as Stream>::Item> + '_
213    where
214        Self: Unpin,
215    {
216        let mut pinned = Pin::new(self);
217        poll_fn(move |cx| pinned.as_mut().poll_accept(cx))
218    }
219
220    /// Replaces the Tls Acceptor configuration, which will be used for new connections.
221    ///
222    /// This can be used to change the certificate used at runtime.
223    pub fn replace_acceptor(&mut self, acceptor: T) {
224        self.tls = acceptor;
225    }
226
227    /// Replaces the Tls Acceptor configuration from a pinned reference to `Self`.
228    ///
229    /// This is useful if your listener is `!Unpin`.
230    ///
231    /// This can be used to change the certificate used at runtime.
232    pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) {
233        *self.project().tls = acceptor;
234    }
235
236    /// Convert into a Stream of connections.
237    ///
238    /// This drops the address of the connection, but provides a more convenient API
239    /// if the address isn't needed.
240    ///
241    /// The address will still be included in errors.
242    pub fn connections(self) -> impl Stream<Item = Result<T::Stream, TlsListenerError<A, T>>> {
243        self.map_ok(|(conn, _addr)| conn)
244    }
245
246    /// Get a reference to the underlying connection listener
247    ///
248    /// Can be useful to get metadata about the listener, such as the
249    /// local address.
250    pub fn listener(&self) -> &A {
251        &self.listener
252    }
253
254    /// Get the local address of the underlying listener
255    pub fn local_addr(&self) -> Result<A::Address, A::Error>
256    where
257        A: AsyncListener,
258    {
259        self.listener.local_addr()
260    }
261}
262
263impl<A, T> Stream for TlsListener<A, T>
264where
265    A: AsyncAccept,
266    T: AsyncTls<A::Connection>,
267{
268    type Item = Result<(T::Stream, A::Address), TlsListenerError<A, T>>;
269
270    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271        self.poll_accept(cx).map(Some)
272    }
273}
274
275#[cfg(feature = "rustls-core")]
276#[cfg_attr(docsrs, doc(cfg(feature = "rustls-core")))]
277impl<C: AsyncRead + AsyncWrite + Unpin> AsyncTls<C> for tokio_rustls::TlsAcceptor {
278    type Stream = tokio_rustls::server::TlsStream<C>;
279    type Error = std::io::Error;
280    type AcceptFuture = tokio_rustls::Accept<C>;
281
282    fn accept(&self, conn: C) -> Self::AcceptFuture {
283        tokio_rustls::TlsAcceptor::accept(self, conn)
284    }
285}
286
287#[cfg(feature = "native-tls")]
288#[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))]
289impl<C> AsyncTls<C> for tokio_native_tls::TlsAcceptor
290where
291    C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
292{
293    type Stream = tokio_native_tls::TlsStream<C>;
294    type Error = tokio_native_tls::native_tls::Error;
295    type AcceptFuture = Pin<Box<dyn Future<Output = Result<Self::Stream, Self::Error>> + Send>>;
296
297    fn accept(&self, conn: C) -> Self::AcceptFuture {
298        let tls = self.clone();
299        Box::pin(async move { tokio_native_tls::TlsAcceptor::accept(&tls, conn).await })
300    }
301}
302
303#[cfg(feature = "openssl")]
304#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))]
305impl<C> AsyncTls<C> for openssl_impl::ssl::SslContext
306where
307    C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
308{
309    type Stream = tokio_openssl::SslStream<C>;
310    type Error = openssl_impl::ssl::Error;
311    type AcceptFuture = Pin<Box<dyn Future<Output = Result<Self::Stream, Self::Error>> + Send>>;
312
313    fn accept(&self, conn: C) -> Self::AcceptFuture {
314        let ssl = match openssl_impl::ssl::Ssl::new(self) {
315            Ok(s) => s,
316            Err(e) => {
317                return Box::pin(futures_util::future::err(e.into()));
318            }
319        };
320        let mut stream = match tokio_openssl::SslStream::new(ssl, conn) {
321            Ok(s) => s,
322            Err(e) => {
323                return Box::pin(futures_util::future::err(e.into()));
324            }
325        };
326        Box::pin(async move {
327            Pin::new(&mut stream).accept().await?;
328            Ok(stream)
329        })
330    }
331}
332
333impl<T> Builder<T> {
334    /// Set the size of batches of incoming connections to accept at once
335    ///
336    /// When polling for a new connection, the `TlsListener` will first check
337    /// for incomming connections on the listener that need to start a TLS handshake.
338    /// This specifies the maximum number of connections it will accept before seeing if any
339    /// TLS connections are ready.
340    ///
341    /// Having a limit for this ensures that ready TLS conections aren't starved if there are a
342    /// large number of incoming connections.
343    ///
344    /// Defaults to `DEFAULT_ACCEPT_BATCH_SIZE`.
345    pub fn accept_batch_size(&mut self, size: NonZeroUsize) -> &mut Self {
346        self.accept_batch_size = size;
347        self
348    }
349
350    /// Set the timeout for handshakes.
351    ///
352    /// If a timeout takes longer than `timeout`, then the handshake will be
353    /// aborted and the underlying connection will be dropped.
354    ///
355    /// The default is fairly conservative, to avoid dropping connections. It is
356    /// recommended that you adjust this to meet the specific needs of your use case
357    /// in production deployments.
358    ///
359    /// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
360    pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
361        self.handshake_timeout = timeout;
362        self
363    }
364
365    /// Create a `TlsListener` from the builder
366    ///
367    /// Actually build the `TlsListener`. The `listener` argument should be
368    /// an implementation of the `AsyncAccept` trait that accepts new connections
369    /// that the `TlsListener` will  encrypt using TLS.
370    pub fn listen<A: AsyncAccept>(&self, listener: A) -> TlsListener<A, T>
371    where
372        T: AsyncTls<A::Connection>,
373    {
374        TlsListener {
375            listener,
376            tls: self.tls.clone(),
377            waiting: FuturesUnordered::new(),
378            accept_batch_size: self.accept_batch_size,
379            timeout: self.handshake_timeout,
380        }
381    }
382}
383
384impl<LE: std::error::Error, TE: std::error::Error, A> Error<LE, TE, A> {
385    /// Get the peer address from the connection that caused the error, if applicable.
386    ///
387    /// This will only return Some for errors that occur after an initial connection
388    /// is established, such as TlsAcceptError and HandshakeTimeout. And only if
389    /// the [`AsyncAccept`] implementation implements [`peer_addr`](AsyncAccept::peer_addr)
390    pub fn peer_addr(&self) -> Option<&A> {
391        match self {
392            Error::TlsAcceptError { peer_addr, .. } | Self::HandshakeTimeout { peer_addr, .. } => {
393                Some(peer_addr)
394            }
395            _ => None,
396        }
397    }
398}
399
400/// Create a new Builder for a TlsListener
401///
402/// `server_config` will be used to configure the TLS sessions.
403pub fn builder<T>(tls: T) -> Builder<T> {
404    Builder {
405        tls,
406        accept_batch_size: DEFAULT_ACCEPT_BATCH_SIZE,
407        handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
408    }
409}
410
411pin_project! {
412    struct Waiting<A, T>
413    where
414        A: AsyncAccept,
415        T: AsyncTls<A::Connection>
416    {
417        #[pin]
418        inner: Timeout<T::AcceptFuture>,
419        peer_addr: Option<A::Address>,
420    }
421}
422
423impl<A, T> Future for Waiting<A, T>
424where
425    A: AsyncAccept,
426    T: AsyncTls<A::Connection>,
427{
428    type Output = Result<(T::Stream, A::Address), TlsListenerError<A, T>>;
429
430    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
431        let mut this = self.project();
432        let res = ready!(this.inner.as_mut().poll(cx));
433        let addr = this
434            .peer_addr
435            .take()
436            .expect("this future has already been polled to completion");
437        match res {
438            // We succesfully got a connection
439            Ok(Ok(conn)) => Poll::Ready(Ok((conn, addr))),
440            // The handshake failed
441            Ok(Err(e)) => Poll::Ready(Err(Error::TlsAcceptError {
442                error: e,
443                peer_addr: addr,
444            })),
445            // The handshake timed out
446            Err(_) => Poll::Ready(Err(Error::HandshakeTimeout { peer_addr: addr })),
447        }
448    }
449}