1use crate::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate};
4use ed25519_dalek::VerifyingKey;
5use std::collections::HashMap;
6use std::error::Error;
7use std::fmt::{Display, Formatter};
8use std::time::SystemTime;
9
10#[cfg(feature = "serde_derive")]
11use serde::{Deserialize, Serialize};
12
13#[derive(Default, Clone)]
16#[cfg_attr(feature = "serde_derive", derive(Serialize, Deserialize))]
17pub struct NebulaCAPool {
18 pub cas: HashMap<String, NebulaCertificate>,
20 pub cert_blocklist: Vec<String>,
22 pub expired: bool,
24}
25
26impl NebulaCAPool {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn new_from_pem(bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
38 let pems = pem::parse_many(bytes)?;
39
40 let mut pool = Self::new();
41
42 for cert in pems {
43 match pool.add_ca_certificate(pem::encode(&cert).as_bytes()) {
44 Ok(did_expire) => {
45 if did_expire {
46 pool.expired = true;
47 }
48 }
49 Err(e) => return Err(e),
50 }
51 }
52
53 Ok(pool)
54 }
55
56 pub fn add_ca_certificate(&mut self, bytes: &[u8]) -> Result<bool, Box<dyn Error>> {
60 let cert = deserialize_nebula_certificate_from_pem(bytes)?;
61
62 if !cert.details.is_ca {
63 return Err(CaPoolError::NotACA.into());
64 }
65
66 if !cert.check_signature(&VerifyingKey::from_bytes(&cert.details.public_key)?)? {
67 return Err(CaPoolError::NotSelfSigned.into());
68 }
69
70 let fingerprint = cert.sha256sum()?;
71 let expired = cert.expired(SystemTime::now());
72
73 if expired {
74 self.expired = true;
75 }
76
77 self.cas.insert(fingerprint, cert);
78
79 Ok(expired)
80 }
81
82 pub fn blocklist_fingerprint(&mut self, fingerprint: &str) {
84 self.cert_blocklist.push(fingerprint.to_string());
85 }
86
87 pub fn reset_blocklist(&mut self) {
89 self.cert_blocklist = vec![];
90 }
91
92 pub fn is_blocklisted(&self, cert: &NebulaCertificate) -> bool {
94 let Ok(h) = cert.sha256sum() else { return false };
95 self.cert_blocklist.contains(&h)
96 }
97
98 pub fn get_ca_for_cert(
102 &self,
103 cert: &NebulaCertificate,
104 ) -> Result<Option<&NebulaCertificate>, Box<dyn Error>> {
105 if cert.details.issuer == String::new() {
106 return Err(CaPoolError::NoIssuer.into());
107 }
108
109 Ok(self.cas.get(&cert.details.issuer))
110 }
111
112 pub fn get_fingerprints(&self) -> Vec<&String> {
114 self.cas.keys().collect()
115 }
116}
117
118#[derive(Debug)]
120#[cfg_attr(feature = "serde_derive", derive(Serialize, Deserialize))]
121pub enum CaPoolError {
122 NotACA,
124 NotSelfSigned,
126 NoIssuer,
128}
129impl Error for CaPoolError {}
130#[cfg(not(tarpaulin_include))]
131impl Display for CaPoolError {
132 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
133 match self {
134 Self::NotACA => write!(f, "Tried to add a non-CA cert to the CA pool"),
135 Self::NotSelfSigned => write!(f, "Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates)"),
136 Self::NoIssuer => write!(f, "Tried to look up a certificate with a null issuer field")
137 }
138 }
139}