#[cfg(test)]
use std::collections::HashSet;
use std::{
collections::HashMap,
fmt,
io::BufReader,
str::FromStr,
sync::{Arc, Mutex},
};
use once_cell::sync::Lazy;
use rustls::{
crypto::ring::sign::any_supported_type,
pki_types::{CertificateDer, PrivateKeyDer},
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
};
use sha2::{Digest, Sha256};
use sozu_command::{
certificate::{
get_cn_and_san_attributes, parse_pem, parse_x509, CertificateError, Fingerprint,
},
proto::command::{AddCertificate, CertificateAndKey, ReplaceCertificate, SocketAddress},
};
use crate::router::trie::{Key, KeyValue, TrieNode};
static DEFAULT_CERTIFICATE: Lazy<Option<Arc<CertifiedKey>>> = Lazy::new(|| {
let add = AddCertificate {
certificate: CertificateAndKey {
certificate: include_str!("../assets/certificate.pem").to_string(),
certificate_chain: vec![include_str!("../assets/certificate_chain.pem").to_string()],
key: include_str!("../assets/key.pem").to_string(),
versions: vec![],
names: vec![],
},
address: SocketAddress::new_v4(0, 0, 0, 0, 8080), expired_at: None,
};
CertifiedKeyWrapper::try_from(&add).ok().map(|c| c.inner)
});
#[derive(thiserror::Error, Debug)]
pub enum CertificateResolverError {
#[error("failed to get common name and subject alternate names from pem, {0}")]
InvalidCommonNameAndSubjectAlternateNames(CertificateError),
#[error("invalid private key: {0}")]
InvalidPrivateKey(String),
#[error("empty key")]
EmptyKeys,
#[error("error parsing x509 cert from bytes: {0}")]
ParseX509(CertificateError),
#[error("error parsing pem formated certificate from bytes: {0}")]
ParsePem(CertificateError),
#[error("error parsing overriding names in new certificate: {0}")]
ParseOverridingNames(CertificateError),
}
#[derive(Clone, Debug)]
pub struct CertifiedKeyWrapper {
inner: Arc<CertifiedKey>,
names: Vec<String>,
expiration: i64,
fingerprint: Fingerprint,
}
impl TryFrom<&AddCertificate> for CertifiedKeyWrapper {
type Error = CertificateResolverError;
fn try_from(add: &AddCertificate) -> Result<Self, Self::Error> {
let cert = add.certificate.clone();
let pem =
parse_pem(cert.certificate.as_bytes()).map_err(CertificateResolverError::ParsePem)?;
let x509 = parse_x509(&pem.contents).map_err(CertificateResolverError::ParseX509)?;
let overriding_names = if add.certificate.names.is_empty() {
get_cn_and_san_attributes(&x509)
} else {
add.certificate.names.clone()
};
let expiration = add
.expired_at
.unwrap_or(x509.validity().not_after.timestamp());
let fingerprint = Fingerprint(Sha256::digest(&pem.contents).iter().cloned().collect());
let mut chain = vec![CertificateDer::from(pem.contents)];
for cert in &cert.certificate_chain {
let chain_link = parse_pem(cert.as_bytes())
.map_err(CertificateResolverError::ParsePem)?
.contents;
chain.push(CertificateDer::from(chain_link));
}
let mut key_reader = BufReader::new(cert.key.as_bytes());
let item = match rustls_pemfile::read_one(&mut key_reader)
.map_err(|_| CertificateResolverError::EmptyKeys)?
{
Some(item) => item,
None => return Err(CertificateResolverError::EmptyKeys),
};
let private_key = match item {
rustls_pemfile::Item::Pkcs1Key(rsa_key) => PrivateKeyDer::from(rsa_key),
rustls_pemfile::Item::Pkcs8Key(pkcs8_key) => PrivateKeyDer::from(pkcs8_key),
rustls_pemfile::Item::Sec1Key(ec_key) => PrivateKeyDer::from(ec_key),
_ => return Err(CertificateResolverError::EmptyKeys),
};
match any_supported_type(&private_key) {
Ok(signing_key) => {
let stored_certificate = CertifiedKeyWrapper {
inner: Arc::new(CertifiedKey::new(chain, signing_key)),
names: overriding_names,
expiration,
fingerprint,
};
Ok(stored_certificate)
}
Err(sign_error) => Err(CertificateResolverError::InvalidPrivateKey(
sign_error.to_string(),
)),
}
}
}
#[derive(Default, Debug)]
pub struct CertificateResolver {
pub domains: TrieNode<Fingerprint>,
certificates: HashMap<Fingerprint, CertifiedKeyWrapper>,
name_fingerprint_idx: HashMap<String, Vec<(Fingerprint, i64)>>,
}
impl CertificateResolver {
pub fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
self.certificates.get(fingerprint).map(ToOwned::to_owned)
}
pub fn add_certificate(
&mut self,
add: &AddCertificate,
) -> Result<Fingerprint, CertificateResolverError> {
let cert_to_add = CertifiedKeyWrapper::try_from(add)?;
trace!("Certificate Resolver: adding certificate {:?}", cert_to_add);
if self.certificates.contains_key(&cert_to_add.fingerprint) {
return Ok(cert_to_add.fingerprint);
}
for new_name in &cert_to_add.names {
let fingerprints_for_this_name = self
.name_fingerprint_idx
.entry(new_name.to_owned())
.or_default();
fingerprints_for_this_name
.push((cert_to_add.fingerprint.clone(), cert_to_add.expiration));
fingerprints_for_this_name.sort_by_key(|t| t.1);
let longest_lived_cert = match fingerprints_for_this_name.last() {
Some(cert) => cert,
None => {
error!("no fingerprint for this name, this should not happen");
continue;
}
};
self.domains.remove(&new_name.to_owned().into_bytes());
self.domains.insert(
new_name.to_owned().into_bytes(),
longest_lived_cert.0.to_owned(),
);
}
self.certificates
.insert(cert_to_add.fingerprint.to_owned(), cert_to_add.clone());
trace!("{:#?}", self);
Ok(cert_to_add.fingerprint)
}
pub fn remove_certificate(
&mut self,
fingerprint: &Fingerprint,
) -> Result<(), CertificateResolverError> {
if let Some(certificate_to_remove) = self.get_certificate(fingerprint) {
for name in certificate_to_remove.names {
self.domains.domain_remove(&name.clone().into_bytes());
if let Some(fingerprints_and_exp) = self.name_fingerprint_idx.get_mut(&name) {
*fingerprints_and_exp = fingerprints_and_exp
.drain(..)
.filter(|t| &t.0 != fingerprint)
.collect();
if let Some(longest_lived_cert) = fingerprints_and_exp.last() {
self.domains
.insert(name.into_bytes(), longest_lived_cert.0.to_owned());
}
}
}
self.certificates.remove(fingerprint);
}
trace!("{:#?}", self);
Ok(())
}
pub fn replace_certificate(
&mut self,
replace: &ReplaceCertificate,
) -> Result<Fingerprint, CertificateResolverError> {
match Fingerprint::from_str(&replace.old_fingerprint) {
Ok(old_fingerprint) => self.remove_certificate(&old_fingerprint)?,
Err(err) => {
error!("failed to parse fingerprint, {}", err);
}
}
self.add_certificate(&AddCertificate {
address: replace.address.to_owned(),
certificate: replace.new_certificate.to_owned(),
expired_at: replace.new_expired_at.to_owned(),
})
}
#[cfg(test)]
fn find_certificates_by_names(
&self,
names: &HashSet<String>,
) -> Result<HashSet<Fingerprint>, CertificateResolverError> {
let mut fingerprints = HashSet::new();
for name in names {
if let Some(fprints) = self.name_fingerprint_idx.get(name) {
fprints.iter().for_each(|fingerprint| {
fingerprints.insert(fingerprint.to_owned().0);
});
}
}
Ok(fingerprints)
}
#[cfg(test)]
fn certificate_names(
&self,
fingerprint: &Fingerprint,
) -> Result<HashSet<String>, CertificateResolverError> {
if let Some(cert) = self.certificates.get(fingerprint) {
return Ok(cert.names.iter().cloned().collect());
}
Ok(HashSet::new())
}
pub fn domain_lookup(
&self,
domain: &[u8],
accept_wildcard: bool,
) -> Option<&KeyValue<Key, Fingerprint>> {
self.domains.domain_lookup(domain, accept_wildcard)
}
}
#[derive(Default)]
pub struct MutexCertificateResolver(pub Mutex<CertificateResolver>);
impl ResolvesServerCert for MutexCertificateResolver {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
let server_name = client_hello.server_name();
let sigschemes = client_hello.signature_schemes();
if server_name.is_none() {
error!("cannot look up certificate: no SNI from session");
return None;
}
let name: &str = server_name.unwrap();
trace!(
"trying to resolve name: {:?} for signature scheme: {:?}",
name,
sigschemes
);
if let Ok(ref mut resolver) = self.0.try_lock() {
if let Some((_, fingerprint)) = resolver.domains.domain_lookup(name.as_bytes(), true) {
trace!(
"looking for certificate for {:?} with fingerprint {:?}",
name,
fingerprint
);
let cert = resolver
.certificates
.get(fingerprint)
.map(|cert| cert.inner.clone());
trace!("Found for fingerprint {}: {}", fingerprint, cert.is_some());
return cert;
}
}
debug!("Default certificate is used for {}", name);
incr!("tls.default_cert_used");
DEFAULT_CERTIFICATE.clone()
}
}
impl fmt::Debug for MutexCertificateResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("MutexWrappedCertificateResolver")
}
}
#[cfg(test)]
mod tests {
use std::{
collections::HashSet,
error::Error,
time::{Duration, SystemTime},
};
use super::CertificateResolver;
use sozu_command::proto::command::{AddCertificate, CertificateAndKey, SocketAddress};
#[test]
fn lifecycle() -> Result<(), Box<dyn Error + Send + Sync>> {
let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
let mut resolver = CertificateResolver::default();
let certificate_and_key = CertificateAndKey {
certificate: String::from(include_str!("../assets/certificate.pem")),
key: String::from(include_str!("../assets/key.pem")),
..Default::default()
};
let fingerprint = resolver
.add_certificate(&AddCertificate {
address,
certificate: certificate_and_key,
expired_at: None,
})
.expect("could not add certificate");
if resolver.get_certificate(&fingerprint).is_none() {
return Err("failed to retrieve certificate".into());
}
let names = resolver.certificate_names(&fingerprint)?;
if let Err(err) = resolver.remove_certificate(&fingerprint) {
return Err(format!("the certificate was not removed, {err}").into());
}
if resolver.get_certificate(&fingerprint).is_some() {
return Err("We have retrieved the certificate that should be deleted".into());
}
if !resolver.find_certificates_by_names(&names)?.is_empty() {
return Err(
"The certificate should be deleted but one of its names is in the index".into(),
);
}
Ok(())
}
#[test]
fn name_override() -> Result<(), Box<dyn Error + Send + Sync>> {
let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
let mut resolver = CertificateResolver::default();
let certificate_and_key = CertificateAndKey {
certificate: String::from(include_str!("../assets/certificate.pem")),
key: String::from(include_str!("../assets/key.pem")),
names: vec!["localhost".into(), "lolcatho.st".into()],
..Default::default()
};
let fingerprint = resolver.add_certificate(&AddCertificate {
address,
certificate: certificate_and_key,
expired_at: None,
})?;
if resolver.get_certificate(&fingerprint).is_none() {
return Err("failed to retrieve certificate".into());
}
let mut lolcat = HashSet::new();
lolcat.insert(String::from("lolcatho.st"));
if resolver.find_certificates_by_names(&lolcat)?.is_empty()
|| resolver.get_certificate(&fingerprint).is_none()
{
return Err("failed to retrieve certificate with custom names".into());
}
if let Err(err) = resolver.remove_certificate(&fingerprint) {
return Err(format!("the certificate could not be removed, {err}").into());
}
let names = resolver.certificate_names(&fingerprint)?;
if !resolver.find_certificates_by_names(&names)?.is_empty()
&& resolver.get_certificate(&fingerprint).is_some()
{
return Err("We have retrieved the certificate that should be deleted".into());
}
Ok(())
}
#[test]
fn keep_resolving_with_wildcard() -> Result<(), Box<dyn Error + Send + Sync>> {
let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
let mut resolver = CertificateResolver::default();
let wildcard_example_org = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-3.pem")),
key: String::from(include_str!("../assets/tests/key.pem")),
..Default::default()
};
let wildcard_example_org_fingerprint = resolver.add_certificate(&AddCertificate {
address: address.clone(),
certificate: wildcard_example_org,
expired_at: Some(
(SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
+ Duration::from_secs(1 * 365 * 24 * 3600))
.as_secs() as i64,
),
})?;
if resolver
.get_certificate(&wildcard_example_org_fingerprint)
.is_none()
{
return Err("could not load the 2-year-valid certificate".into());
}
let www_example_org = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-2.pem")),
key: String::from(include_str!("../assets/tests/key.pem")),
..Default::default()
};
let www_example_org_fingerprint = resolver.add_certificate(&AddCertificate {
address,
certificate: www_example_org,
expired_at: Some(
(SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
+ Duration::from_secs(2 * 365 * 24 * 3600))
.as_secs() as i64,
),
..Default::default()
})?;
let www_example_org = resolver
.domain_lookup("www.example.org".as_bytes(), true)
.expect("there should be a www.example.org cert");
assert_eq!(www_example_org.1, www_example_org_fingerprint);
let test_example_org = resolver
.domain_lookup("test.example.org".as_bytes(), true)
.expect("there should be a test.example.org cert");
assert_eq!(test_example_org.1, wildcard_example_org_fingerprint);
let example_org = resolver
.domain_lookup("example.org".as_bytes(), true)
.expect("there should be a example.org cert");
assert_eq!(example_org.1, www_example_org_fingerprint);
resolver
.remove_certificate(&www_example_org_fingerprint)
.expect("should be able to remove the 2-year certificate");
let should_be_wildcard_fingerprint = resolver
.domain_lookup("www.example.org".as_bytes(), true)
.expect("there should be a www.example.org cert");
assert_eq!(
should_be_wildcard_fingerprint.1,
wildcard_example_org_fingerprint
);
assert!(resolver
.domain_lookup("example.org".as_bytes(), true)
.is_none());
Ok(())
}
#[test]
fn resolve_the_longer_lived_cert() -> Result<(), Box<dyn Error + Send + Sync>> {
let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
let mut resolver = CertificateResolver::default();
let certificate_and_key_2y = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
key: String::from(include_str!("../assets/tests/key-2y.pem")),
..Default::default()
};
let fingerprint_2y = resolver.add_certificate(&AddCertificate {
address: address.clone(),
certificate: certificate_and_key_2y,
expired_at: None,
})?;
if resolver.get_certificate(&fingerprint_2y).is_none() {
return Err("could not load the 2-year-valid certificate".into());
}
let certificate_and_key_1y = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
key: String::from(include_str!("../assets/tests/key-1y.pem")),
..Default::default()
};
let fingerprint_1y = resolver.add_certificate(&AddCertificate {
address,
certificate: certificate_and_key_1y,
..Default::default()
})?;
let localhost_cert = resolver
.domain_lookup("localhost".as_bytes(), true)
.expect("there should be a localhost cert");
assert_eq!(localhost_cert.1, fingerprint_2y);
resolver
.remove_certificate(&fingerprint_2y)
.expect("should be able to remove the 2-year certificate");
let localhost_cert = resolver
.domain_lookup("localhost".as_bytes(), true)
.expect("there should be a localhost cert");
assert_eq!(localhost_cert.1, fingerprint_1y);
Ok(())
}
#[test]
fn expiration_override() -> Result<(), Box<dyn Error + Send + Sync>> {
let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
let mut resolver = CertificateResolver::default();
let certificate_and_key_1y = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
key: String::from(include_str!("../assets/tests/key-1y.pem")),
..Default::default()
};
let fingerprint_1y_overriden = resolver.add_certificate(&AddCertificate {
address: address.clone(),
certificate: certificate_and_key_1y,
expired_at: Some(
(SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
+ Duration::from_secs(3 * 365 * 24 * 3600))
.as_secs() as i64,
),
})?;
if resolver
.get_certificate(&fingerprint_1y_overriden)
.is_none()
{
return Err("failed to retrieve certificate".into());
}
let certificate_and_key_2y = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
key: String::from(include_str!("../assets/tests/key-2y.pem")),
..Default::default()
};
let fingerprint_2y = resolver.add_certificate(&AddCertificate {
address,
certificate: certificate_and_key_2y,
expired_at: None,
})?;
let localhost_cert = resolver
.domain_lookup("localhost".as_bytes(), true)
.expect("there should be a localhost cert");
assert_eq!(localhost_cert.1, fingerprint_1y_overriden);
resolver
.remove_certificate(&fingerprint_1y_overriden)
.expect("should be able to remove the 1-year (3-year-overriden) certificate");
let localhost_cert = resolver
.domain_lookup("localhost".as_bytes(), true)
.expect("there should be a localhost cert");
assert_eq!(localhost_cert.1, fingerprint_2y);
Ok(())
}
}