1#[cfg(test)]
7use std::collections::HashSet;
8use std::{
9 collections::HashMap,
10 fmt,
11 io::BufReader,
12 str::FromStr,
13 sync::{Arc, LazyLock, Mutex},
14};
15
16use rustls::{
17 crypto::ring::sign::any_supported_type,
18 pki_types::{CertificateDer, PrivateKeyDer},
19 server::{ClientHello, ResolvesServerCert},
20 sign::CertifiedKey,
21};
22use sha2::{Digest, Sha256};
23use sozu_command::{
24 certificate::{
25 CertificateError, Fingerprint, get_cn_and_san_attributes, parse_pem, parse_x509,
26 },
27 proto::command::{AddCertificate, CertificateAndKey, ReplaceCertificate, SocketAddress},
28};
29
30use crate::router::pattern_trie::{Key, KeyValue, TrieNode};
31
32static DEFAULT_CERTIFICATE: LazyLock<Option<Arc<CertifiedKey>>> = LazyLock::new(|| {
36 let add = AddCertificate {
37 certificate: CertificateAndKey {
38 certificate: include_str!("../assets/certificate.pem").to_string(),
39 certificate_chain: vec![include_str!("../assets/certificate_chain.pem").to_string()],
40 key: include_str!("../assets/key.pem").to_string(),
41 versions: vec![],
42 names: vec![],
43 },
44 address: SocketAddress::new_v4(0, 0, 0, 0, 8080), expired_at: None,
46 };
47 CertifiedKeyWrapper::try_from(&add).ok().map(|c| c.inner)
48});
49
50#[derive(thiserror::Error, Debug)]
51pub enum CertificateResolverError {
52 #[error("failed to get common name and subject alternate names from pem, {0}")]
53 InvalidCommonNameAndSubjectAlternateNames(CertificateError),
54 #[error("invalid private key: {0}")]
55 InvalidPrivateKey(String),
56 #[error("empty key")]
57 EmptyKeys,
58 #[error("error parsing x509 cert from bytes: {0}")]
59 ParseX509(CertificateError),
60 #[error("error parsing pem formated certificate from bytes: {0}")]
61 ParsePem(CertificateError),
62 #[error("error parsing overriding names in new certificate: {0}")]
63 ParseOverridingNames(CertificateError),
64}
65
66#[derive(Clone, Debug)]
70pub struct CertifiedKeyWrapper {
71 inner: Arc<CertifiedKey>,
72 names: Vec<String>,
74 expiration: i64,
75 fingerprint: Fingerprint,
76}
77
78impl TryFrom<&AddCertificate> for CertifiedKeyWrapper {
81 type Error = CertificateResolverError;
82
83 fn try_from(add: &AddCertificate) -> Result<Self, Self::Error> {
84 let cert = add.certificate.clone();
85
86 let pem =
87 parse_pem(cert.certificate.as_bytes()).map_err(CertificateResolverError::ParsePem)?;
88
89 let x509 = parse_x509(&pem.contents).map_err(CertificateResolverError::ParseX509)?;
90
91 let overriding_names = if add.certificate.names.is_empty() {
92 get_cn_and_san_attributes(&x509)
93 } else {
94 add.certificate.names.clone()
95 };
96
97 let expiration = add
98 .expired_at
99 .unwrap_or(x509.validity().not_after.timestamp());
100
101 let fingerprint = Fingerprint(Sha256::digest(&pem.contents).iter().cloned().collect());
102
103 let mut chain = vec![CertificateDer::from(pem.contents)];
104 for cert in &cert.certificate_chain {
105 let chain_link = parse_pem(cert.as_bytes())
106 .map_err(CertificateResolverError::ParsePem)?
107 .contents;
108
109 chain.push(CertificateDer::from(chain_link));
110 }
111
112 let mut key_reader = BufReader::new(cert.key.as_bytes());
113
114 let item = match rustls_pemfile::read_one(&mut key_reader)
115 .map_err(|_| CertificateResolverError::EmptyKeys)?
116 {
117 Some(item) => item,
118 None => return Err(CertificateResolverError::EmptyKeys),
119 };
120
121 let private_key = match item {
122 rustls_pemfile::Item::Pkcs1Key(rsa_key) => PrivateKeyDer::from(rsa_key),
123 rustls_pemfile::Item::Pkcs8Key(pkcs8_key) => PrivateKeyDer::from(pkcs8_key),
124 rustls_pemfile::Item::Sec1Key(ec_key) => PrivateKeyDer::from(ec_key),
125 _ => return Err(CertificateResolverError::EmptyKeys),
126 };
127
128 match any_supported_type(&private_key) {
129 Ok(signing_key) => {
130 let stored_certificate = CertifiedKeyWrapper {
131 inner: Arc::new(CertifiedKey::new(chain, signing_key)),
132 names: overriding_names,
133 expiration,
134 fingerprint,
135 };
136 Ok(stored_certificate)
137 }
138 Err(sign_error) => Err(CertificateResolverError::InvalidPrivateKey(
139 sign_error.to_string(),
140 )),
141 }
142 }
143}
144
145#[derive(Default, Debug)]
152pub struct CertificateResolver {
153 pub domains: TrieNode<Fingerprint>,
155 certificates: HashMap<Fingerprint, CertifiedKeyWrapper>,
157 name_fingerprint_idx: HashMap<String, Vec<(Fingerprint, i64)>>,
161}
162
163impl CertificateResolver {
164 pub fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
166 self.certificates.get(fingerprint).map(ToOwned::to_owned)
167 }
168
169 pub fn add_certificate(
172 &mut self,
173 add: &AddCertificate,
174 ) -> Result<Fingerprint, CertificateResolverError> {
175 let cert_to_add = CertifiedKeyWrapper::try_from(add)?;
176
177 trace!("Certificate Resolver: adding certificate {:?}", cert_to_add);
178
179 if self.certificates.contains_key(&cert_to_add.fingerprint) {
180 return Ok(cert_to_add.fingerprint);
181 }
182
183 for new_name in &cert_to_add.names {
184 let fingerprints_for_this_name = self
185 .name_fingerprint_idx
186 .entry(new_name.to_owned())
187 .or_default();
188
189 fingerprints_for_this_name
190 .push((cert_to_add.fingerprint.clone(), cert_to_add.expiration));
191
192 fingerprints_for_this_name.sort_by_key(|t| t.1);
194
195 let longest_lived_cert = match fingerprints_for_this_name.last() {
196 Some(cert) => cert,
197 None => {
198 error!("no fingerprint for this name, this should not happen");
199 continue;
200 }
201 };
202
203 self.domains.remove(&new_name.to_owned().into_bytes());
205 self.domains.insert(
206 new_name.to_owned().into_bytes(),
207 longest_lived_cert.0.to_owned(),
208 );
209 }
210
211 self.certificates
212 .insert(cert_to_add.fingerprint.to_owned(), cert_to_add.clone());
213
214 trace!("{:#?}", self);
215
216 Ok(cert_to_add.fingerprint)
217 }
218
219 pub fn remove_certificate(
222 &mut self,
223 fingerprint: &Fingerprint,
224 ) -> Result<(), CertificateResolverError> {
225 if let Some(certificate_to_remove) = self.get_certificate(fingerprint) {
226 for name in certificate_to_remove.names {
227 self.domains.domain_remove(&name.clone().into_bytes());
228
229 if let Some(fingerprints_and_exp) = self.name_fingerprint_idx.get_mut(&name) {
230 *fingerprints_and_exp = fingerprints_and_exp
232 .drain(..)
233 .filter(|t| &t.0 != fingerprint)
234 .collect();
235
236 if let Some(longest_lived_cert) = fingerprints_and_exp.last() {
238 self.domains
239 .insert(name.into_bytes(), longest_lived_cert.0.to_owned());
240 }
241 }
242 }
243
244 self.certificates.remove(fingerprint);
245 }
246 trace!("{:#?}", self);
247
248 Ok(())
249 }
250
251 pub fn replace_certificate(
255 &mut self,
256 replace: &ReplaceCertificate,
257 ) -> Result<Fingerprint, CertificateResolverError> {
258 match Fingerprint::from_str(&replace.old_fingerprint) {
259 Ok(old_fingerprint) => self.remove_certificate(&old_fingerprint)?,
260 Err(err) => {
261 error!("failed to parse fingerprint, {}", err);
262 }
263 }
264
265 self.add_certificate(&AddCertificate {
266 address: replace.address.to_owned(),
267 certificate: replace.new_certificate.to_owned(),
268 expired_at: replace.new_expired_at.to_owned(),
269 })
270 }
271
272 #[cfg(test)]
275 fn find_certificates_by_names(
276 &self,
277 names: &HashSet<String>,
278 ) -> Result<HashSet<Fingerprint>, CertificateResolverError> {
279 let mut fingerprints = HashSet::new();
280 for name in names {
281 if let Some(fprints) = self.name_fingerprint_idx.get(name) {
282 fprints.iter().for_each(|fingerprint| {
283 fingerprints.insert(fingerprint.to_owned().0);
284 });
285 }
286 }
287
288 Ok(fingerprints)
289 }
290
291 #[cfg(test)]
294 fn certificate_names(
295 &self,
296 fingerprint: &Fingerprint,
297 ) -> Result<HashSet<String>, CertificateResolverError> {
298 if let Some(cert) = self.certificates.get(fingerprint) {
299 return Ok(cert.names.iter().cloned().collect());
300 }
301 Ok(HashSet::new())
302 }
303
304 pub fn domain_lookup(
305 &self,
306 domain: &[u8],
307 accept_wildcard: bool,
308 ) -> Option<&KeyValue<Key, Fingerprint>> {
309 self.domains.domain_lookup(domain, accept_wildcard)
310 }
311}
312
313#[derive(Default)]
317pub struct MutexCertificateResolver(pub Mutex<CertificateResolver>);
318
319impl ResolvesServerCert for MutexCertificateResolver {
320 fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
321 let server_name = client_hello.server_name();
322 let sigschemes = client_hello.signature_schemes();
323
324 if server_name.is_none() {
325 error!("cannot look up certificate: no SNI from session");
326 return None;
327 }
328
329 let name: &str = server_name.unwrap();
330 trace!(
331 "trying to resolve name: {:?} for signature scheme: {:?}",
332 name, sigschemes
333 );
334 if let Ok(ref mut resolver) = self.0.try_lock() {
335 if let Some((_, fingerprint)) = resolver.domains.domain_lookup(name.as_bytes(), true) {
337 trace!(
338 "looking for certificate for {:?} with fingerprint {:?}",
339 name, fingerprint
340 );
341
342 let cert = resolver
343 .certificates
344 .get(fingerprint)
345 .map(|cert| cert.inner.clone());
346
347 trace!("Found for fingerprint {}: {}", fingerprint, cert.is_some());
348 return cert;
349 }
350 }
351
352 debug!("Default certificate is used for {}", name);
356 incr!("tls.default_cert_used");
357 DEFAULT_CERTIFICATE.clone()
358 }
359}
360
361impl fmt::Debug for MutexCertificateResolver {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 f.write_str("MutexWrappedCertificateResolver")
364 }
365}
366
367#[cfg(test)]
371mod tests {
372 use std::{
373 collections::HashSet,
374 error::Error,
375 time::{Duration, SystemTime},
376 };
377
378 use sozu_command::proto::command::{AddCertificate, CertificateAndKey, SocketAddress};
380
381 use super::CertificateResolver;
382
383 #[test]
384 fn lifecycle() -> Result<(), Box<dyn Error + Send + Sync>> {
385 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
386 let mut resolver = CertificateResolver::default();
387 let certificate_and_key = CertificateAndKey {
388 certificate: String::from(include_str!("../assets/certificate.pem")),
389 key: String::from(include_str!("../assets/key.pem")),
390 ..Default::default()
391 };
392
393 let fingerprint = resolver
394 .add_certificate(&AddCertificate {
395 address,
396 certificate: certificate_and_key,
397 expired_at: None,
398 })
399 .expect("could not add certificate");
400
401 if resolver.get_certificate(&fingerprint).is_none() {
402 return Err("failed to retrieve certificate".into());
403 }
404
405 let names = resolver.certificate_names(&fingerprint)?;
407
408 if let Err(err) = resolver.remove_certificate(&fingerprint) {
409 return Err(format!("the certificate was not removed, {err}").into());
410 }
411
412 if resolver.get_certificate(&fingerprint).is_some() {
413 return Err("We have retrieved the certificate that should be deleted".into());
414 }
415
416 if !resolver.find_certificates_by_names(&names)?.is_empty() {
417 return Err(
418 "The certificate should be deleted but one of its names is in the index".into(),
419 );
420 }
421
422 Ok(())
423 }
424
425 #[test]
426 fn name_override() -> Result<(), Box<dyn Error + Send + Sync>> {
427 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
428 let mut resolver = CertificateResolver::default();
429 let certificate_and_key = CertificateAndKey {
430 certificate: String::from(include_str!("../assets/certificate.pem")),
431 key: String::from(include_str!("../assets/key.pem")),
432 names: vec!["localhost".into(), "lolcatho.st".into()],
433 ..Default::default()
434 };
435
436 let fingerprint = resolver.add_certificate(&AddCertificate {
437 address,
438 certificate: certificate_and_key,
439 expired_at: None,
440 })?;
441
442 if resolver.get_certificate(&fingerprint).is_none() {
443 return Err("failed to retrieve certificate".into());
444 }
445
446 let mut lolcat = HashSet::new();
447 lolcat.insert(String::from("lolcatho.st"));
448 if resolver.find_certificates_by_names(&lolcat)?.is_empty()
449 || resolver.get_certificate(&fingerprint).is_none()
450 {
451 return Err("failed to retrieve certificate with custom names".into());
452 }
453
454 if let Err(err) = resolver.remove_certificate(&fingerprint) {
455 return Err(format!("the certificate could not be removed, {err}").into());
456 }
457
458 let names = resolver.certificate_names(&fingerprint)?;
459 if !resolver.find_certificates_by_names(&names)?.is_empty()
460 && resolver.get_certificate(&fingerprint).is_some()
461 {
462 return Err("We have retrieved the certificate that should be deleted".into());
463 }
464
465 Ok(())
466 }
467
468 #[test]
469 fn keep_resolving_with_wildcard() -> Result<(), Box<dyn Error + Send + Sync>> {
470 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
471 let mut resolver = CertificateResolver::default();
472
473 let wildcard_example_org = CertificateAndKey {
476 certificate: String::from(include_str!("../assets/tests/certificate-3.pem")),
477 key: String::from(include_str!("../assets/tests/key.pem")),
478 ..Default::default()
479 };
480
481 let wildcard_example_org_fingerprint = resolver.add_certificate(&AddCertificate {
482 address: address.clone(),
483 certificate: wildcard_example_org,
484 expired_at: Some(
485 (SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
486 + Duration::from_secs(1 * 365 * 24 * 3600))
487 .as_secs() as i64,
488 ),
489 })?;
490
491 if resolver
492 .get_certificate(&wildcard_example_org_fingerprint)
493 .is_none()
494 {
495 return Err("could not load the 2-year-valid certificate".into());
496 }
497
498 let www_example_org = CertificateAndKey {
502 certificate: String::from(include_str!("../assets/tests/certificate-2.pem")),
503 key: String::from(include_str!("../assets/tests/key.pem")),
504 ..Default::default()
505 };
506
507 let www_example_org_fingerprint = resolver.add_certificate(&AddCertificate {
508 address,
509 certificate: www_example_org,
510 expired_at: Some(
511 (SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
512 + Duration::from_secs(2 * 365 * 24 * 3600))
513 .as_secs() as i64,
514 ),
515 ..Default::default()
516 })?;
517
518 let www_example_org = resolver
519 .domain_lookup("www.example.org".as_bytes(), true)
520 .expect("there should be a www.example.org cert");
521 assert_eq!(www_example_org.1, www_example_org_fingerprint);
522
523 let test_example_org = resolver
524 .domain_lookup("test.example.org".as_bytes(), true)
525 .expect("there should be a test.example.org cert");
526 assert_eq!(test_example_org.1, wildcard_example_org_fingerprint);
527
528 let example_org = resolver
529 .domain_lookup("example.org".as_bytes(), true)
530 .expect("there should be a example.org cert");
531 assert_eq!(example_org.1, www_example_org_fingerprint);
532
533 resolver
536 .remove_certificate(&www_example_org_fingerprint)
537 .expect("should be able to remove the 2-year certificate");
538
539 let should_be_wildcard_fingerprint = resolver
540 .domain_lookup("www.example.org".as_bytes(), true)
541 .expect("there should be a www.example.org cert");
542 assert_eq!(
543 should_be_wildcard_fingerprint.1,
544 wildcard_example_org_fingerprint
545 );
546
547 assert!(
548 resolver
549 .domain_lookup("example.org".as_bytes(), true)
550 .is_none()
551 );
552
553 Ok(())
554 }
555
556 #[test]
557 fn resolve_the_longer_lived_cert() -> Result<(), Box<dyn Error + Send + Sync>> {
558 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
559 let mut resolver = CertificateResolver::default();
560
561 let certificate_and_key_2y = CertificateAndKey {
564 certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
565 key: String::from(include_str!("../assets/tests/key-2y.pem")),
566 ..Default::default()
567 };
568
569 let fingerprint_2y = resolver.add_certificate(&AddCertificate {
570 address: address.clone(),
571 certificate: certificate_and_key_2y,
572 expired_at: None,
573 })?;
574
575 if resolver.get_certificate(&fingerprint_2y).is_none() {
576 return Err("could not load the 2-year-valid certificate".into());
577 }
578
579 let certificate_and_key_1y = CertificateAndKey {
582 certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
583 key: String::from(include_str!("../assets/tests/key-1y.pem")),
584 ..Default::default()
585 };
586
587 let fingerprint_1y = resolver.add_certificate(&AddCertificate {
588 address,
589 certificate: certificate_and_key_1y,
590 ..Default::default()
591 })?;
592
593 let localhost_cert = resolver
594 .domain_lookup("localhost".as_bytes(), true)
595 .expect("there should be a localhost cert");
596
597 assert_eq!(localhost_cert.1, fingerprint_2y);
598
599 resolver
603 .remove_certificate(&fingerprint_2y)
604 .expect("should be able to remove the 2-year certificate");
605
606 let localhost_cert = resolver
607 .domain_lookup("localhost".as_bytes(), true)
608 .expect("there should be a localhost cert");
609
610 assert_eq!(localhost_cert.1, fingerprint_1y);
611
612 Ok(())
613 }
614
615 #[test]
616 fn expiration_override() -> Result<(), Box<dyn Error + Send + Sync>> {
617 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
618 let mut resolver = CertificateResolver::default();
619
620 let certificate_and_key_1y = CertificateAndKey {
623 certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
624 key: String::from(include_str!("../assets/tests/key-1y.pem")),
625 ..Default::default()
626 };
627
628 let fingerprint_1y_overriden = resolver.add_certificate(&AddCertificate {
629 address: address.clone(),
630 certificate: certificate_and_key_1y,
631 expired_at: Some(
632 (SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
633 + Duration::from_secs(3 * 365 * 24 * 3600))
634 .as_secs() as i64,
635 ),
636 })?;
637
638 if resolver
639 .get_certificate(&fingerprint_1y_overriden)
640 .is_none()
641 {
642 return Err("failed to retrieve certificate".into());
643 }
644
645 let certificate_and_key_2y = CertificateAndKey {
648 certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
649 key: String::from(include_str!("../assets/tests/key-2y.pem")),
650 ..Default::default()
651 };
652
653 let fingerprint_2y = resolver.add_certificate(&AddCertificate {
654 address,
655 certificate: certificate_and_key_2y,
656 expired_at: None,
657 })?;
658
659 let localhost_cert = resolver
660 .domain_lookup("localhost".as_bytes(), true)
661 .expect("there should be a localhost cert");
662
663 assert_eq!(localhost_cert.1, fingerprint_1y_overriden);
664
665 resolver
669 .remove_certificate(&fingerprint_1y_overriden)
670 .expect("should be able to remove the 1-year (3-year-overriden) certificate");
671
672 let localhost_cert = resolver
673 .domain_lookup("localhost".as_bytes(), true)
674 .expect("there should be a localhost cert");
675
676 assert_eq!(localhost_cert.1, fingerprint_2y);
677
678 Ok(())
679 }
680}