1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
//! The [`CertifiedKeyLoader`].

use futures_util::TryFutureExt as _;

/// Load the [`rustls::sign::CertifiedKey`] from the specified paths using the specified readers.
pub struct CertifiedKeyLoader<KeyProvider, KeyReader, CertsReader> {
    /// The provider to load the key into.
    pub key_provider: KeyProvider,
    /// Reads a key from the file.
    pub key_reader: KeyReader,
    /// Reads a list of certs from file.
    pub certs_reader: CertsReader,
}

impl<KeyProvider, KeyReader, CertsReader> std::fmt::Debug
    for CertifiedKeyLoader<KeyProvider, KeyReader, CertsReader>
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CertifiedKeyLoader")
            .field("key_provider", &"...")
            .field("key_reader", &"...")
            .field("certs_reader", &"...")
            .finish()
    }
}

/// An error that can occur while loading the data.
#[derive(Debug, thiserror::Error)]
pub enum CertifiedKeyLoaderError<ReadKey, ReadCerts> {
    /// Reading the key failed.
    #[error("reading key: {0}")]
    ReadKey(ReadKey),
    /// Reading the certificate failed.
    #[error("reading certs: {0}")]
    ReadCerts(ReadCerts),
    /// Key processing failed.
    #[error("loading key: {0}")]
    LoadKey(rustls::Error),
}

#[async_trait::async_trait]
impl<KeyProvider, KeyReader, CertsReader> rustls_cert_reloadable::Loader
    for CertifiedKeyLoader<KeyProvider, KeyReader, CertsReader>
where
    KeyProvider: rustls::crypto::KeyProvider,
    KeyReader: rustls_cert_read::ReadKey + Send,
    CertsReader: rustls_cert_read::ReadCerts + Send,
    KeyReader::Error: std::error::Error + Send + 'static,
    CertsReader::Error: std::error::Error + Send + 'static,
{
    type Value = rustls::sign::CertifiedKey;
    type Error = CertifiedKeyLoaderError<KeyReader::Error, CertsReader::Error>;

    async fn load(&mut self) -> Result<Self::Value, Self::Error> {
        let (certs, key) = {
            let key_fut = self
                .key_reader
                .read_key()
                .map_err(CertifiedKeyLoaderError::ReadKey);
            let certs_fut = self
                .certs_reader
                .read_certs()
                .map_err(CertifiedKeyLoaderError::ReadCerts);
            futures_util::future::try_join(certs_fut, key_fut).await?
        };

        let key = self
            .key_provider
            .load_private_key(key)
            .map_err(CertifiedKeyLoaderError::LoadKey)?;

        Ok(rustls::sign::CertifiedKey::new(certs, key))
    }
}