1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use crate::acme::ACME_TLS_ALPN_NAME;
use crate::ResolvesServerCertUsingAcme;
use async_rustls::rustls::{ServerConfig, Session};
use async_rustls::server::TlsStream;
use futures::{AsyncRead, AsyncWrite};
use std::future::Future;
use std::mem::replace;
use std::ops::DerefMut;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

#[derive(Clone)]
pub struct TlsAcceptor {
    config: Arc<ServerConfig>,
}

impl TlsAcceptor {
    pub fn new(mut config: ServerConfig, resolver: Arc<ResolvesServerCertUsingAcme>) -> Self {
        config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
        config.cert_resolver = resolver;
        let config = Arc::new(config);
        TlsAcceptor { config }
    }
    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
    where
        IO: AsyncRead + AsyncWrite + Unpin,
    {
        Accept::Accepting(async_rustls::TlsAcceptor::from(self.config.clone()).accept(stream))
    }
}

pub enum Accept<IO> {
    Accepting(async_rustls::Accept<IO>),
    Closing(TlsStream<IO>),
    Closed,
}

impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
    type Output = std::io::Result<async_rustls::server::TlsStream<IO>>;

    #[inline]
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        loop {
            match self.deref_mut() {
                Accept::Accepting(accept) => match Pin::new(accept).poll(cx) {
                    Poll::Ready(Ok(tls)) => match tls.get_ref().1.get_alpn_protocol() {
                        Some(ACME_TLS_ALPN_NAME) => self.set(Accept::Closing(tls)),
                        _ => return Poll::Ready(Ok(tls)),
                    },
                    p => return p,
                },
                Accept::Closing(tls) => match Pin::new(tls).poll_close(cx) {
                    Poll::Ready(Ok(())) => match replace(self.get_mut(), Accept::Closed) {
                        Accept::Closing(tls) => return Poll::Ready(Ok(tls)),
                        _ => unreachable!(),
                    },
                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
                    Poll::Pending => return Poll::Pending,
                },
                Accept::Closed => panic!("polled after returning closed tls connection"),
            }
        }
    }
}