sozu_command_lib/
certificate.rs

1use std::{fmt, str::FromStr};
2
3use hex::{FromHex, FromHexError};
4use serde::de::{self, Visitor};
5use sha2::{Digest, Sha256};
6use x509_parser::{
7    certificate::X509Certificate,
8    extensions::{GeneralName, ParsedExtension},
9    oid_registry::{OID_X509_COMMON_NAME, OID_X509_EXT_SUBJECT_ALT_NAME},
10    parse_x509_certificate,
11    pem::{parse_x509_pem, Pem},
12};
13
14use crate::{
15    config::{Config, ConfigError},
16    proto::command::{CertificateAndKey, TlsVersion},
17};
18
19// -----------------------------------------------------------------------------
20// CertificateError
21
22#[derive(thiserror::Error, Debug)]
23pub enum CertificateError {
24    #[error("Could not parse PEM certificate from bytes: {0}")]
25    ParsePEMCertificate(String),
26    #[error("Could not parse X509 certificate from bytes: {0}")]
27    ParseX509Certificate(String),
28    #[error("failed to parse tls version '{0}'")]
29    InvalidTlsVersion(String),
30    #[error("failed to parse fingerprint, {0}")]
31    InvalidFingerprint(FromHexError),
32    #[error("could not load file on path {path}: {error}")]
33    LoadFile { path: String, error: ConfigError },
34    #[error("Failed at decoding the hex encoded certificate: {0}")]
35    DecodeError(FromHexError),
36}
37
38// -----------------------------------------------------------------------------
39// parse
40
41/// parse a pem file encoded as binary and convert it into the right structure
42/// (a.k.a [`Pem`])
43pub fn parse_pem(certificate: &[u8]) -> Result<Pem, CertificateError> {
44    let (_, pem) = parse_x509_pem(certificate)
45        .map_err(|err| CertificateError::ParsePEMCertificate(err.to_string()))?;
46
47    Ok(pem)
48}
49
50/// parse x509 certificate from PEM bytes
51pub fn parse_x509(pem_bytes: &[u8]) -> Result<X509Certificate, CertificateError> {
52    parse_x509_certificate(pem_bytes)
53        .map_err(|nom_e| CertificateError::ParseX509Certificate(nom_e.to_string()))
54        .map(|t| t.1)
55}
56
57// -----------------------------------------------------------------------------
58// get_cn_and_san_attributes
59
60/// Retrieve from the x509 the common name (a.k.a `CN`) and the
61/// subject alternate names (a.k.a `SAN`)
62pub fn get_cn_and_san_attributes(x509: &X509Certificate) -> Vec<String> {
63    let mut names: Vec<String> = Vec::new();
64    for name in x509.subject().iter_by_oid(&OID_X509_COMMON_NAME) {
65        names.push(
66            name.as_str()
67                .map(String::from)
68                .unwrap_or_else(|_| String::from_utf8_lossy(name.as_slice()).to_string()),
69        );
70    }
71
72    for extension in x509.extensions() {
73        if extension.oid == OID_X509_EXT_SUBJECT_ALT_NAME {
74            if let ParsedExtension::SubjectAlternativeName(san) = extension.parsed_extension() {
75                for name in &san.general_names {
76                    if let GeneralName::DNSName(name) = name {
77                        names.push(name.to_string());
78                    }
79                }
80            }
81        }
82    }
83    names.dedup();
84    names
85}
86
87// -----------------------------------------------------------------------------
88// TlsVersion
89
90impl FromStr for TlsVersion {
91    type Err = CertificateError;
92
93    fn from_str(s: &str) -> Result<Self, Self::Err> {
94        match s {
95            "SSL_V2" => Ok(TlsVersion::SslV2),
96            "SSL_V3" => Ok(TlsVersion::SslV3),
97            "TLSv1" => Ok(TlsVersion::TlsV10),
98            "TLS_V11" => Ok(TlsVersion::TlsV11),
99            "TLS_V12" => Ok(TlsVersion::TlsV12),
100            "TLS_V13" => Ok(TlsVersion::TlsV13),
101            _ => Err(CertificateError::InvalidTlsVersion(s.to_string())),
102        }
103    }
104}
105
106// -----------------------------------------------------------------------------
107// Fingerprint
108
109//FIXME: make fixed size depending on hash algorithm
110/// A TLS certificates, encoded in bytes
111#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
112pub struct Fingerprint(pub Vec<u8>);
113
114impl FromStr for Fingerprint {
115    type Err = CertificateError;
116
117    fn from_str(s: &str) -> Result<Self, Self::Err> {
118        hex::decode(s)
119            .map_err(CertificateError::InvalidFingerprint)
120            .map(Fingerprint)
121    }
122}
123
124impl fmt::Debug for Fingerprint {
125    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
126        write!(f, "CertificateFingerprint({})", hex::encode(&self.0))
127    }
128}
129
130impl fmt::Display for Fingerprint {
131    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
132        write!(f, "{}", hex::encode(&self.0))
133    }
134}
135
136impl serde::Serialize for Fingerprint {
137    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
138    where
139        S: serde::Serializer,
140    {
141        serializer.serialize_str(&hex::encode(&self.0))
142    }
143}
144
145struct FingerprintVisitor;
146
147impl<'de> Visitor<'de> for FingerprintVisitor {
148    type Value = Fingerprint;
149
150    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
151        formatter.write_str("the certificate fingerprint must be in hexadecimal format")
152    }
153
154    fn visit_str<E>(self, value: &str) -> Result<Fingerprint, E>
155    where
156        E: de::Error,
157    {
158        FromHex::from_hex(value)
159            .map_err(|e| E::custom(format!("could not deserialize hex: {e:?}")))
160            .map(Fingerprint)
161    }
162}
163
164impl<'de> serde::Deserialize<'de> for Fingerprint {
165    fn deserialize<D>(deserializer: D) -> Result<Fingerprint, D::Error>
166    where
167        D: serde::de::Deserializer<'de>,
168    {
169        deserializer.deserialize_str(FingerprintVisitor {})
170    }
171}
172
173/// Compute fingerprint from decoded pem as binary value
174pub fn calculate_fingerprint_from_der(certificate: &[u8]) -> Vec<u8> {
175    Sha256::digest(certificate).iter().cloned().collect()
176}
177
178/// Compute fingerprint from a certificate that is encoded in pem format
179pub fn calculate_fingerprint(certificate: &[u8]) -> Result<Vec<u8>, CertificateError> {
180    let parsed_certificate = parse_pem(certificate)?;
181    let fingerprint = calculate_fingerprint_from_der(&parsed_certificate.contents);
182    Ok(fingerprint)
183}
184
185pub fn split_certificate_chain(mut chain: String) -> Vec<String> {
186    let mut v = Vec::new();
187
188    let end = "-----END CERTIFICATE-----";
189    loop {
190        if let Some(sz) = chain.find(end) {
191            let cert: String = chain.drain(..sz + end.len()).collect();
192            v.push(cert.trim().to_string());
193            continue;
194        }
195
196        break;
197    }
198
199    v
200}
201
202pub fn get_fingerprint_from_certificate_path(
203    certificate_path: &str,
204) -> Result<Fingerprint, CertificateError> {
205    let bytes =
206        Config::load_file_bytes(certificate_path).map_err(|e| CertificateError::LoadFile {
207            path: certificate_path.to_string(),
208            error: e,
209        })?;
210
211    let parsed_bytes = calculate_fingerprint(&bytes)?;
212
213    Ok(Fingerprint(parsed_bytes))
214}
215
216pub fn decode_fingerprint(fingerprint: &str) -> Result<Fingerprint, CertificateError> {
217    let bytes = hex::decode(fingerprint).map_err(CertificateError::DecodeError)?;
218    Ok(Fingerprint(bytes))
219}
220
221pub fn load_full_certificate(
222    certificate_path: &str,
223    certificate_chain_path: &str,
224    key_path: &str,
225    versions: Vec<TlsVersion>,
226    names: Vec<String>,
227) -> Result<CertificateAndKey, CertificateError> {
228    let certificate =
229        Config::load_file(certificate_path).map_err(|e| CertificateError::LoadFile {
230            path: certificate_path.to_string(),
231            error: e,
232        })?;
233
234    let certificate_chain = Config::load_file(certificate_chain_path)
235        .map(split_certificate_chain)
236        .map_err(|e| CertificateError::LoadFile {
237            path: certificate_chain_path.to_string(),
238            error: e,
239        })?;
240
241    let key = Config::load_file(key_path).map_err(|e| CertificateError::LoadFile {
242        path: key_path.to_string(),
243        error: e,
244    })?;
245
246    let versions = versions.iter().map(|v| *v as i32).collect();
247
248    Ok(CertificateAndKey {
249        certificate,
250        certificate_chain,
251        key,
252        versions,
253        names,
254    })
255}
256
257impl CertificateAndKey {
258    pub fn fingerprint(&self) -> Result<Fingerprint, CertificateError> {
259        let pem = parse_pem(self.certificate.as_bytes())?;
260        let fingerprint = Fingerprint(Sha256::digest(pem.contents).iter().cloned().collect());
261        Ok(fingerprint)
262    }
263
264    pub fn get_overriding_names(&self) -> Result<Vec<String>, CertificateError> {
265        if self.names.is_empty() {
266            let pem = parse_pem(self.certificate.as_bytes())?;
267            let x509 = parse_x509(&pem.contents)?;
268
269            let overriding_names = get_cn_and_san_attributes(&x509);
270
271            Ok(overriding_names.into_iter().collect())
272        } else {
273            Ok(self.names.to_owned())
274        }
275    }
276
277    pub fn apply_overriding_names(&mut self) -> Result<(), CertificateError> {
278        self.names = self.get_overriding_names()?;
279        Ok(())
280    }
281}