use std::{
borrow::ToOwned,
collections::{HashMap, HashSet},
convert::From,
io::BufReader,
str::FromStr,
sync::{Arc, Mutex},
};
use once_cell::sync::Lazy;
use rustls::{
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
Certificate, PrivateKey,
};
use sha2::{Digest, Sha256};
use sozu_command::{
certificate::{
get_cn_and_san_attributes, parse_pem, parse_x509, CertificateError, Fingerprint,
},
proto::command::{AddCertificate, CertificateAndKey, ReplaceCertificate},
};
use crate::router::trie::{Key, KeyValue, TrieNode};
static DEFAULT_CERTIFICATE: Lazy<Option<Arc<CertifiedKey>>> = Lazy::new(|| {
let certificate_and_key = CertificateAndKey {
certificate: include_str!("../assets/certificate.pem").to_string(),
certificate_chain: vec![include_str!("../assets/certificate_chain.pem").to_string()],
key: include_str!("../assets/key.pem").to_string(),
versions: vec![],
names: vec![],
};
CertificateResolver::parse(&certificate_and_key)
.ok()
.map(|c| c.inner)
});
pub trait ResolveCertificate {
type Error;
fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper>;
fn add_certificate(&mut self, opts: &AddCertificate) -> Result<Fingerprint, Self::Error>;
fn remove_certificate(&mut self, opts: &Fingerprint) -> Result<(), Self::Error>;
fn replace_certificate(
&mut self,
opts: &ReplaceCertificate,
) -> Result<Fingerprint, Self::Error> {
match Fingerprint::from_str(&opts.old_fingerprint) {
Ok(old_fingerprint) => self.remove_certificate(&old_fingerprint)?,
Err(err) => {
error!("failed to parse fingerprint, {}", err);
}
}
self.add_certificate(&AddCertificate {
address: opts.address.to_owned(),
certificate: opts.new_certificate.to_owned(),
expired_at: opts.new_expired_at.to_owned(),
})
}
}
#[derive(Clone, Debug)]
pub struct CertificateOverride {
pub names: Option<HashSet<String>>,
pub expiration: Option<i64>,
}
impl From<&AddCertificate> for CertificateOverride {
fn from(opts: &AddCertificate) -> Self {
let mut names = None;
if !opts.certificate.names.is_empty() {
names = Some(opts.certificate.names.iter().cloned().collect())
}
Self {
names,
expiration: opts.expired_at.to_owned(),
}
}
}
#[derive(Clone)]
pub struct CertifiedKeyWrapper {
inner: Arc<CertifiedKey>,
}
impl CertifiedKeyWrapper {
fn pem_bytes(&self) -> &[u8] {
&self.inner.cert[0].0
}
}
#[derive(thiserror::Error, Debug)]
pub enum CertificateResolverError {
#[error("failed to get common name and subject alternate names from pem, {0}")]
InvalidCommonNameAndSubjectAlternateNames(CertificateError),
#[error("invalid private key: {0}")]
InvalidPrivateKey(String),
#[error("empty key")]
EmptyKeys,
#[error("certificate error: {0}")]
CertificateError(CertificateError),
}
impl From<CertificateError> for CertificateResolverError {
fn from(value: CertificateError) -> Self {
Self::CertificateError(value)
}
}
#[derive(Default)]
pub struct CertificateResolver {
pub domains: TrieNode<Fingerprint>,
certificates: HashMap<Fingerprint, CertifiedKeyWrapper>,
name_fingerprint_idx: HashMap<String, HashSet<Fingerprint>>,
overrides: HashMap<Fingerprint, CertificateOverride>,
}
impl ResolveCertificate for CertificateResolver {
type Error = CertificateResolverError;
fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
self.certificates.get(fingerprint).map(ToOwned::to_owned)
}
fn add_certificate(&mut self, opts: &AddCertificate) -> Result<Fingerprint, Self::Error> {
let certificate_to_add = Self::parse(&opts.certificate)?;
let fingerprint = fingerprint(certificate_to_add.pem_bytes());
if !opts.certificate.names.is_empty() || opts.expired_at.is_some() {
self.overrides
.insert(fingerprint.to_owned(), CertificateOverride::from(opts));
} else {
self.overrides.remove(&fingerprint);
}
let (should_insert, certificates_to_remove) =
self.should_insert(&fingerprint, &certificate_to_add)?;
if !should_insert {
return Ok(fingerprint);
}
let new_names = match self.get_names_override(&fingerprint) {
Some(names) => names,
None => self.certificate_names(certificate_to_add.pem_bytes())?,
};
self.certificates
.insert(fingerprint.to_owned(), certificate_to_add);
for new_name in new_names {
self.domains
.insert(new_name.to_owned().into_bytes(), fingerprint.to_owned());
self.name_fingerprint_idx
.entry(new_name)
.or_insert_with(HashSet::new)
.insert(fingerprint.to_owned());
}
for (fingerprint, names) in &certificates_to_remove {
for name in names {
if let Some(fingerprints) = self.name_fingerprint_idx.get_mut(name) {
fingerprints.remove(fingerprint);
}
}
self.certificates.remove(fingerprint);
}
Ok(fingerprint.to_owned())
}
fn remove_certificate(&mut self, fingerprint: &Fingerprint) -> Result<(), Self::Error> {
if let Some(certificate_to_remove) = self.get_certificate(fingerprint) {
let names = match self.get_names_override(fingerprint) {
Some(names) => names,
None => self.certificate_names(certificate_to_remove.pem_bytes())?,
};
for name in &names {
if let Some(fingerprints) = self.name_fingerprint_idx.get_mut(name) {
fingerprints.remove(fingerprint);
if fingerprints.is_empty() {
self.domains.domain_remove(&name.to_owned().into_bytes());
}
}
}
self.certificates.remove(fingerprint);
}
Ok(())
}
}
fn fingerprint(bytes: &[u8]) -> Fingerprint {
Fingerprint(Sha256::digest(bytes).iter().cloned().collect())
}
impl CertificateResolver {
fn find_certificates_by_names(
&self,
names: &HashSet<String>,
) -> Result<HashSet<Fingerprint>, CertificateResolverError> {
let mut fingerprints = HashSet::new();
for name in names {
if let Some(fprints) = self.name_fingerprint_idx.get(name) {
fprints.iter().for_each(|fingerprint| {
fingerprints.insert(fingerprint.to_owned());
});
}
}
Ok(fingerprints)
}
fn certificate_names(
&self,
pem_bytes: &[u8],
) -> Result<HashSet<String>, CertificateResolverError> {
let fingerprint = fingerprint(pem_bytes);
if let Some(certificate_override) = self.overrides.get(&fingerprint) {
if let Some(names) = &certificate_override.names {
return Ok(names.to_owned());
}
}
get_cn_and_san_attributes(pem_bytes)
.map_err(CertificateResolverError::InvalidCommonNameAndSubjectAlternateNames)
}
fn parse(
certificate_and_key: &CertificateAndKey,
) -> Result<CertifiedKeyWrapper, CertificateResolverError> {
let certificate_pem =
sozu_command::certificate::parse_pem(certificate_and_key.certificate.as_bytes())?;
let mut chain = vec![Certificate(certificate_pem.contents)];
for cert in &certificate_and_key.certificate_chain {
let chain_link = parse_pem(cert.as_bytes())?.contents;
chain.push(Certificate(chain_link));
}
let mut key_reader = BufReader::new(certificate_and_key.key.as_bytes());
let item = match rustls_pemfile::read_one(&mut key_reader)
.map_err(|_| CertificateResolverError::EmptyKeys)?
{
Some(item) => item,
None => return Err(CertificateResolverError::EmptyKeys),
};
let private_key = match item {
rustls_pemfile::Item::RSAKey(rsa_key) => PrivateKey(rsa_key),
rustls_pemfile::Item::PKCS8Key(pkcs8_key) => PrivateKey(pkcs8_key),
rustls_pemfile::Item::ECKey(ec_key) => PrivateKey(ec_key),
_ => return Err(CertificateResolverError::EmptyKeys),
};
match rustls::sign::any_supported_type(&private_key) {
Ok(signing_key) => {
let stored_certificate = CertifiedKeyWrapper {
inner: Arc::new(CertifiedKey::new(chain, signing_key)),
};
Ok(stored_certificate)
}
Err(sign_error) => Err(CertificateResolverError::InvalidPrivateKey(
sign_error.to_string(),
)),
}
}
}
impl CertificateResolver {
fn should_insert(
&self,
fingerprint: &Fingerprint,
candidate_cert: &CertifiedKeyWrapper,
) -> Result<(bool, HashMap<Fingerprint, HashSet<String>>), CertificateResolverError> {
let x509 = parse_x509(candidate_cert.pem_bytes())?;
let new_names = match self.get_names_override(fingerprint) {
Some(names) => names,
None => self.certificate_names(candidate_cert.pem_bytes())?,
};
let expiration = self
.get_expiration_override(fingerprint)
.unwrap_or_else(|| x509.validity().not_after.timestamp());
let fingerprints = self.find_certificates_by_names(&new_names)?;
let mut certificates = HashMap::new();
for fingerprint in &fingerprints {
if let Some(cert) = self.get_certificate(fingerprint) {
certificates.insert(fingerprint, cert);
}
}
let mut should_insert = false;
let mut certificates_to_remove = HashMap::new();
let mut certificates_names = HashSet::new();
for (fingerprint, stored_certificate) in certificates {
let x509 = parse_x509(stored_certificate.pem_bytes())?;
let certificate_names = match self.get_names_override(fingerprint) {
Some(names) => names,
None => self.certificate_names(stored_certificate.pem_bytes())?,
};
let certificate_expiration = self
.get_expiration_override(fingerprint)
.unwrap_or_else(|| x509.validity().not_after.timestamp());
let extra_names = certificate_names
.difference(&new_names)
.collect::<HashSet<_>>();
if extra_names.is_empty() && certificate_expiration < expiration {
certificates_to_remove.insert(fingerprint.to_owned(), certificate_names.to_owned());
should_insert = true;
}
for name in certificate_names {
certificates_names.insert(name);
}
}
let diff: HashSet<&String> = new_names.difference(&certificates_names).collect();
if !should_insert && diff.is_empty() {
return Ok((false, certificates_to_remove));
}
Ok((true, certificates_to_remove))
}
fn get_expiration_override(&self, fingerprint: &Fingerprint) -> Option<i64> {
self.overrides.get(fingerprint).and_then(|co| co.expiration)
}
fn get_names_override(&self, fingerprint: &Fingerprint) -> Option<HashSet<String>> {
self.overrides
.get(fingerprint)
.and_then(|co| co.names.to_owned())
}
pub fn domain_lookup(
&self,
domain: &[u8],
accept_wildcard: bool,
) -> Option<&KeyValue<Key, Fingerprint>> {
self.domains.domain_lookup(domain, accept_wildcard)
}
}
#[derive(Default)]
pub struct MutexWrappedCertificateResolver(pub Mutex<CertificateResolver>);
impl ResolvesServerCert for MutexWrappedCertificateResolver {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
let server_name = client_hello.server_name();
let sigschemes = client_hello.signature_schemes();
if server_name.is_none() {
error!("cannot look up certificate: no SNI from session");
return None;
}
let name: &str = server_name.unwrap();
trace!(
"trying to resolve name: {:?} for signature scheme: {:?}",
name,
sigschemes
);
if let Ok(ref mut resolver) = self.0.try_lock() {
if let Some((_, fingerprint)) = resolver.domains.domain_lookup(name.as_bytes(), true) {
trace!(
"looking for certificate for {:?} with fingerprint {:?}",
name,
fingerprint
);
let cert = resolver
.certificates
.get(fingerprint)
.map(|cert| cert.inner.clone());
trace!("Found for fingerprint {}: {}", fingerprint, cert.is_some());
return cert;
}
}
debug!("Default certificate is used for {}", name);
incr!("tls.default_cert_used");
DEFAULT_CERTIFICATE.clone()
}
}
#[cfg(test)]
mod tests {
use std::{
collections::HashSet,
error::Error,
time::{Duration, SystemTime},
};
use super::{fingerprint, CertificateResolver, ResolveCertificate};
use rand::{seq::SliceRandom, thread_rng};
use sozu_command::{
certificate::parse_pem,
proto::command::{AddCertificate, CertificateAndKey},
};
#[test]
fn lifecycle() -> Result<(), Box<dyn Error + Send + Sync>> {
let address = "127.0.0.1:8080".to_string();
let mut resolver = CertificateResolver::default();
let certificate_and_key = CertificateAndKey {
certificate: String::from(include_str!("../assets/certificate.pem")),
key: String::from(include_str!("../assets/key.pem")),
..Default::default()
};
let pem = parse_pem(certificate_and_key.certificate.as_bytes())?;
let fingerprint = resolver.add_certificate(&AddCertificate {
address,
certificate: certificate_and_key,
expired_at: None,
})?;
if resolver.get_certificate(&fingerprint).is_none() {
return Err("failed to retrieve certificate".into());
}
if let Err(err) = resolver.remove_certificate(&fingerprint) {
return Err(format!("the certificate must not been removed, {err}").into());
}
let names = resolver.certificate_names(&pem.contents)?;
if !resolver.find_certificates_by_names(&names)?.is_empty()
&& resolver.get_certificate(&fingerprint).is_some()
{
return Err("We have retrieve the certificate that should be deleted".into());
}
Ok(())
}
#[test]
fn name_override() -> Result<(), Box<dyn Error + Send + Sync>> {
let address = "127.0.0.1:8080".to_string();
let mut resolver = CertificateResolver::default();
let certificate_and_key = CertificateAndKey {
certificate: String::from(include_str!("../assets/certificate.pem")),
key: String::from(include_str!("../assets/key.pem")),
names: vec!["localhost".into(), "lolcatho.st".into()],
..Default::default()
};
let pem = parse_pem(certificate_and_key.certificate.as_bytes())?;
let fingerprint = resolver.add_certificate(&AddCertificate {
address,
certificate: certificate_and_key,
expired_at: None,
})?;
if resolver.get_certificate(&fingerprint).is_none() {
return Err("failed to retrieve certificate".into());
}
let mut lolcat = HashSet::new();
lolcat.insert(String::from("lolcatho.st"));
if resolver.find_certificates_by_names(&lolcat)?.is_empty()
|| resolver.get_certificate(&fingerprint).is_none()
{
return Err("failed to retrieve certificate with custom names".into());
}
if let Err(err) = resolver.remove_certificate(&fingerprint) {
return Err(format!("the certificate must not been removed, {err}").into());
}
let names = resolver.certificate_names(&pem.contents)?;
if !resolver.find_certificates_by_names(&names)?.is_empty()
&& resolver.get_certificate(&fingerprint).is_some()
{
return Err("We have retrieve the certificate that should be deleted".into());
}
Ok(())
}
#[test]
fn replacement() -> Result<(), Box<dyn Error + Send + Sync>> {
let address = "127.0.0.1:8080".to_string();
let mut resolver = CertificateResolver::default();
let certificate_and_key_1y = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
key: String::from(include_str!("../assets/tests/key-1y.pem")),
..Default::default()
};
let pem = parse_pem(certificate_and_key_1y.certificate.as_bytes())?;
let names_1y = resolver.certificate_names(&pem.contents)?;
let fingerprint_1y = resolver.add_certificate(&AddCertificate {
address: address.clone(),
certificate: certificate_and_key_1y,
expired_at: None,
})?;
if resolver.get_certificate(&fingerprint_1y).is_none() {
return Err("failed to retrieve certificate".into());
}
let certificate_and_key_2y = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
key: String::from(include_str!("../assets/tests/key-2y.pem")),
..Default::default()
};
let fingerprint_2y = resolver.add_certificate(&AddCertificate {
address,
certificate: certificate_and_key_2y,
expired_at: None,
})?;
if resolver.get_certificate(&fingerprint_2y).is_none() {
return Err("failed to retrieve certificate".into());
}
if resolver.get_certificate(&fingerprint_1y).is_some() {
return Err("certificate must be replaced by the 2y expiration one".into());
}
if resolver.get_certificate(&fingerprint_2y).is_none() {
return Err("certificate must be added instead of the 1y expiration one".into());
}
let fingerprints = resolver.find_certificates_by_names(&names_1y)?;
if fingerprints.get(&fingerprint_1y).is_some() {
return Err("index must not reference the 1y expiration certificate".into());
}
if fingerprints.get(&fingerprint_2y).is_none() {
return Err("index have to reference the 2y expiration certificate".into());
}
Ok(())
}
#[test]
fn expiration_override() -> Result<(), Box<dyn Error + Send + Sync>> {
let address = "127.0.0.1:8080".to_string();
let mut resolver = CertificateResolver::default();
let certificate_and_key_1y = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
key: String::from(include_str!("../assets/tests/key-1y.pem")),
..Default::default()
};
let pem = parse_pem(certificate_and_key_1y.certificate.as_bytes())?;
let names_1y = resolver.certificate_names(&pem.contents)?;
let fingerprint_1y = resolver.add_certificate(&AddCertificate {
address: address.clone(),
certificate: certificate_and_key_1y,
expired_at: Some(
(SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
+ Duration::from_secs(3 * 365 * 24 * 3600))
.as_secs() as i64,
),
})?;
if resolver.get_certificate(&fingerprint_1y).is_none() {
return Err("failed to retrieve certificate".into());
}
let certificate_and_key_2y = CertificateAndKey {
certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
key: String::from(include_str!("../assets/tests/key-2y.pem")),
..Default::default()
};
let fingerprint_2y = resolver.add_certificate(&AddCertificate {
address,
certificate: certificate_and_key_2y,
expired_at: None,
})?;
if resolver.get_certificate(&fingerprint_2y).is_some() {
return Err("certificate should not be loaded".into());
}
if resolver.get_certificate(&fingerprint_1y).is_none() {
return Err("certificate must not be replaced by the 2y expiration one".into());
}
if resolver.get_certificate(&fingerprint_2y).is_some() {
return Err("certificate must not be added instead of the 1y expiration one".into());
}
let fingerprints = resolver.find_certificates_by_names(&names_1y)?;
if fingerprints.get(&fingerprint_1y).is_none() {
return Err("index must reference the 1y expiration certificate".into());
}
if fingerprints.get(&fingerprint_2y).is_some() {
return Err("index must not reference the 2y expiration certificate".into());
}
Ok(())
}
#[test]
fn random() -> Result<(), Box<dyn Error + Send + Sync>> {
let mut certificates = vec![
CertificateAndKey {
certificate: include_str!("../assets/tests/certificate-1.pem").to_string(),
key: include_str!("../assets/tests/key.pem").to_string(),
..Default::default()
},
CertificateAndKey {
certificate: include_str!("../assets/tests/certificate-2.pem").to_string(),
key: include_str!("../assets/tests/key.pem").to_string(),
..Default::default()
},
CertificateAndKey {
certificate: include_str!("../assets/tests/certificate-3.pem").to_string(),
key: include_str!("../assets/tests/key.pem").to_string(),
..Default::default()
},
CertificateAndKey {
certificate: include_str!("../assets/tests/certificate-4.pem").to_string(),
key: include_str!("../assets/tests/key.pem").to_string(),
..Default::default()
},
CertificateAndKey {
certificate: include_str!("../assets/tests/certificate-5.pem").to_string(),
key: include_str!("../assets/tests/key.pem").to_string(),
..Default::default()
},
CertificateAndKey {
certificate: include_str!("../assets/tests/certificate-6.pem").to_string(),
key: include_str!("../assets/tests/key.pem").to_string(),
..Default::default()
},
];
let mut fingerprints = vec![];
for certificate in &certificates {
let pem = parse_pem(certificate.certificate.as_bytes())?;
fingerprints.push(fingerprint(&pem.contents));
}
certificates.shuffle(&mut thread_rng());
let address = "127.0.0.1:8080".to_string();
let mut resolver = CertificateResolver::default();
for certificate in &certificates {
resolver.add_certificate(&AddCertificate {
address: address.clone(),
certificate: certificate.to_owned(),
expired_at: None,
})?;
}
let mut names = HashSet::new();
names.insert("example.org".to_string());
let fprints = resolver.find_certificates_by_names(&names)?;
if 1 != fprints.len() && !fprints.contains(&fingerprints[1]) {
return Err("domain 'example.org' resolve to the wrong certificate".into());
}
let mut names = HashSet::new();
names.insert("*.example.org".to_string());
let fprints = resolver.find_certificates_by_names(&names)?;
if 1 != fprints.len() && !fprints.contains(&fingerprints[2]) {
return Err("domain '*.example.org' resolve to the wrong certificate".into());
}
let mut names = HashSet::new();
names.insert("clever-cloud.com".to_string());
let fprints = resolver.find_certificates_by_names(&names)?;
if 1 != fprints.len() && !fprints.contains(&fingerprints[4]) {
return Err("domain 'clever-cloud.com' resolve to the wrong certificate".into());
}
let mut names = HashSet::new();
names.insert("*.clever-cloud.com".to_string());
let fprints = resolver.find_certificates_by_names(&names)?;
if 1 != fprints.len() && !fprints.contains(&fingerprints[5]) {
return Err("domain '*.clever-cloud.com' resolve to the wrong certificate".into());
}
Ok(())
}
}