1#[cfg(test)]
7use std::collections::HashSet;
8use std::{
9 collections::HashMap,
10 fmt,
11 io::BufReader,
12 str::FromStr,
13 sync::{Arc, LazyLock, Mutex},
14};
15
16use rustls::{
17 crypto::ring::sign::any_supported_type,
18 pki_types::{CertificateDer, PrivateKeyDer},
19 server::{ClientHello, ResolvesServerCert},
20 sign::CertifiedKey,
21};
22use sha2::{Digest, Sha256};
23use sozu_command::{
24 certificate::{
25 get_cn_and_san_attributes, parse_pem, parse_x509, CertificateError, Fingerprint,
26 },
27 proto::command::{AddCertificate, CertificateAndKey, ReplaceCertificate, SocketAddress},
28};
29
30use crate::router::pattern_trie::{Key, KeyValue, TrieNode};
31
32static DEFAULT_CERTIFICATE: LazyLock<Option<Arc<CertifiedKey>>> = LazyLock::new(|| {
36 let add = AddCertificate {
37 certificate: CertificateAndKey {
38 certificate: include_str!("../assets/certificate.pem").to_string(),
39 certificate_chain: vec![include_str!("../assets/certificate_chain.pem").to_string()],
40 key: include_str!("../assets/key.pem").to_string(),
41 versions: vec![],
42 names: vec![],
43 },
44 address: SocketAddress::new_v4(0, 0, 0, 0, 8080), expired_at: None,
46 };
47 CertifiedKeyWrapper::try_from(&add).ok().map(|c| c.inner)
48});
49
50#[derive(thiserror::Error, Debug)]
51pub enum CertificateResolverError {
52 #[error("failed to get common name and subject alternate names from pem, {0}")]
53 InvalidCommonNameAndSubjectAlternateNames(CertificateError),
54 #[error("invalid private key: {0}")]
55 InvalidPrivateKey(String),
56 #[error("empty key")]
57 EmptyKeys,
58 #[error("error parsing x509 cert from bytes: {0}")]
59 ParseX509(CertificateError),
60 #[error("error parsing pem formated certificate from bytes: {0}")]
61 ParsePem(CertificateError),
62 #[error("error parsing overriding names in new certificate: {0}")]
63 ParseOverridingNames(CertificateError),
64}
65
66#[derive(Clone, Debug)]
70pub struct CertifiedKeyWrapper {
71 inner: Arc<CertifiedKey>,
72 names: Vec<String>,
74 expiration: i64,
75 fingerprint: Fingerprint,
76}
77
78impl TryFrom<&AddCertificate> for CertifiedKeyWrapper {
81 type Error = CertificateResolverError;
82
83 fn try_from(add: &AddCertificate) -> Result<Self, Self::Error> {
84 let cert = add.certificate.clone();
85
86 let pem =
87 parse_pem(cert.certificate.as_bytes()).map_err(CertificateResolverError::ParsePem)?;
88
89 let x509 = parse_x509(&pem.contents).map_err(CertificateResolverError::ParseX509)?;
90
91 let overriding_names = if add.certificate.names.is_empty() {
92 get_cn_and_san_attributes(&x509)
93 } else {
94 add.certificate.names.clone()
95 };
96
97 let expiration = add
98 .expired_at
99 .unwrap_or(x509.validity().not_after.timestamp());
100
101 let fingerprint = Fingerprint(Sha256::digest(&pem.contents).iter().cloned().collect());
102
103 let mut chain = vec![CertificateDer::from(pem.contents)];
104 for cert in &cert.certificate_chain {
105 let chain_link = parse_pem(cert.as_bytes())
106 .map_err(CertificateResolverError::ParsePem)?
107 .contents;
108
109 chain.push(CertificateDer::from(chain_link));
110 }
111
112 let mut key_reader = BufReader::new(cert.key.as_bytes());
113
114 let item = match rustls_pemfile::read_one(&mut key_reader)
115 .map_err(|_| CertificateResolverError::EmptyKeys)?
116 {
117 Some(item) => item,
118 None => return Err(CertificateResolverError::EmptyKeys),
119 };
120
121 let private_key = match item {
122 rustls_pemfile::Item::Pkcs1Key(rsa_key) => PrivateKeyDer::from(rsa_key),
123 rustls_pemfile::Item::Pkcs8Key(pkcs8_key) => PrivateKeyDer::from(pkcs8_key),
124 rustls_pemfile::Item::Sec1Key(ec_key) => PrivateKeyDer::from(ec_key),
125 _ => return Err(CertificateResolverError::EmptyKeys),
126 };
127
128 match any_supported_type(&private_key) {
129 Ok(signing_key) => {
130 let stored_certificate = CertifiedKeyWrapper {
131 inner: Arc::new(CertifiedKey::new(chain, signing_key)),
132 names: overriding_names,
133 expiration,
134 fingerprint,
135 };
136 Ok(stored_certificate)
137 }
138 Err(sign_error) => Err(CertificateResolverError::InvalidPrivateKey(
139 sign_error.to_string(),
140 )),
141 }
142 }
143}
144
145#[derive(Default, Debug)]
152pub struct CertificateResolver {
153 pub domains: TrieNode<Fingerprint>,
155 certificates: HashMap<Fingerprint, CertifiedKeyWrapper>,
157 name_fingerprint_idx: HashMap<String, Vec<(Fingerprint, i64)>>,
161}
162
163impl CertificateResolver {
164 pub fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
166 self.certificates.get(fingerprint).map(ToOwned::to_owned)
167 }
168
169 pub fn add_certificate(
172 &mut self,
173 add: &AddCertificate,
174 ) -> Result<Fingerprint, CertificateResolverError> {
175 let cert_to_add = CertifiedKeyWrapper::try_from(add)?;
176
177 trace!("Certificate Resolver: adding certificate {:?}", cert_to_add);
178
179 if self.certificates.contains_key(&cert_to_add.fingerprint) {
180 return Ok(cert_to_add.fingerprint);
181 }
182
183 for new_name in &cert_to_add.names {
184 let fingerprints_for_this_name = self
185 .name_fingerprint_idx
186 .entry(new_name.to_owned())
187 .or_default();
188
189 fingerprints_for_this_name
190 .push((cert_to_add.fingerprint.clone(), cert_to_add.expiration));
191
192 fingerprints_for_this_name.sort_by_key(|t| t.1);
194
195 let longest_lived_cert = match fingerprints_for_this_name.last() {
196 Some(cert) => cert,
197 None => {
198 error!("no fingerprint for this name, this should not happen");
199 continue;
200 }
201 };
202
203 self.domains.remove(&new_name.to_owned().into_bytes());
205 self.domains.insert(
206 new_name.to_owned().into_bytes(),
207 longest_lived_cert.0.to_owned(),
208 );
209 }
210
211 self.certificates
212 .insert(cert_to_add.fingerprint.to_owned(), cert_to_add.clone());
213
214 trace!("{:#?}", self);
215
216 Ok(cert_to_add.fingerprint)
217 }
218
219 pub fn remove_certificate(
222 &mut self,
223 fingerprint: &Fingerprint,
224 ) -> Result<(), CertificateResolverError> {
225 if let Some(certificate_to_remove) = self.get_certificate(fingerprint) {
226 for name in certificate_to_remove.names {
227 self.domains.domain_remove(&name.clone().into_bytes());
228
229 if let Some(fingerprints_and_exp) = self.name_fingerprint_idx.get_mut(&name) {
230 *fingerprints_and_exp = fingerprints_and_exp
232 .drain(..)
233 .filter(|t| &t.0 != fingerprint)
234 .collect();
235
236 if let Some(longest_lived_cert) = fingerprints_and_exp.last() {
238 self.domains
239 .insert(name.into_bytes(), longest_lived_cert.0.to_owned());
240 }
241 }
242 }
243
244 self.certificates.remove(fingerprint);
245 }
246 trace!("{:#?}", self);
247
248 Ok(())
249 }
250
251 pub fn replace_certificate(
255 &mut self,
256 replace: &ReplaceCertificate,
257 ) -> Result<Fingerprint, CertificateResolverError> {
258 match Fingerprint::from_str(&replace.old_fingerprint) {
259 Ok(old_fingerprint) => self.remove_certificate(&old_fingerprint)?,
260 Err(err) => {
261 error!("failed to parse fingerprint, {}", err);
262 }
263 }
264
265 self.add_certificate(&AddCertificate {
266 address: replace.address.to_owned(),
267 certificate: replace.new_certificate.to_owned(),
268 expired_at: replace.new_expired_at.to_owned(),
269 })
270 }
271
272 #[cfg(test)]
275 fn find_certificates_by_names(
276 &self,
277 names: &HashSet<String>,
278 ) -> Result<HashSet<Fingerprint>, CertificateResolverError> {
279 let mut fingerprints = HashSet::new();
280 for name in names {
281 if let Some(fprints) = self.name_fingerprint_idx.get(name) {
282 fprints.iter().for_each(|fingerprint| {
283 fingerprints.insert(fingerprint.to_owned().0);
284 });
285 }
286 }
287
288 Ok(fingerprints)
289 }
290
291 #[cfg(test)]
294 fn certificate_names(
295 &self,
296 fingerprint: &Fingerprint,
297 ) -> Result<HashSet<String>, CertificateResolverError> {
298 if let Some(cert) = self.certificates.get(fingerprint) {
299 return Ok(cert.names.iter().cloned().collect());
300 }
301 Ok(HashSet::new())
302 }
303
304 pub fn domain_lookup(
305 &self,
306 domain: &[u8],
307 accept_wildcard: bool,
308 ) -> Option<&KeyValue<Key, Fingerprint>> {
309 self.domains.domain_lookup(domain, accept_wildcard)
310 }
311}
312
313#[derive(Default)]
317pub struct MutexCertificateResolver(pub Mutex<CertificateResolver>);
318
319impl ResolvesServerCert for MutexCertificateResolver {
320 fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
321 let server_name = client_hello.server_name();
322 let sigschemes = client_hello.signature_schemes();
323
324 if server_name.is_none() {
325 error!("cannot look up certificate: no SNI from session");
326 return None;
327 }
328
329 let name: &str = server_name.unwrap();
330 trace!(
331 "trying to resolve name: {:?} for signature scheme: {:?}",
332 name,
333 sigschemes
334 );
335 if let Ok(ref mut resolver) = self.0.try_lock() {
336 if let Some((_, fingerprint)) = resolver.domains.domain_lookup(name.as_bytes(), true) {
338 trace!(
339 "looking for certificate for {:?} with fingerprint {:?}",
340 name,
341 fingerprint
342 );
343
344 let cert = resolver
345 .certificates
346 .get(fingerprint)
347 .map(|cert| cert.inner.clone());
348
349 trace!("Found for fingerprint {}: {}", fingerprint, cert.is_some());
350 return cert;
351 }
352 }
353
354 debug!("Default certificate is used for {}", name);
358 incr!("tls.default_cert_used");
359 DEFAULT_CERTIFICATE.clone()
360 }
361}
362
363impl fmt::Debug for MutexCertificateResolver {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 f.write_str("MutexWrappedCertificateResolver")
366 }
367}
368
369#[cfg(test)]
373mod tests {
374 use std::{
375 collections::HashSet,
376 error::Error,
377 time::{Duration, SystemTime},
378 };
379
380 use super::CertificateResolver;
381
382 use sozu_command::proto::command::{AddCertificate, CertificateAndKey, SocketAddress};
384
385 #[test]
386 fn lifecycle() -> Result<(), Box<dyn Error + Send + Sync>> {
387 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
388 let mut resolver = CertificateResolver::default();
389 let certificate_and_key = CertificateAndKey {
390 certificate: String::from(include_str!("../assets/certificate.pem")),
391 key: String::from(include_str!("../assets/key.pem")),
392 ..Default::default()
393 };
394
395 let fingerprint = resolver
396 .add_certificate(&AddCertificate {
397 address,
398 certificate: certificate_and_key,
399 expired_at: None,
400 })
401 .expect("could not add certificate");
402
403 if resolver.get_certificate(&fingerprint).is_none() {
404 return Err("failed to retrieve certificate".into());
405 }
406
407 let names = resolver.certificate_names(&fingerprint)?;
409
410 if let Err(err) = resolver.remove_certificate(&fingerprint) {
411 return Err(format!("the certificate was not removed, {err}").into());
412 }
413
414 if resolver.get_certificate(&fingerprint).is_some() {
415 return Err("We have retrieved the certificate that should be deleted".into());
416 }
417
418 if !resolver.find_certificates_by_names(&names)?.is_empty() {
419 return Err(
420 "The certificate should be deleted but one of its names is in the index".into(),
421 );
422 }
423
424 Ok(())
425 }
426
427 #[test]
428 fn name_override() -> Result<(), Box<dyn Error + Send + Sync>> {
429 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
430 let mut resolver = CertificateResolver::default();
431 let certificate_and_key = CertificateAndKey {
432 certificate: String::from(include_str!("../assets/certificate.pem")),
433 key: String::from(include_str!("../assets/key.pem")),
434 names: vec!["localhost".into(), "lolcatho.st".into()],
435 ..Default::default()
436 };
437
438 let fingerprint = resolver.add_certificate(&AddCertificate {
439 address,
440 certificate: certificate_and_key,
441 expired_at: None,
442 })?;
443
444 if resolver.get_certificate(&fingerprint).is_none() {
445 return Err("failed to retrieve certificate".into());
446 }
447
448 let mut lolcat = HashSet::new();
449 lolcat.insert(String::from("lolcatho.st"));
450 if resolver.find_certificates_by_names(&lolcat)?.is_empty()
451 || resolver.get_certificate(&fingerprint).is_none()
452 {
453 return Err("failed to retrieve certificate with custom names".into());
454 }
455
456 if let Err(err) = resolver.remove_certificate(&fingerprint) {
457 return Err(format!("the certificate could not be removed, {err}").into());
458 }
459
460 let names = resolver.certificate_names(&fingerprint)?;
461 if !resolver.find_certificates_by_names(&names)?.is_empty()
462 && resolver.get_certificate(&fingerprint).is_some()
463 {
464 return Err("We have retrieved the certificate that should be deleted".into());
465 }
466
467 Ok(())
468 }
469
470 #[test]
471 fn keep_resolving_with_wildcard() -> Result<(), Box<dyn Error + Send + Sync>> {
472 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
473 let mut resolver = CertificateResolver::default();
474
475 let wildcard_example_org = CertificateAndKey {
478 certificate: String::from(include_str!("../assets/tests/certificate-3.pem")),
479 key: String::from(include_str!("../assets/tests/key.pem")),
480 ..Default::default()
481 };
482
483 let wildcard_example_org_fingerprint = resolver.add_certificate(&AddCertificate {
484 address: address.clone(),
485 certificate: wildcard_example_org,
486 expired_at: Some(
487 (SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
488 + Duration::from_secs(1 * 365 * 24 * 3600))
489 .as_secs() as i64,
490 ),
491 })?;
492
493 if resolver
494 .get_certificate(&wildcard_example_org_fingerprint)
495 .is_none()
496 {
497 return Err("could not load the 2-year-valid certificate".into());
498 }
499
500 let www_example_org = CertificateAndKey {
504 certificate: String::from(include_str!("../assets/tests/certificate-2.pem")),
505 key: String::from(include_str!("../assets/tests/key.pem")),
506 ..Default::default()
507 };
508
509 let www_example_org_fingerprint = resolver.add_certificate(&AddCertificate {
510 address,
511 certificate: www_example_org,
512 expired_at: Some(
513 (SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
514 + Duration::from_secs(2 * 365 * 24 * 3600))
515 .as_secs() as i64,
516 ),
517 ..Default::default()
518 })?;
519
520 let www_example_org = resolver
521 .domain_lookup("www.example.org".as_bytes(), true)
522 .expect("there should be a www.example.org cert");
523 assert_eq!(www_example_org.1, www_example_org_fingerprint);
524
525 let test_example_org = resolver
526 .domain_lookup("test.example.org".as_bytes(), true)
527 .expect("there should be a test.example.org cert");
528 assert_eq!(test_example_org.1, wildcard_example_org_fingerprint);
529
530 let example_org = resolver
531 .domain_lookup("example.org".as_bytes(), true)
532 .expect("there should be a example.org cert");
533 assert_eq!(example_org.1, www_example_org_fingerprint);
534
535 resolver
538 .remove_certificate(&www_example_org_fingerprint)
539 .expect("should be able to remove the 2-year certificate");
540
541 let should_be_wildcard_fingerprint = resolver
542 .domain_lookup("www.example.org".as_bytes(), true)
543 .expect("there should be a www.example.org cert");
544 assert_eq!(
545 should_be_wildcard_fingerprint.1,
546 wildcard_example_org_fingerprint
547 );
548
549 assert!(resolver
550 .domain_lookup("example.org".as_bytes(), true)
551 .is_none());
552
553 Ok(())
554 }
555
556 #[test]
557 fn resolve_the_longer_lived_cert() -> Result<(), Box<dyn Error + Send + Sync>> {
558 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
559 let mut resolver = CertificateResolver::default();
560
561 let certificate_and_key_2y = CertificateAndKey {
564 certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
565 key: String::from(include_str!("../assets/tests/key-2y.pem")),
566 ..Default::default()
567 };
568
569 let fingerprint_2y = resolver.add_certificate(&AddCertificate {
570 address: address.clone(),
571 certificate: certificate_and_key_2y,
572 expired_at: None,
573 })?;
574
575 if resolver.get_certificate(&fingerprint_2y).is_none() {
576 return Err("could not load the 2-year-valid certificate".into());
577 }
578
579 let certificate_and_key_1y = CertificateAndKey {
582 certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
583 key: String::from(include_str!("../assets/tests/key-1y.pem")),
584 ..Default::default()
585 };
586
587 let fingerprint_1y = resolver.add_certificate(&AddCertificate {
588 address,
589 certificate: certificate_and_key_1y,
590 ..Default::default()
591 })?;
592
593 let localhost_cert = resolver
594 .domain_lookup("localhost".as_bytes(), true)
595 .expect("there should be a localhost cert");
596
597 assert_eq!(localhost_cert.1, fingerprint_2y);
598
599 resolver
603 .remove_certificate(&fingerprint_2y)
604 .expect("should be able to remove the 2-year certificate");
605
606 let localhost_cert = resolver
607 .domain_lookup("localhost".as_bytes(), true)
608 .expect("there should be a localhost cert");
609
610 assert_eq!(localhost_cert.1, fingerprint_1y);
611
612 Ok(())
613 }
614
615 #[test]
616 fn expiration_override() -> Result<(), Box<dyn Error + Send + Sync>> {
617 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
618 let mut resolver = CertificateResolver::default();
619
620 let certificate_and_key_1y = CertificateAndKey {
623 certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
624 key: String::from(include_str!("../assets/tests/key-1y.pem")),
625 ..Default::default()
626 };
627
628 let fingerprint_1y_overriden = resolver.add_certificate(&AddCertificate {
629 address: address.clone(),
630 certificate: certificate_and_key_1y,
631 expired_at: Some(
632 (SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
633 + Duration::from_secs(3 * 365 * 24 * 3600))
634 .as_secs() as i64,
635 ),
636 })?;
637
638 if resolver
639 .get_certificate(&fingerprint_1y_overriden)
640 .is_none()
641 {
642 return Err("failed to retrieve certificate".into());
643 }
644
645 let certificate_and_key_2y = CertificateAndKey {
648 certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
649 key: String::from(include_str!("../assets/tests/key-2y.pem")),
650 ..Default::default()
651 };
652
653 let fingerprint_2y = resolver.add_certificate(&AddCertificate {
654 address,
655 certificate: certificate_and_key_2y,
656 expired_at: None,
657 })?;
658
659 let localhost_cert = resolver
660 .domain_lookup("localhost".as_bytes(), true)
661 .expect("there should be a localhost cert");
662
663 assert_eq!(localhost_cert.1, fingerprint_1y_overriden);
664
665 resolver
669 .remove_certificate(&fingerprint_1y_overriden)
670 .expect("should be able to remove the 1-year (3-year-overriden) certificate");
671
672 let localhost_cert = resolver
673 .domain_lookup("localhost".as_bytes(), true)
674 .expect("there should be a localhost cert");
675
676 assert_eq!(localhost_cert.1, fingerprint_2y);
677
678 Ok(())
679 }
680}