tokio_rustls_acme/
axum.rs

1use crate::{AcmeAccept, AcmeAcceptor};
2use rustls::ServerConfig;
3use std::future::Future;
4use std::io;
5use std::io::ErrorKind;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_rustls::Accept;
11
12#[derive(Clone)]
13pub struct AxumAcceptor {
14    acme_acceptor: AcmeAcceptor,
15    config: Arc<ServerConfig>,
16}
17
18impl AxumAcceptor {
19    pub fn new(acme_acceptor: AcmeAcceptor, config: Arc<ServerConfig>) -> Self {
20        Self {
21            acme_acceptor,
22            config,
23        }
24    }
25}
26
27impl<I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static>
28    axum_server::accept::Accept<I, S> for AxumAcceptor
29{
30    type Stream = tokio_rustls::server::TlsStream<I>;
31    type Service = S;
32    type Future = AxumAccept<I, S>;
33
34    fn accept(&self, stream: I, service: S) -> Self::Future {
35        let acme_accept = self.acme_acceptor.accept(stream);
36        Self::Future {
37            config: self.config.clone(),
38            acme_accept,
39            tls_accept: None,
40            service: Some(service),
41        }
42    }
43}
44
45pub struct AxumAccept<I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static> {
46    config: Arc<ServerConfig>,
47    acme_accept: AcmeAccept<I>,
48    tls_accept: Option<Accept<I>>,
49    service: Option<S>,
50}
51
52impl<I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static> Unpin
53    for AxumAccept<I, S>
54{
55}
56
57impl<I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static> Future
58    for AxumAccept<I, S>
59{
60    type Output = io::Result<(tokio_rustls::server::TlsStream<I>, S)>;
61
62    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63        loop {
64            if let Some(tls_accept) = &mut self.tls_accept {
65                return match Pin::new(&mut *tls_accept).poll(cx) {
66                    Poll::Ready(Ok(tls)) => Poll::Ready(Ok((tls, self.service.take().unwrap()))),
67                    Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
68                    Poll::Pending => Poll::Pending,
69                };
70            }
71            return match Pin::new(&mut self.acme_accept).poll(cx) {
72                Poll::Ready(Ok(Some(start_handshake))) => {
73                    let config = self.config.clone();
74                    self.tls_accept = Some(start_handshake.into_stream(config));
75                    continue;
76                }
77                Poll::Ready(Ok(None)) => Poll::Ready(Err(io::Error::new(
78                    ErrorKind::Other,
79                    "TLS-ALPN-01 validation request",
80                ))),
81                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
82                Poll::Pending => Poll::Pending,
83            };
84        }
85    }
86}