1use std::{fmt, str::FromStr};
2
3use hex::{FromHex, FromHexError};
4use serde::de::{self, Visitor};
5use sha2::{Digest, Sha256};
6use x509_parser::{
7 certificate::X509Certificate,
8 extensions::{GeneralName, ParsedExtension},
9 oid_registry::{OID_X509_COMMON_NAME, OID_X509_EXT_SUBJECT_ALT_NAME},
10 parse_x509_certificate,
11 pem::{parse_x509_pem, Pem},
12};
13
14use crate::{
15 config::{Config, ConfigError},
16 proto::command::{CertificateAndKey, TlsVersion},
17};
18
19#[derive(thiserror::Error, Debug)]
23pub enum CertificateError {
24 #[error("Could not parse PEM certificate from bytes: {0}")]
25 ParsePEMCertificate(String),
26 #[error("Could not parse X509 certificate from bytes: {0}")]
27 ParseX509Certificate(String),
28 #[error("failed to parse tls version '{0}'")]
29 InvalidTlsVersion(String),
30 #[error("failed to parse fingerprint, {0}")]
31 InvalidFingerprint(FromHexError),
32 #[error("could not load file on path {path}: {error}")]
33 LoadFile { path: String, error: ConfigError },
34 #[error("Failed at decoding the hex encoded certificate: {0}")]
35 DecodeError(FromHexError),
36}
37
38pub fn parse_pem(certificate: &[u8]) -> Result<Pem, CertificateError> {
44 let (_, pem) = parse_x509_pem(certificate)
45 .map_err(|err| CertificateError::ParsePEMCertificate(err.to_string()))?;
46
47 Ok(pem)
48}
49
50pub fn parse_x509(pem_bytes: &[u8]) -> Result<X509Certificate, CertificateError> {
52 parse_x509_certificate(pem_bytes)
53 .map_err(|nom_e| CertificateError::ParseX509Certificate(nom_e.to_string()))
54 .map(|t| t.1)
55}
56
57pub fn get_cn_and_san_attributes(x509: &X509Certificate) -> Vec<String> {
63 let mut names: Vec<String> = Vec::new();
64 for name in x509.subject().iter_by_oid(&OID_X509_COMMON_NAME) {
65 names.push(
66 name.as_str()
67 .map(String::from)
68 .unwrap_or_else(|_| String::from_utf8_lossy(name.as_slice()).to_string()),
69 );
70 }
71
72 for extension in x509.extensions() {
73 if extension.oid == OID_X509_EXT_SUBJECT_ALT_NAME {
74 if let ParsedExtension::SubjectAlternativeName(san) = extension.parsed_extension() {
75 for name in &san.general_names {
76 if let GeneralName::DNSName(name) = name {
77 names.push(name.to_string());
78 }
79 }
80 }
81 }
82 }
83 names.dedup();
84 names
85}
86
87impl FromStr for TlsVersion {
91 type Err = CertificateError;
92
93 fn from_str(s: &str) -> Result<Self, Self::Err> {
94 match s {
95 "SSL_V2" => Ok(TlsVersion::SslV2),
96 "SSL_V3" => Ok(TlsVersion::SslV3),
97 "TLSv1" => Ok(TlsVersion::TlsV10),
98 "TLS_V11" => Ok(TlsVersion::TlsV11),
99 "TLS_V12" => Ok(TlsVersion::TlsV12),
100 "TLS_V13" => Ok(TlsVersion::TlsV13),
101 _ => Err(CertificateError::InvalidTlsVersion(s.to_string())),
102 }
103 }
104}
105
106#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
112pub struct Fingerprint(pub Vec<u8>);
113
114impl FromStr for Fingerprint {
115 type Err = CertificateError;
116
117 fn from_str(s: &str) -> Result<Self, Self::Err> {
118 hex::decode(s)
119 .map_err(CertificateError::InvalidFingerprint)
120 .map(Fingerprint)
121 }
122}
123
124impl fmt::Debug for Fingerprint {
125 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
126 write!(f, "CertificateFingerprint({})", hex::encode(&self.0))
127 }
128}
129
130impl fmt::Display for Fingerprint {
131 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
132 write!(f, "{}", hex::encode(&self.0))
133 }
134}
135
136impl serde::Serialize for Fingerprint {
137 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
138 where
139 S: serde::Serializer,
140 {
141 serializer.serialize_str(&hex::encode(&self.0))
142 }
143}
144
145struct FingerprintVisitor;
146
147impl<'de> Visitor<'de> for FingerprintVisitor {
148 type Value = Fingerprint;
149
150 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
151 formatter.write_str("the certificate fingerprint must be in hexadecimal format")
152 }
153
154 fn visit_str<E>(self, value: &str) -> Result<Fingerprint, E>
155 where
156 E: de::Error,
157 {
158 FromHex::from_hex(value)
159 .map_err(|e| E::custom(format!("could not deserialize hex: {e:?}")))
160 .map(Fingerprint)
161 }
162}
163
164impl<'de> serde::Deserialize<'de> for Fingerprint {
165 fn deserialize<D>(deserializer: D) -> Result<Fingerprint, D::Error>
166 where
167 D: serde::de::Deserializer<'de>,
168 {
169 deserializer.deserialize_str(FingerprintVisitor {})
170 }
171}
172
173pub fn calculate_fingerprint_from_der(certificate: &[u8]) -> Vec<u8> {
175 Sha256::digest(certificate).iter().cloned().collect()
176}
177
178pub fn calculate_fingerprint(certificate: &[u8]) -> Result<Vec<u8>, CertificateError> {
180 let parsed_certificate = parse_pem(certificate)?;
181 let fingerprint = calculate_fingerprint_from_der(&parsed_certificate.contents);
182 Ok(fingerprint)
183}
184
185pub fn split_certificate_chain(mut chain: String) -> Vec<String> {
186 let mut v = Vec::new();
187
188 let end = "-----END CERTIFICATE-----";
189 loop {
190 if let Some(sz) = chain.find(end) {
191 let cert: String = chain.drain(..sz + end.len()).collect();
192 v.push(cert.trim().to_string());
193 continue;
194 }
195
196 break;
197 }
198
199 v
200}
201
202pub fn get_fingerprint_from_certificate_path(
203 certificate_path: &str,
204) -> Result<Fingerprint, CertificateError> {
205 let bytes =
206 Config::load_file_bytes(certificate_path).map_err(|e| CertificateError::LoadFile {
207 path: certificate_path.to_string(),
208 error: e,
209 })?;
210
211 let parsed_bytes = calculate_fingerprint(&bytes)?;
212
213 Ok(Fingerprint(parsed_bytes))
214}
215
216pub fn decode_fingerprint(fingerprint: &str) -> Result<Fingerprint, CertificateError> {
217 let bytes = hex::decode(fingerprint).map_err(CertificateError::DecodeError)?;
218 Ok(Fingerprint(bytes))
219}
220
221pub fn load_full_certificate(
222 certificate_path: &str,
223 certificate_chain_path: &str,
224 key_path: &str,
225 versions: Vec<TlsVersion>,
226 names: Vec<String>,
227) -> Result<CertificateAndKey, CertificateError> {
228 let certificate =
229 Config::load_file(certificate_path).map_err(|e| CertificateError::LoadFile {
230 path: certificate_path.to_string(),
231 error: e,
232 })?;
233
234 let certificate_chain = Config::load_file(certificate_chain_path)
235 .map(split_certificate_chain)
236 .map_err(|e| CertificateError::LoadFile {
237 path: certificate_chain_path.to_string(),
238 error: e,
239 })?;
240
241 let key = Config::load_file(key_path).map_err(|e| CertificateError::LoadFile {
242 path: key_path.to_string(),
243 error: e,
244 })?;
245
246 let versions = versions.iter().map(|v| *v as i32).collect();
247
248 Ok(CertificateAndKey {
249 certificate,
250 certificate_chain,
251 key,
252 versions,
253 names,
254 })
255}
256
257impl CertificateAndKey {
258 pub fn fingerprint(&self) -> Result<Fingerprint, CertificateError> {
259 let pem = parse_pem(self.certificate.as_bytes())?;
260 let fingerprint = Fingerprint(Sha256::digest(pem.contents).iter().cloned().collect());
261 Ok(fingerprint)
262 }
263
264 pub fn get_overriding_names(&self) -> Result<Vec<String>, CertificateError> {
265 if self.names.is_empty() {
266 let pem = parse_pem(self.certificate.as_bytes())?;
267 let x509 = parse_x509(&pem.contents)?;
268
269 let overriding_names = get_cn_and_san_attributes(&x509);
270
271 Ok(overriding_names.into_iter().collect())
272 } else {
273 Ok(self.names.to_owned())
274 }
275 }
276
277 pub fn apply_overriding_names(&mut self) -> Result<(), CertificateError> {
278 self.names = self.get_overriding_names()?;
279 Ok(())
280 }
281}