tokio_rustls_acme/
state.rs

1use std::convert::Infallible;
2use std::fmt::Debug;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9use chrono::{DateTime, TimeZone, Utc};
10use futures::future::try_join_all;
11use futures::{ready, FutureExt, Stream};
12use rcgen::{CertificateParams, DistinguishedName, Error as RcgenError, PKCS_ECDSA_P256_SHA256};
13use rustls::crypto::ring::sign::any_ecdsa_type;
14use rustls::pki_types::{CertificateDer as RustlsCertificate, PrivateKeyDer, PrivatePkcs8KeyDer};
15use rustls::sign::CertifiedKey;
16use thiserror::Error;
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio::time::Sleep;
19use x509_parser::parse_x509_certificate;
20
21use crate::acceptor::AcmeAcceptor;
22use crate::acme::{
23    Account, AcmeError, Auth, AuthStatus, Directory, Identifier, Order, OrderStatus,
24};
25use crate::{AcmeConfig, Incoming, ResolvesServerCertAcme};
26
27type Timer = std::pin::Pin<Box<Sleep>>;
28type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
29
30pub fn after(d: std::time::Duration) -> Timer {
31    Box::pin(tokio::time::sleep(d))
32}
33
34#[allow(clippy::type_complexity)]
35pub struct AcmeState<EC: Debug = Infallible, EA: Debug = EC> {
36    config: Arc<AcmeConfig<EC, EA>>,
37    resolver: Arc<ResolvesServerCertAcme>,
38    account_key: Option<Vec<u8>>,
39
40    early_action: Option<BoxFuture<Event<EC, EA>>>,
41    load_cert: Option<BoxFuture<Result<Option<Vec<u8>>, EC>>>,
42    load_account: Option<BoxFuture<Result<Option<Vec<u8>>, EA>>>,
43    order: Option<BoxFuture<Result<Vec<u8>, OrderError>>>,
44    backoff_cnt: usize,
45    wait: Option<Timer>,
46}
47
48pub type Event<EC, EA> = Result<EventOk, EventError<EC, EA>>;
49
50#[derive(Debug)]
51pub enum EventOk {
52    DeployedCachedCert,
53    DeployedNewCert,
54    CertCacheStore,
55    AccountCacheStore,
56}
57
58#[derive(Error, Debug)]
59pub enum EventError<EC: Debug, EA: Debug> {
60    #[error("cert cache load: {0}")]
61    CertCacheLoad(EC),
62    #[error("account cache load: {0}")]
63    AccountCacheLoad(EA),
64    #[error("cert cache store: {0}")]
65    CertCacheStore(EC),
66    #[error("account cache store: {0}")]
67    AccountCacheStore(EA),
68    #[error("cached cert parse: {0}")]
69    CachedCertParse(CertParseError),
70    #[error("order: {0}")]
71    Order(OrderError),
72    #[error("new cert parse: {0}")]
73    NewCertParse(CertParseError),
74}
75
76#[derive(Error, Debug)]
77pub enum OrderError {
78    #[error("acme error: {0}")]
79    Acme(#[from] AcmeError),
80    #[error("certificate generation error: {0}")]
81    Rcgen(#[from] RcgenError),
82    #[error("bad order object: {0:?}")]
83    BadOrder(Order),
84    #[error("bad auth object: {0:?}")]
85    BadAuth(Auth),
86    #[error("authorization for {0} failed too many times")]
87    TooManyAttemptsAuth(String),
88    #[error("order status stayed on processing too long")]
89    ProcessingTimeout(Order),
90}
91
92#[derive(Error, Debug)]
93pub enum CertParseError {
94    #[error("X509 parsing error: {0}")]
95    X509(#[from] x509_parser::nom::Err<x509_parser::error::X509Error>),
96    #[error("expected 2 or more pem, got: {0}")]
97    Pem(#[from] pem::PemError),
98    #[error("expected 2 or more pem, got: {0}")]
99    TooFewPem(usize),
100    #[error("unsupported private key type")]
101    InvalidPrivateKey,
102}
103
104impl<EC: 'static + Debug, EA: 'static + Debug> AcmeState<EC, EA> {
105    pub fn incoming<
106        TCP: AsyncRead + AsyncWrite + Unpin,
107        ETCP,
108        ITCP: Stream<Item = Result<TCP, ETCP>> + Unpin,
109    >(
110        self,
111        tcp_incoming: ITCP,
112        alpn_protocols: Vec<Vec<u8>>,
113    ) -> Incoming<TCP, ETCP, ITCP, EC, EA> {
114        let acceptor = self.acceptor();
115        Incoming::new(tcp_incoming, self, acceptor, alpn_protocols)
116    }
117    pub fn acceptor(&self) -> AcmeAcceptor {
118        AcmeAcceptor::new(self.resolver())
119    }
120
121    #[cfg(feature = "axum")]
122    pub fn axum_acceptor(
123        &self,
124        rustls_config: Arc<rustls::ServerConfig>,
125    ) -> crate::axum::AxumAcceptor {
126        crate::axum::AxumAcceptor::new(self.acceptor(), rustls_config)
127    }
128    pub fn resolver(&self) -> Arc<ResolvesServerCertAcme> {
129        self.resolver.clone()
130    }
131    pub fn new(config: AcmeConfig<EC, EA>) -> Self {
132        let config = Arc::new(config);
133        Self {
134            config: config.clone(),
135            resolver: ResolvesServerCertAcme::new(),
136            account_key: None,
137            early_action: None,
138            load_cert: Some(Box::pin({
139                let config = config.clone();
140                async move {
141                    config
142                        .cache
143                        .load_cert(&config.domains, &config.directory_url)
144                        .await
145                }
146            })),
147            load_account: Some(Box::pin({
148                let config = config;
149                async move {
150                    config
151                        .cache
152                        .load_account(&config.contact, &config.directory_url)
153                        .await
154                }
155            })),
156            order: None,
157            backoff_cnt: 0,
158            wait: None,
159        }
160    }
161    fn parse_cert(pem: &[u8]) -> Result<(CertifiedKey, [DateTime<Utc>; 2]), CertParseError> {
162        let mut pems = pem::parse_many(pem)?;
163        if pems.len() < 2 {
164            return Err(CertParseError::TooFewPem(pems.len()));
165        }
166        let pk_bytes = pems.remove(0).into_contents();
167        let pk_der: PrivatePkcs8KeyDer = pk_bytes.into();
168        let pk: PrivateKeyDer = pk_der.into();
169        let pk = match any_ecdsa_type(&pk) {
170            Ok(pk) => pk,
171            Err(_) => return Err(CertParseError::InvalidPrivateKey),
172        };
173        let cert_chain: Vec<RustlsCertificate> =
174            pems.into_iter().map(|p| p.into_contents().into()).collect();
175        let validity = match parse_x509_certificate(cert_chain[0].as_ref()) {
176            Ok((_, cert)) => {
177                let validity = cert.validity();
178                [validity.not_before, validity.not_after]
179                    .map(|t| Utc.timestamp_opt(t.timestamp(), 0).earliest().unwrap())
180            }
181            Err(err) => return Err(CertParseError::X509(err)),
182        };
183        let cert = CertifiedKey::new(cert_chain, pk);
184        Ok((cert, validity))
185    }
186
187    #[allow(clippy::result_large_err)]
188    fn process_cert(&mut self, pem: Vec<u8>, cached: bool) -> Event<EC, EA> {
189        let (cert, validity) = match (Self::parse_cert(&pem), cached) {
190            (Ok(r), _) => r,
191            (Err(err), cached) => {
192                return match cached {
193                    true => Err(EventError::CachedCertParse(err)),
194                    false => Err(EventError::NewCertParse(err)),
195                }
196            }
197        };
198        self.resolver.set_cert(Arc::new(cert));
199        let wait_duration = (validity[1] - (validity[1] - validity[0]) / 3 - Utc::now())
200            .max(chrono::Duration::zero())
201            .to_std()
202            .unwrap_or_default();
203        self.wait = Some(after(wait_duration));
204        if cached {
205            return Ok(EventOk::DeployedCachedCert);
206        }
207        let config = self.config.clone();
208        self.early_action = Some(Box::pin(async move {
209            match config
210                .cache
211                .store_cert(&config.domains, &config.directory_url, &pem)
212                .await
213            {
214                Ok(()) => Ok(EventOk::CertCacheStore),
215                Err(err) => Err(EventError::CertCacheStore(err)),
216            }
217        }));
218        Event::Ok(EventOk::DeployedNewCert)
219    }
220    async fn order(
221        config: Arc<AcmeConfig<EC, EA>>,
222        resolver: Arc<ResolvesServerCertAcme>,
223        key_pair: Vec<u8>,
224    ) -> Result<Vec<u8>, OrderError> {
225        let directory = Directory::discover(&config.client_config, &config.directory_url).await?;
226        let account = Account::create_with_keypair(
227            &config.client_config,
228            directory,
229            &config.contact,
230            &key_pair,
231        )
232        .await?;
233
234        let mut params = CertificateParams::new(config.domains.clone())?;
235        params.distinguished_name = DistinguishedName::new();
236        let key_pair = rcgen::KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)?;
237
238        let (order_url, mut order) = account
239            .new_order(&config.client_config, config.domains.clone())
240            .await?;
241        loop {
242            match order.status {
243                OrderStatus::Pending => {
244                    let auth_futures = order
245                        .authorizations
246                        .iter()
247                        .map(|url| Self::authorize(&config, &resolver, &account, url));
248                    try_join_all(auth_futures).await?;
249                    log::info!("completed all authorizations");
250                    order = account.order(&config.client_config, &order_url).await?;
251                }
252                OrderStatus::Processing => {
253                    for i in 0u64..10 {
254                        log::info!("order processing");
255                        after(Duration::from_secs(1u64 << i)).await;
256                        order = account.order(&config.client_config, &order_url).await?;
257                        if order.status != OrderStatus::Processing {
258                            break;
259                        }
260                    }
261                    if order.status == OrderStatus::Processing {
262                        return Err(OrderError::ProcessingTimeout(order));
263                    }
264                }
265                OrderStatus::Ready => {
266                    log::info!("sending csr");
267                    let csr = params.serialize_request(&key_pair)?;
268                    order = account
269                        .finalize(&config.client_config, order.finalize, csr.der().to_vec())
270                        .await?
271                }
272                OrderStatus::Valid { certificate } => {
273                    log::info!("download certificate");
274                    let pem = [
275                        &key_pair.serialize_pem(),
276                        "\n",
277                        &account
278                            .certificate(&config.client_config, certificate)
279                            .await?,
280                    ]
281                    .concat();
282                    return Ok(pem.into_bytes());
283                }
284                OrderStatus::Invalid => return Err(OrderError::BadOrder(order)),
285            }
286        }
287    }
288    async fn authorize(
289        config: &AcmeConfig<EC, EA>,
290        resolver: &ResolvesServerCertAcme,
291        account: &Account,
292        url: &String,
293    ) -> Result<(), OrderError> {
294        let auth = account.auth(&config.client_config, url).await?;
295        let (domain, challenge_url) = match auth.status {
296            AuthStatus::Pending => {
297                let Identifier::Dns(domain) = auth.identifier;
298                log::info!("trigger challenge for {}", &domain);
299                let (challenge, auth_key) =
300                    account.tls_alpn_01(&auth.challenges, domain.clone())?;
301                resolver.set_auth_key(domain.clone(), Arc::new(auth_key));
302                account
303                    .challenge(&config.client_config, &challenge.url)
304                    .await?;
305                (domain, challenge.url.clone())
306            }
307            AuthStatus::Valid => return Ok(()),
308            _ => return Err(OrderError::BadAuth(auth)),
309        };
310        for i in 0u64..5 {
311            after(Duration::from_secs(1u64 << i)).await;
312            let auth = account.auth(&config.client_config, url).await?;
313            match auth.status {
314                AuthStatus::Pending => {
315                    log::info!("authorization for {} still pending", &domain);
316                    account
317                        .challenge(&config.client_config, &challenge_url)
318                        .await?
319                }
320                AuthStatus::Valid => return Ok(()),
321                _ => return Err(OrderError::BadAuth(auth)),
322            }
323        }
324        Err(OrderError::TooManyAttemptsAuth(domain))
325    }
326    fn poll_next_infinite(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Event<EC, EA>> {
327        loop {
328            // queued early action
329            if let Some(early_action) = &mut self.early_action {
330                let result = ready!(early_action.poll_unpin(cx));
331                self.early_action.take();
332                return Poll::Ready(result);
333            }
334
335            // sleep
336            if let Some(timer) = &mut self.wait {
337                ready!(timer.poll_unpin(cx));
338                self.wait.take();
339            }
340
341            // load from cert cache
342            if let Some(load_cert) = &mut self.load_cert {
343                let result = ready!(load_cert.poll_unpin(cx));
344                self.load_cert.take();
345                match result {
346                    Ok(Some(pem)) => {
347                        return Poll::Ready(Self::process_cert(self.get_mut(), pem, true));
348                    }
349                    Ok(None) => {}
350                    Err(err) => return Poll::Ready(Err(EventError::CertCacheLoad(err))),
351                }
352            }
353
354            // load from account cache
355            if let Some(load_account) = &mut self.load_account {
356                let result = ready!(load_account.poll_unpin(cx));
357                self.load_account.take();
358                match result {
359                    Ok(Some(key_pair)) => self.account_key = Some(key_pair),
360                    Ok(None) => {}
361                    Err(err) => return Poll::Ready(Err(EventError::AccountCacheLoad(err))),
362                }
363            }
364
365            // execute order
366            if let Some(order) = &mut self.order {
367                let result = ready!(order.poll_unpin(cx));
368                self.order.take();
369                match result {
370                    Ok(pem) => {
371                        self.backoff_cnt = 0;
372                        return Poll::Ready(Self::process_cert(self.get_mut(), pem, false));
373                    }
374                    Err(err) => {
375                        // TODO: replace key on some errors or high backoff_cnt?
376                        self.wait = Some(after(Duration::from_secs(1 << self.backoff_cnt)));
377                        self.backoff_cnt = (self.backoff_cnt + 1).min(16);
378                        return Poll::Ready(Err(EventError::Order(err)));
379                    }
380                }
381            }
382
383            // schedule order
384            let account_key = match &self.account_key {
385                None => {
386                    let account_key = Account::generate_key_pair();
387                    self.account_key = Some(account_key.clone());
388                    let config = self.config.clone();
389                    let account_key_clone = account_key.clone();
390                    self.early_action = Some(Box::pin(async move {
391                        match config
392                            .cache
393                            .store_account(
394                                &config.contact,
395                                &config.directory_url,
396                                &account_key_clone,
397                            )
398                            .await
399                        {
400                            Ok(()) => Ok(EventOk::AccountCacheStore),
401                            Err(err) => Err(EventError::AccountCacheStore(err)),
402                        }
403                    }));
404                    account_key
405                }
406                Some(account_key) => account_key.clone(),
407            };
408            let config = self.config.clone();
409            let resolver = self.resolver.clone();
410            self.order = Some(Box::pin({
411                Self::order(config.clone(), resolver.clone(), account_key)
412            }));
413        }
414    }
415}
416
417impl<EC: 'static + Debug, EA: 'static + Debug> Stream for AcmeState<EC, EA> {
418    type Item = Event<EC, EA>;
419
420    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
421        Poll::Ready(Some(ready!(self.poll_next_infinite(cx))))
422    }
423}