1use std::{fmt, str::FromStr};
2
3use hex::{FromHex, FromHexError};
4use serde::de::{self, Visitor};
5use sha2::{Digest, Sha256};
6use x509_parser::{
7 certificate::X509Certificate,
8 extensions::{GeneralName, ParsedExtension},
9 oid_registry::{OID_X509_COMMON_NAME, OID_X509_EXT_SUBJECT_ALT_NAME},
10 parse_x509_certificate,
11 pem::{Pem, parse_x509_pem},
12};
13
14use crate::{
15 config::{Config, ConfigError},
16 proto::command::{CertificateAndKey, TlsVersion},
17};
18
19#[allow(dead_code)]
29const SHA256_FINGERPRINT_LEN: usize = 32;
30
31#[derive(thiserror::Error, Debug)]
35pub enum CertificateError {
36 #[error("Could not parse PEM certificate from bytes: {0}")]
37 ParsePEMCertificate(String),
38 #[error("Could not parse X509 certificate from bytes: {0}")]
39 ParseX509Certificate(String),
40 #[error("failed to parse tls version '{0}'")]
41 InvalidTlsVersion(String),
42 #[error("failed to parse fingerprint, {0}")]
43 InvalidFingerprint(FromHexError),
44 #[error("could not load file on path {path}: {error}")]
45 LoadFile { path: String, error: ConfigError },
46 #[error("Failed at decoding the hex encoded certificate: {0}")]
47 DecodeError(FromHexError),
48}
49
50pub fn parse_pem(certificate: &[u8]) -> Result<Pem, CertificateError> {
56 let (_, pem) = parse_x509_pem(certificate)
57 .map_err(|err| CertificateError::ParsePEMCertificate(err.to_string()))?;
58
59 Ok(pem)
60}
61
62pub fn parse_x509(pem_bytes: &[u8]) -> Result<X509Certificate<'_>, CertificateError> {
64 parse_x509_certificate(pem_bytes)
65 .map_err(|nom_e| CertificateError::ParseX509Certificate(nom_e.to_string()))
66 .map(|t| t.1)
67}
68
69pub fn get_cn_and_san_attributes(x509: &X509Certificate) -> Vec<String> {
86 let mut names: Vec<String> = Vec::new();
87 let mut san_dns_seen = false;
88
89 for extension in x509.extensions() {
90 if extension.oid == OID_X509_EXT_SUBJECT_ALT_NAME {
91 if let ParsedExtension::SubjectAlternativeName(san) = extension.parsed_extension() {
92 for name in &san.general_names {
93 if let GeneralName::DNSName(name) = name {
94 san_dns_seen = true;
95 names.push(name.to_string());
96 }
97 }
98 }
99 }
100 }
101
102 debug_assert_eq!(
108 san_dns_seen,
109 !names.is_empty(),
110 "SAN dNSName presence must match the collected-names state before CN fallback"
111 );
112
113 if !san_dns_seen {
114 for name in x509.subject().iter_by_oid(&OID_X509_COMMON_NAME) {
115 names.push(
116 name.as_str()
117 .map(String::from)
118 .unwrap_or_else(|_| String::from_utf8_lossy(name.as_slice()).to_string()),
119 );
120 }
121 }
122 let before_dedup = names.len();
123 names.dedup();
124 debug_assert!(
127 names.len() <= before_dedup,
128 "dedup must not grow the identity list"
129 );
130 names
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
141 fn san_dns_present_excludes_cn() {
142 let pem = parse_pem(include_str!("../../lib/assets/cn-ne-san-cert.pem").as_bytes())
143 .expect("parse PEM");
144 let x509 = parse_x509(&pem.contents).expect("parse x509");
145 let names = get_cn_and_san_attributes(&x509);
146 assert_eq!(names, vec![String::from("tenant-a.example")]);
147 }
148
149 #[test]
153 fn cn_used_when_san_absent() {
154 let pem = parse_pem(include_str!("../../lib/assets/certificate.pem").as_bytes())
155 .expect("parse PEM");
156 let x509 = parse_x509(&pem.contents).expect("parse x509");
157 let names = get_cn_and_san_attributes(&x509);
158 assert_eq!(names, vec![String::from("lolcatho.st")]);
159 }
160
161 #[test]
166 fn san_dns_present_cn_is_san_member() {
167 let pem = parse_pem(include_str!("../../lib/assets/multi-sni-cert.pem").as_bytes())
168 .expect("parse PEM");
169 let x509 = parse_x509(&pem.contents).expect("parse x509");
170 let names = get_cn_and_san_attributes(&x509);
171 assert!(names.contains(&String::from("foo.example.com")));
172 assert!(names.contains(&String::from("bar.example.com")));
173 assert!(names.contains(&String::from("baz.example.com")));
174 assert!(names.contains(&String::from("localhost")));
175 assert_eq!(names.len(), 4);
176 }
177}
178
179impl FromStr for TlsVersion {
183 type Err = CertificateError;
184
185 fn from_str(s: &str) -> Result<Self, Self::Err> {
186 match s {
187 "SSL_V2" => Ok(TlsVersion::SslV2),
188 "SSL_V3" => Ok(TlsVersion::SslV3),
189 "TLSv1" => Ok(TlsVersion::TlsV10),
190 "TLS_V11" => Ok(TlsVersion::TlsV11),
191 "TLS_V12" => Ok(TlsVersion::TlsV12),
192 "TLS_V13" => Ok(TlsVersion::TlsV13),
193 _ => Err(CertificateError::InvalidTlsVersion(s.to_string())),
194 }
195 }
196}
197
198#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
204pub struct Fingerprint(pub Vec<u8>);
205
206impl FromStr for Fingerprint {
207 type Err = CertificateError;
208
209 fn from_str(s: &str) -> Result<Self, Self::Err> {
210 hex::decode(s)
211 .map_err(CertificateError::InvalidFingerprint)
212 .map(Fingerprint)
213 }
214}
215
216impl fmt::Debug for Fingerprint {
217 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
218 write!(f, "CertificateFingerprint({})", hex::encode(&self.0))
219 }
220}
221
222impl fmt::Display for Fingerprint {
223 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
224 write!(f, "{}", hex::encode(&self.0))
225 }
226}
227
228impl serde::Serialize for Fingerprint {
229 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
230 where
231 S: serde::Serializer,
232 {
233 serializer.serialize_str(&hex::encode(&self.0))
234 }
235}
236
237struct FingerprintVisitor;
238
239impl Visitor<'_> for FingerprintVisitor {
240 type Value = Fingerprint;
241
242 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
243 formatter.write_str("the certificate fingerprint must be in hexadecimal format")
244 }
245
246 fn visit_str<E>(self, value: &str) -> Result<Fingerprint, E>
247 where
248 E: de::Error,
249 {
250 FromHex::from_hex(value)
251 .map_err(|e| E::custom(format!("could not deserialize hex: {e:?}")))
252 .map(Fingerprint)
253 }
254}
255
256impl<'de> serde::Deserialize<'de> for Fingerprint {
257 fn deserialize<D>(deserializer: D) -> Result<Fingerprint, D::Error>
258 where
259 D: serde::de::Deserializer<'de>,
260 {
261 deserializer.deserialize_str(FingerprintVisitor {})
262 }
263}
264
265pub fn calculate_fingerprint_from_der(certificate: &[u8]) -> Vec<u8> {
267 let fingerprint: Vec<u8> = Sha256::digest(certificate).iter().cloned().collect();
268 debug_assert_eq!(
272 fingerprint.len(),
273 SHA256_FINGERPRINT_LEN,
274 "SHA-256 fingerprint must be exactly 32 bytes"
275 );
276 fingerprint
277}
278
279pub fn calculate_fingerprint(certificate: &[u8]) -> Result<Vec<u8>, CertificateError> {
281 let parsed_certificate = parse_pem(certificate)?;
282 let fingerprint = calculate_fingerprint_from_der(&parsed_certificate.contents);
283 debug_assert_eq!(
287 fingerprint.len(),
288 SHA256_FINGERPRINT_LEN,
289 "PEM fingerprint must be a 32-byte SHA-256 digest"
290 );
291 debug_assert_eq!(
292 fingerprint,
293 calculate_fingerprint_from_der(&parsed_certificate.contents),
294 "fingerprint must be a deterministic function of the DER contents"
295 );
296 Ok(fingerprint)
297}
298
299pub fn split_certificate_chain(mut chain: String) -> Vec<String> {
300 let mut v = Vec::new();
301
302 let end = "-----END CERTIFICATE-----";
303 let expected_certs = chain.matches(end).count();
309 loop {
310 if let Some(sz) = chain.find(end) {
311 let cert: String = chain.drain(..sz + end.len()).collect();
312 debug_assert!(
316 cert.contains(end),
317 "each split block must contain its END CERTIFICATE marker"
318 );
319 v.push(cert.trim().to_string());
320 continue;
321 }
322
323 break;
324 }
325
326 debug_assert_eq!(
328 v.len(),
329 expected_certs,
330 "split must yield exactly one certificate per END marker"
331 );
332 v
333}
334
335pub fn get_fingerprint_from_certificate_path(
336 certificate_path: &str,
337) -> Result<Fingerprint, CertificateError> {
338 let bytes =
339 Config::load_file_bytes(certificate_path).map_err(|e| CertificateError::LoadFile {
340 path: certificate_path.to_string(),
341 error: e,
342 })?;
343
344 let parsed_bytes = calculate_fingerprint(&bytes)?;
345
346 debug_assert_eq!(
350 parsed_bytes.len(),
351 SHA256_FINGERPRINT_LEN,
352 "fingerprint loaded from a certificate path must be 32 bytes"
353 );
354 Ok(Fingerprint(parsed_bytes))
355}
356
357pub fn decode_fingerprint(fingerprint: &str) -> Result<Fingerprint, CertificateError> {
358 let bytes = hex::decode(fingerprint).map_err(CertificateError::DecodeError)?;
359 Ok(Fingerprint(bytes))
360}
361
362pub fn load_full_certificate(
363 certificate_path: &str,
364 certificate_chain_path: &str,
365 key_path: &str,
366 versions: Vec<TlsVersion>,
367 names: Vec<String>,
368) -> Result<CertificateAndKey, CertificateError> {
369 let certificate =
370 Config::load_file(certificate_path).map_err(|e| CertificateError::LoadFile {
371 path: certificate_path.to_string(),
372 error: e,
373 })?;
374
375 let certificate_chain = Config::load_file(certificate_chain_path)
376 .map(split_certificate_chain)
377 .map_err(|e| CertificateError::LoadFile {
378 path: certificate_chain_path.to_string(),
379 error: e,
380 })?;
381
382 let key = Config::load_file(key_path).map_err(|e| CertificateError::LoadFile {
383 path: key_path.to_string(),
384 error: e,
385 })?;
386
387 let versions_len = versions.len();
388 let names_len = names.len();
389 let versions: Vec<i32> = versions.iter().map(|v| *v as i32).collect();
390
391 debug_assert_eq!(
394 versions.len(),
395 versions_len,
396 "version encoding must preserve the input cardinality"
397 );
398
399 let built = CertificateAndKey {
400 certificate,
401 certificate_chain,
402 key,
403 versions,
404 names,
405 };
406
407 debug_assert_eq!(
411 built.names.len(),
412 names_len,
413 "names must be carried through the builder unchanged"
414 );
415 Ok(built)
416}
417
418impl CertificateAndKey {
419 pub fn fingerprint(&self) -> Result<Fingerprint, CertificateError> {
420 let pem = parse_pem(self.certificate.as_bytes())?;
421 let fingerprint = Fingerprint(Sha256::digest(&pem.contents).iter().cloned().collect());
422 debug_assert_eq!(
427 fingerprint.0.len(),
428 SHA256_FINGERPRINT_LEN,
429 "CertificateAndKey fingerprint must be 32 bytes"
430 );
431 debug_assert_eq!(
432 fingerprint.0,
433 calculate_fingerprint_from_der(&pem.contents),
434 "method and free-function fingerprints must agree on the same DER"
435 );
436 Ok(fingerprint)
437 }
438
439 pub fn get_overriding_names(&self) -> Result<Vec<String>, CertificateError> {
440 if self.names.is_empty() {
441 let pem = parse_pem(self.certificate.as_bytes())?;
442 let x509 = parse_x509(&pem.contents)?;
443
444 let overriding_names = get_cn_and_san_attributes(&x509);
445
446 Ok(overriding_names.into_iter().collect())
447 } else {
448 let names = self.names.to_owned();
449 debug_assert_eq!(
453 names, self.names,
454 "explicit names must be returned unchanged when present"
455 );
456 Ok(names)
457 }
458 }
459
460 pub fn apply_overriding_names(&mut self) -> Result<(), CertificateError> {
461 let resolved = self.get_overriding_names()?;
462 self.names = resolved.clone();
463 debug_assert_eq!(
467 self.names, resolved,
468 "applied names must equal the resolved set"
469 );
470 Ok(())
471 }
472}