trifid_pki/
ca.rs

1//! Structs to represent a pool of CA's and blacklisted certificates
2
3use crate::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate};
4use ed25519_dalek::VerifyingKey;
5use std::collections::HashMap;
6use std::error::Error;
7use std::fmt::{Display, Formatter};
8use std::time::SystemTime;
9
10#[cfg(feature = "serde_derive")]
11use serde::{Deserialize, Serialize};
12
13/// A pool of trusted CA certificates, and certificates that should be blocked.
14/// This is equivalent to the `pki` section in a typical Nebula config.yml.
15#[derive(Default, Clone)]
16#[cfg_attr(feature = "serde_derive", derive(Serialize, Deserialize))]
17pub struct NebulaCAPool {
18    /// The list of CA root certificates that should be trusted.
19    pub cas: HashMap<String, NebulaCertificate>,
20    /// The list of blocklisted certificate fingerprints
21    pub cert_blocklist: Vec<String>,
22    /// True if any of the member CAs certificates are expired. Must be handled.
23    pub expired: bool,
24}
25
26impl NebulaCAPool {
27    /// Create a new, blank CA pool
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    /// Create a new CA pool from a set of PEM encoded CA certificates.
33    /// If any of the certificates are expired, the pool will **still be returned**, with the expired flag set.
34    /// This must be handled properly.
35    /// # Errors
36    /// This function will return an error if PEM data provided was invalid.
37    pub fn new_from_pem(bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
38        let pems = pem::parse_many(bytes)?;
39
40        let mut pool = Self::new();
41
42        for cert in pems {
43            match pool.add_ca_certificate(pem::encode(&cert).as_bytes()) {
44                Ok(did_expire) => {
45                    if did_expire {
46                        pool.expired = true;
47                    }
48                }
49                Err(e) => return Err(e),
50            }
51        }
52
53        Ok(pool)
54    }
55
56    /// Add a given CA certificate to the CA pool. If the certificate is expired, it will **still be added** - the return value will be `true` instead of `false`
57    /// # Errors
58    /// This function will return an error if the certificate is invalid in any way.
59    pub fn add_ca_certificate(&mut self, bytes: &[u8]) -> Result<bool, Box<dyn Error>> {
60        let cert = deserialize_nebula_certificate_from_pem(bytes)?;
61
62        if !cert.details.is_ca {
63            return Err(CaPoolError::NotACA.into());
64        }
65
66        if !cert.check_signature(&VerifyingKey::from_bytes(&cert.details.public_key)?)? {
67            return Err(CaPoolError::NotSelfSigned.into());
68        }
69
70        let fingerprint = cert.sha256sum()?;
71        let expired = cert.expired(SystemTime::now());
72
73        if expired {
74            self.expired = true;
75        }
76
77        self.cas.insert(fingerprint, cert);
78
79        Ok(expired)
80    }
81
82    /// Blocklist the given certificate in the CA pool
83    pub fn blocklist_fingerprint(&mut self, fingerprint: &str) {
84        self.cert_blocklist.push(fingerprint.to_string());
85    }
86
87    /// Clears the list of blocklisted fingerprints
88    pub fn reset_blocklist(&mut self) {
89        self.cert_blocklist = vec![];
90    }
91
92    /// Checks if the given certificate is blocklisted
93    pub fn is_blocklisted(&self, cert: &NebulaCertificate) -> bool {
94        let Ok(h) = cert.sha256sum() else { return false };
95        self.cert_blocklist.contains(&h)
96    }
97
98    /// Gets the CA certificate used to sign the given certificate
99    /// # Errors
100    /// This function will return an error if the certificate does not have an issuer attached (it is self-signed)
101    pub fn get_ca_for_cert(
102        &self,
103        cert: &NebulaCertificate,
104    ) -> Result<Option<&NebulaCertificate>, Box<dyn Error>> {
105        if cert.details.issuer == String::new() {
106            return Err(CaPoolError::NoIssuer.into());
107        }
108
109        Ok(self.cas.get(&cert.details.issuer))
110    }
111
112    /// Get a list of trusted CA fingerprints
113    pub fn get_fingerprints(&self) -> Vec<&String> {
114        self.cas.keys().collect()
115    }
116}
117
118/// A list of errors that can happen when working with a CA Pool
119#[derive(Debug)]
120#[cfg_attr(feature = "serde_derive", derive(Serialize, Deserialize))]
121pub enum CaPoolError {
122    /// Tried to add a non-CA cert to the CA pool
123    NotACA,
124    /// Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates)
125    NotSelfSigned,
126    /// Tried to look up a certificate that does not have an issuer field
127    NoIssuer,
128}
129impl Error for CaPoolError {}
130#[cfg(not(tarpaulin_include))]
131impl Display for CaPoolError {
132    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
133        match self {
134            Self::NotACA => write!(f, "Tried to add a non-CA cert to the CA pool"),
135            Self::NotSelfSigned => write!(f, "Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates)"),
136            Self::NoIssuer => write!(f, "Tried to look up a certificate with a null issuer field")
137        }
138    }
139}