rustls_acme/
axum.rs

1use crate::futures_rustls::rustls::ServerConfig;
2use crate::{AcmeAccept, AcmeAcceptor};
3use futures::prelude::*;
4use futures_rustls::Accept;
5use std::io;
6use std::io::ErrorKind;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite};
11use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
12
13#[derive(Clone)]
14pub struct AxumAcceptor {
15    acme_acceptor: AcmeAcceptor,
16    config: Arc<ServerConfig>,
17}
18
19impl AxumAcceptor {
20    pub fn new(acme_acceptor: AcmeAcceptor, config: Arc<ServerConfig>) -> Self {
21        Self { acme_acceptor, config }
22    }
23}
24
25impl<I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static> axum_server::accept::Accept<I, S> for AxumAcceptor {
26    type Stream = Compat<futures_rustls::server::TlsStream<Compat<I>>>;
27    type Service = S;
28    type Future = AxumAccept<I, S>;
29
30    fn accept(&self, stream: I, service: S) -> Self::Future {
31        let acme_accept = self.acme_acceptor.accept(stream.compat());
32        Self::Future {
33            config: self.config.clone(),
34            acme_accept,
35            tls_accept: None,
36            service: Some(service),
37        }
38    }
39}
40
41pub struct AxumAccept<I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static> {
42    config: Arc<ServerConfig>,
43    acme_accept: AcmeAccept<Compat<I>>,
44    tls_accept: Option<Accept<Compat<I>>>,
45    service: Option<S>,
46}
47
48impl<I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static> Unpin for AxumAccept<I, S> {}
49
50impl<I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static> Future for AxumAccept<I, S> {
51    type Output = io::Result<(Compat<futures_rustls::server::TlsStream<Compat<I>>>, S)>;
52
53    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54        loop {
55            if let Some(tls_accept) = &mut self.tls_accept {
56                return match Pin::new(&mut *tls_accept).poll(cx) {
57                    Poll::Ready(Ok(tls)) => Poll::Ready(Ok((tls.compat(), self.service.take().unwrap()))),
58                    Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
59                    Poll::Pending => Poll::Pending,
60                };
61            }
62            return match Pin::new(&mut self.acme_accept).poll(cx) {
63                Poll::Ready(Ok(Some(start_handshake))) => {
64                    let config = self.config.clone();
65                    self.tls_accept = Some(start_handshake.into_stream(config));
66                    continue;
67                }
68                Poll::Ready(Ok(None)) => Poll::Ready(Err(io::Error::new(ErrorKind::Other, "TLS-ALPN-01 validation request"))),
69                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
70                Poll::Pending => Poll::Pending,
71            };
72        }
73    }
74}