1#![deny(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use futures_util::stream::{FuturesUnordered, Stream, StreamExt, TryStreamExt};
20use pin_project_lite::pin_project;
21#[cfg(feature = "rt")]
22pub use spawning_handshake::SpawningHandshakes;
23use std::fmt::Debug;
24use std::future::{poll_fn, Future};
25use std::num::NonZeroUsize;
26use std::pin::Pin;
27use std::task::{ready, Context, Poll};
28use std::time::Duration;
29use thiserror::Error;
30use tokio::io::{AsyncRead, AsyncWrite};
31use tokio::time::{timeout, Timeout};
32#[cfg(feature = "native-tls")]
33pub use tokio_native_tls as native_tls;
34#[cfg(feature = "openssl")]
35pub use tokio_openssl as openssl;
36#[cfg(feature = "rustls-core")]
37pub use tokio_rustls as rustls;
38
39mod accept;
40#[cfg(feature = "tokio-net")]
41mod net;
42#[cfg(feature = "rt")]
43mod spawning_handshake;
44
45pub use accept::*;
46
47#[cfg(feature = "axum")]
48mod axum;
49
50pub const DEFAULT_ACCEPT_BATCH_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(64) };
52pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
54
55pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
59 type Stream;
61 type Error: std::error::Error;
63 type AcceptFuture: Future<Output = Result<Self::Stream, Self::Error>>;
65
66 fn accept(&self, stream: C) -> Self::AcceptFuture;
68}
69
70pin_project! {
71 pub struct TlsListener<A: AsyncAccept, T: AsyncTls<A::Connection>> {
98 #[pin]
99 listener: A,
100 tls: T,
101 waiting: FuturesUnordered<Waiting<A, T>>,
102 accept_batch_size: NonZeroUsize,
103 timeout: Duration,
104 }
105}
106
107#[derive(Clone)]
109pub struct Builder<T> {
110 tls: T,
111 accept_batch_size: NonZeroUsize,
112 handshake_timeout: Duration,
113}
114
115#[derive(Debug, Error)]
117#[non_exhaustive]
118pub enum Error<LE: std::error::Error, TE: std::error::Error, Addr> {
119 #[error("{0}")]
121 ListenerError(#[source] LE),
122 #[error("{error}")]
124 #[non_exhaustive]
125 TlsAcceptError {
126 #[source]
128 error: TE,
129
130 peer_addr: Addr,
132 },
133 #[error("Timeout during TLS handshake")]
135 #[non_exhaustive]
136 HandshakeTimeout {
137 peer_addr: Addr,
139 },
140}
141
142impl<A: AsyncAccept, T> TlsListener<A, T>
143where
144 T: AsyncTls<A::Connection>,
145{
146 pub fn new(tls: T, listener: A) -> Self {
148 builder(tls).listen(listener)
149 }
150}
151
152type TlsListenerError<A, T> = Error<
155 <A as AsyncAccept>::Error,
156 <T as AsyncTls<<A as AsyncAccept>::Connection>>::Error,
157 <A as AsyncAccept>::Address,
158>;
159
160impl<A, T> TlsListener<A, T>
161where
162 A: AsyncAccept,
163 T: AsyncTls<A::Connection>,
164{
165 pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
170 let mut this = self.project();
171
172 loop {
173 let mut empty_listener = false;
174 for _ in 0..this.accept_batch_size.get() {
175 match this.listener.as_mut().poll_accept(cx) {
176 Poll::Pending => {
177 empty_listener = true;
178 break;
179 }
180 Poll::Ready(Ok((conn, addr))) => {
181 this.waiting.push(Waiting {
182 inner: timeout(*this.timeout, this.tls.accept(conn)),
183 peer_addr: Some(addr),
184 });
185 }
186 Poll::Ready(Err(e)) => {
187 return Poll::Ready(Err(Error::ListenerError(e)));
188 }
189 }
190 }
191
192 match this.waiting.poll_next_unpin(cx) {
193 Poll::Ready(Some(result)) => return Poll::Ready(result),
194 Poll::Ready(None) | Poll::Pending => {
197 if empty_listener {
198 return Poll::Pending;
199 }
200 }
201 }
202 }
203 }
204
205 pub fn accept(&mut self) -> impl Future<Output = <Self as Stream>::Item> + '_
213 where
214 Self: Unpin,
215 {
216 let mut pinned = Pin::new(self);
217 poll_fn(move |cx| pinned.as_mut().poll_accept(cx))
218 }
219
220 pub fn replace_acceptor(&mut self, acceptor: T) {
224 self.tls = acceptor;
225 }
226
227 pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) {
233 *self.project().tls = acceptor;
234 }
235
236 pub fn connections(self) -> impl Stream<Item = Result<T::Stream, TlsListenerError<A, T>>> {
243 self.map_ok(|(conn, _addr)| conn)
244 }
245
246 pub fn listener(&self) -> &A {
251 &self.listener
252 }
253
254 pub fn local_addr(&self) -> Result<A::Address, A::Error>
256 where
257 A: AsyncListener,
258 {
259 self.listener.local_addr()
260 }
261}
262
263impl<A, T> Stream for TlsListener<A, T>
264where
265 A: AsyncAccept,
266 T: AsyncTls<A::Connection>,
267{
268 type Item = Result<(T::Stream, A::Address), TlsListenerError<A, T>>;
269
270 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271 self.poll_accept(cx).map(Some)
272 }
273}
274
275#[cfg(feature = "rustls-core")]
276#[cfg_attr(docsrs, doc(cfg(feature = "rustls-core")))]
277impl<C: AsyncRead + AsyncWrite + Unpin> AsyncTls<C> for tokio_rustls::TlsAcceptor {
278 type Stream = tokio_rustls::server::TlsStream<C>;
279 type Error = std::io::Error;
280 type AcceptFuture = tokio_rustls::Accept<C>;
281
282 fn accept(&self, conn: C) -> Self::AcceptFuture {
283 tokio_rustls::TlsAcceptor::accept(self, conn)
284 }
285}
286
287#[cfg(feature = "native-tls")]
288#[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))]
289impl<C> AsyncTls<C> for tokio_native_tls::TlsAcceptor
290where
291 C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
292{
293 type Stream = tokio_native_tls::TlsStream<C>;
294 type Error = tokio_native_tls::native_tls::Error;
295 type AcceptFuture = Pin<Box<dyn Future<Output = Result<Self::Stream, Self::Error>> + Send>>;
296
297 fn accept(&self, conn: C) -> Self::AcceptFuture {
298 let tls = self.clone();
299 Box::pin(async move { tokio_native_tls::TlsAcceptor::accept(&tls, conn).await })
300 }
301}
302
303#[cfg(feature = "openssl")]
304#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))]
305impl<C> AsyncTls<C> for openssl_impl::ssl::SslContext
306where
307 C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
308{
309 type Stream = tokio_openssl::SslStream<C>;
310 type Error = openssl_impl::ssl::Error;
311 type AcceptFuture = Pin<Box<dyn Future<Output = Result<Self::Stream, Self::Error>> + Send>>;
312
313 fn accept(&self, conn: C) -> Self::AcceptFuture {
314 let ssl = match openssl_impl::ssl::Ssl::new(self) {
315 Ok(s) => s,
316 Err(e) => {
317 return Box::pin(futures_util::future::err(e.into()));
318 }
319 };
320 let mut stream = match tokio_openssl::SslStream::new(ssl, conn) {
321 Ok(s) => s,
322 Err(e) => {
323 return Box::pin(futures_util::future::err(e.into()));
324 }
325 };
326 Box::pin(async move {
327 Pin::new(&mut stream).accept().await?;
328 Ok(stream)
329 })
330 }
331}
332
333impl<T> Builder<T> {
334 pub fn accept_batch_size(&mut self, size: NonZeroUsize) -> &mut Self {
346 self.accept_batch_size = size;
347 self
348 }
349
350 pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
361 self.handshake_timeout = timeout;
362 self
363 }
364
365 pub fn listen<A: AsyncAccept>(&self, listener: A) -> TlsListener<A, T>
371 where
372 T: AsyncTls<A::Connection>,
373 {
374 TlsListener {
375 listener,
376 tls: self.tls.clone(),
377 waiting: FuturesUnordered::new(),
378 accept_batch_size: self.accept_batch_size,
379 timeout: self.handshake_timeout,
380 }
381 }
382}
383
384impl<LE: std::error::Error, TE: std::error::Error, A> Error<LE, TE, A> {
385 pub fn peer_addr(&self) -> Option<&A> {
391 match self {
392 Error::TlsAcceptError { peer_addr, .. } | Self::HandshakeTimeout { peer_addr, .. } => {
393 Some(peer_addr)
394 }
395 _ => None,
396 }
397 }
398}
399
400pub fn builder<T>(tls: T) -> Builder<T> {
404 Builder {
405 tls,
406 accept_batch_size: DEFAULT_ACCEPT_BATCH_SIZE,
407 handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
408 }
409}
410
411pin_project! {
412 struct Waiting<A, T>
413 where
414 A: AsyncAccept,
415 T: AsyncTls<A::Connection>
416 {
417 #[pin]
418 inner: Timeout<T::AcceptFuture>,
419 peer_addr: Option<A::Address>,
420 }
421}
422
423impl<A, T> Future for Waiting<A, T>
424where
425 A: AsyncAccept,
426 T: AsyncTls<A::Connection>,
427{
428 type Output = Result<(T::Stream, A::Address), TlsListenerError<A, T>>;
429
430 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
431 let mut this = self.project();
432 let res = ready!(this.inner.as_mut().poll(cx));
433 let addr = this
434 .peer_addr
435 .take()
436 .expect("this future has already been polled to completion");
437 match res {
438 Ok(Ok(conn)) => Poll::Ready(Ok((conn, addr))),
440 Ok(Err(e)) => Poll::Ready(Err(Error::TlsAcceptError {
442 error: e,
443 peer_addr: addr,
444 })),
445 Err(_) => Poll::Ready(Err(Error::HandshakeTimeout { peer_addr: addr })),
447 }
448 }
449}