trz_gateway_common/security_configuration/certificate/
tls_server.rs

1use std::sync::Arc;
2use std::time::SystemTime;
3
4use nameth::NamedEnumValues as _;
5use nameth::nameth;
6use openssl::error::ErrorStack;
7use openssl::pkey::PKey;
8use openssl::pkey::Private;
9use openssl::x509::X509;
10use rustls::ServerConfig;
11use rustls::pki_types::CertificateDer;
12use rustls::pki_types::PrivateKeyDer;
13use rustls::server::ClientHello;
14use rustls::server::ResolvesServerCert;
15use rustls::sign::CertifiedKey;
16use tracing::Level;
17use tracing::debug;
18use tracing::info;
19use tracing::info_span;
20use tracing::warn;
21
22use super::CertificateConfig;
23use crate::certificate_info::CertificateInfo;
24use crate::certificate_info::X509CertificateInfo;
25use crate::crypto_provider::crypto_provider;
26use crate::security_configuration::certificate::display_x509_certificate;
27use crate::x509::time::asn1_to_system_time;
28
29/// Create a RusTLS [ServerConfig] from a [CertificateConfig].
30pub trait ToTlsServer: CertificateConfig + Sized {
31    fn to_tls_server(self) -> Result<Arc<ServerConfig>, ToTlsServerError<Self::Error>> {
32        to_tls_server_impl(self)
33    }
34}
35
36impl<T: CertificateConfig> ToTlsServer for T {}
37
38fn to_tls_server_impl<T: CertificateConfig>(
39    certificate_config: T,
40) -> Result<Arc<ServerConfig>, ToTlsServerError<T::Error>> {
41    let _span = info_span!("Setup TLS server certificate").entered();
42    let server_config = ServerConfig::builder().with_no_client_auth();
43    let mut server_config = if certificate_config.is_dynamic() {
44        server_config.with_cert_resolver(Arc::new(ServerCertificateResolver {
45            state: Default::default(),
46            certificate_config,
47        }))
48    } else {
49        let (certificate_chain, private_key) = build_single_cert::<T>(
50            &*certificate_config
51                .certificate()
52                .map_err(ToTlsServerError::Certificate)?,
53            &certificate_config
54                .intermediates()
55                .map_err(ToTlsServerError::Intermediates)?,
56        )?;
57        server_config
58            .with_single_cert(certificate_chain, private_key)
59            .map_err(ToTlsServerError::ServerConfig)?
60    };
61    server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
62    Ok(Arc::new(server_config))
63}
64
65fn build_single_cert<T: CertificateConfig>(
66    certificate: &X509CertificateInfo,
67    intermediates: &[X509],
68) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), ToTlsServerError<T::Error>> {
69    let mut certificate_chain = vec![];
70    {
71        log_server_certiticate(certificate);
72        let certificate = certificate.certificate.to_der();
73        let certificate = certificate.map_err(ToTlsServerError::CertificateToDer)?;
74        certificate_chain.push(certificate.into());
75    }
76    for intermediate in intermediates.iter() {
77        info!(
78            "Intermediate certificate: {:?} issued by {:?}",
79            intermediate.subject_name(),
80            intermediate.issuer_name()
81        );
82        debug!(
83            "Intermediate certificate details: {}",
84            display_x509_certificate(intermediate)
85        );
86        let intermediate = intermediate.to_der();
87        let intermediate = intermediate.map_err(ToTlsServerError::IntermediateToDer)?;
88        certificate_chain.push(intermediate.into());
89    }
90
91    let private_key = certificate
92        .private_key
93        .private_key_to_der()
94        .map_err(ToTlsServerError::PrivateKeyToDer)?
95        .try_into()
96        .map_err(ToTlsServerError::ToPrivateKey)?;
97
98    Ok((certificate_chain, private_key))
99}
100
101struct ServerCertificateResolver<T> {
102    certificate_config: T,
103    state: std::sync::Mutex<Option<CertResolverState>>,
104}
105
106struct CertResolverState {
107    certified_key: Arc<CertifiedKey>,
108    certificate: Arc<X509CertificateInfo>,
109    intermediates: Arc<Vec<X509>>,
110}
111
112impl<T> std::fmt::Debug for ServerCertificateResolver<T> {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("CertResolver").finish()
115    }
116}
117
118impl<T: CertificateConfig> ResolvesServerCert for ServerCertificateResolver<T> {
119    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
120        let _span = info_span!(
121            "Resolve server certificate",
122            host = client_hello.server_name()
123        )
124        .entered();
125        let mut state = self.state.lock().unwrap();
126        match self.resolve_impl(&mut state) {
127            Ok(certified_key) => Some(certified_key),
128            Err(error) => {
129                warn!("Failed to resolve server certificate: {error}");
130                if let Some(state) = &*state {
131                    info!("Reuse stale cached server certificate");
132                    Some(state.certified_key.clone())
133                } else {
134                    None
135                }
136            }
137        }
138    }
139}
140
141impl<T: CertificateConfig> ServerCertificateResolver<T> {
142    fn resolve_impl(
143        &self,
144        state: &mut Option<CertResolverState>,
145    ) -> Result<Arc<CertifiedKey>, ToTlsServerError<T::Error>> {
146        let certificate = self
147            .certificate_config
148            .certificate()
149            .map_err(ToTlsServerError::Certificate)?;
150        let intermediates = self
151            .certificate_config
152            .intermediates()
153            .map_err(ToTlsServerError::Intermediates)?;
154
155        if let Some(state) = state
156            && Arc::ptr_eq(&certificate, &state.certificate)
157            && Arc::ptr_eq(&intermediates, &state.intermediates)
158        {
159            debug!("Reuse cached server certificate");
160            return Ok(state.certified_key.clone());
161        }
162
163        log_server_certiticate(&certificate);
164        let certified_key = self.make_certified_key(&certificate, &intermediates)?;
165        *state = Some(CertResolverState {
166            certified_key: certified_key.clone(),
167            certificate,
168            intermediates,
169        });
170        return Ok(certified_key);
171    }
172
173    fn make_certified_key(
174        &self,
175        certificate: &X509CertificateInfo,
176        intermediates: &[X509],
177    ) -> Result<Arc<CertifiedKey>, ToTlsServerError<T::Error>> {
178        let (certificate_chain, private_key) = build_single_cert::<T>(certificate, intermediates)?;
179        let certified_key =
180            CertifiedKey::from_der(certificate_chain, private_key, crypto_provider())
181                .map_err(ToTlsServerError::CertifiedKey)?;
182        Ok(Arc::new(certified_key))
183    }
184}
185
186#[nameth]
187#[derive(thiserror::Error, Debug)]
188pub enum ToTlsServerError<E: std::error::Error> {
189    #[error("[{n}] {0}", n = self.name())]
190    Certificate(E),
191
192    #[error("[{n}] {0}", n = self.name())]
193    CertificateToDer(ErrorStack),
194
195    #[error("[{n}] {0}", n = self.name())]
196    Intermediates(E),
197
198    #[error("[{n}] {0}", n = self.name())]
199    IntermediateToDer(ErrorStack),
200
201    #[error("[{n}] {0}", n = self.name())]
202    PrivateKeyToDer(ErrorStack),
203
204    #[error("[{n}] {0}", n = self.name())]
205    ToPrivateKey(&'static str),
206
207    #[error("[{n}] {0}", n = self.name())]
208    ServerConfig(rustls::Error),
209
210    #[error("[{n}] {0}", n = self.name())]
211    CertifiedKey(rustls::Error),
212}
213
214fn log_server_certiticate(certificate: &CertificateInfo<X509, PKey<Private>>) {
215    if !tracing::enabled!(Level::INFO) {
216        return;
217    }
218    let now = SystemTime::now();
219    let subject = certificate.certificate.subject_name();
220    let issuer = certificate.certificate.issuer_name();
221    let not_after = certificate.certificate.not_after();
222    let expiration =
223        match asn1_to_system_time(not_after).map(|not_after| not_after.duration_since(now)) {
224            Ok(Ok(expiration)) => humantime::format_duration(expiration).to_string(),
225            Err(error) => format!("Err: {error}"),
226            Ok(Err(error)) => format!("Err: {error}"),
227        };
228    info! { "Server certificate: {subject:?} issued by {issuer:?} expires {not_after} ({expiration})" };
229    debug!("Server certificate details: {}", certificate.display());
230}