trz_gateway_common/security_configuration/certificate/
tls_server.rs

1use std::sync::Arc;
2use std::time::SystemTime;
3
4use nameth::NamedEnumValues as _;
5use nameth::nameth;
6use openssl::error::ErrorStack;
7use openssl::pkey::PKey;
8use openssl::pkey::Private;
9use openssl::x509::X509;
10use rustls::ServerConfig;
11use rustls::pki_types::CertificateDer;
12use rustls::pki_types::PrivateKeyDer;
13use rustls::server::ClientHello;
14use rustls::server::ResolvesServerCert;
15use rustls::sign::CertifiedKey;
16use tracing::Level;
17use tracing::debug;
18use tracing::info;
19use tracing::info_span;
20use tracing::warn;
21
22use super::CertificateConfig;
23use crate::certificate_info::CertificateInfo;
24use crate::certificate_info::X509CertificateInfo;
25use crate::crypto_provider::crypto_provider;
26use crate::security_configuration::certificate::display_x509_certificate;
27use crate::x509::time::asn1_to_system_time;
28
29/// Create a RusTLS [ServerConfig] from a [CertificateConfig].
30pub trait ToTlsServer: CertificateConfig + Sized {
31    fn to_tls_server(self) -> Result<Arc<ServerConfig>, ToTlsServerError<Self::Error>> {
32        to_tls_server_impl(self)
33    }
34}
35
36impl<T: CertificateConfig> ToTlsServer for T {}
37
38fn to_tls_server_impl<T: CertificateConfig>(
39    certificate_config: T,
40) -> Result<Arc<ServerConfig>, ToTlsServerError<T::Error>> {
41    let _span = info_span!("Setup TLS server certificate").entered();
42    let server_config = ServerConfig::builder().with_no_client_auth();
43    let mut server_config = if certificate_config.is_dynamic() {
44        server_config.with_cert_resolver(Arc::new(ServerCertificateResolver {
45            state: Default::default(),
46            certificate_config,
47        }))
48    } else {
49        let (certificate_chain, private_key) = build_single_cert::<T>(
50            &*certificate_config
51                .certificate()
52                .map_err(ToTlsServerError::Certificate)?,
53            &certificate_config
54                .intermediates()
55                .map_err(ToTlsServerError::Intermediates)?,
56        )?;
57        server_config
58            .with_single_cert(certificate_chain, private_key)
59            .map_err(ToTlsServerError::ServerConfig)?
60    };
61    server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
62    Ok(Arc::new(server_config))
63}
64
65fn build_single_cert<T: CertificateConfig>(
66    certificate: &X509CertificateInfo,
67    intermediates: &[X509],
68) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), ToTlsServerError<T::Error>> {
69    let mut certificate_chain = vec![];
70    {
71        log_server_certiticate(certificate);
72        let certificate = certificate.certificate.to_der();
73        let certificate = certificate.map_err(ToTlsServerError::CertificateToDer)?;
74        certificate_chain.push(certificate.into());
75    }
76    for intermediate in intermediates.iter() {
77        info!(
78            "Intermediate certificate: {:?} issued by {:?}",
79            intermediate.subject_name(),
80            intermediate.issuer_name()
81        );
82        debug!(
83            "Intermediate certificate details: {}",
84            display_x509_certificate(intermediate)
85        );
86        let intermediate = intermediate.to_der();
87        let intermediate = intermediate.map_err(ToTlsServerError::IntermediateToDer)?;
88        certificate_chain.push(intermediate.into());
89    }
90
91    let private_key = certificate
92        .private_key
93        .private_key_to_der()
94        .map_err(ToTlsServerError::PrivateKeyToDer)?
95        .try_into()
96        .map_err(ToTlsServerError::ToPrivateKey)?;
97
98    Ok((certificate_chain, private_key))
99}
100
101struct ServerCertificateResolver<T> {
102    certificate_config: T,
103    state: std::sync::Mutex<Option<CertResolverState>>,
104}
105
106struct CertResolverState {
107    certified_key: Arc<CertifiedKey>,
108    certificate: Arc<X509CertificateInfo>,
109    intermediates: Arc<Vec<X509>>,
110}
111
112impl<T> std::fmt::Debug for ServerCertificateResolver<T> {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("CertResolver").finish()
115    }
116}
117
118impl<T: CertificateConfig> ResolvesServerCert for ServerCertificateResolver<T> {
119    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
120        let _span = info_span!(
121            "Resolve server certificate",
122            host = client_hello.server_name()
123        )
124        .entered();
125        let mut state = self.state.lock().unwrap();
126        match self.resolve_impl(&mut state) {
127            Ok(certified_key) => Some(certified_key),
128            Err(error) => {
129                warn!("Failed to resolve server certificate: {error}");
130                if let Some(state) = &*state {
131                    info!("Reuse stale cached server certificate");
132                    Some(state.certified_key.clone())
133                } else {
134                    None
135                }
136            }
137        }
138    }
139}
140
141impl<T: CertificateConfig> ServerCertificateResolver<T> {
142    fn resolve_impl(
143        &self,
144        state: &mut Option<CertResolverState>,
145    ) -> Result<Arc<CertifiedKey>, ToTlsServerError<T::Error>> {
146        let certificate = self
147            .certificate_config
148            .certificate()
149            .map_err(ToTlsServerError::Certificate)?;
150        let intermediates = self
151            .certificate_config
152            .intermediates()
153            .map_err(ToTlsServerError::Intermediates)?;
154
155        if let Some(state) = state {
156            if Arc::ptr_eq(&certificate, &state.certificate)
157                && Arc::ptr_eq(&intermediates, &state.intermediates)
158            {
159                debug!("Reuse cached server certificate");
160                return Ok(state.certified_key.clone());
161            }
162        }
163
164        log_server_certiticate(&certificate);
165        let certified_key = self.make_certified_key(&certificate, &intermediates)?;
166        *state = Some(CertResolverState {
167            certified_key: certified_key.clone(),
168            certificate,
169            intermediates,
170        });
171        return Ok(certified_key);
172    }
173
174    fn make_certified_key(
175        &self,
176        certificate: &X509CertificateInfo,
177        intermediates: &[X509],
178    ) -> Result<Arc<CertifiedKey>, ToTlsServerError<T::Error>> {
179        let (certificate_chain, private_key) = build_single_cert::<T>(certificate, intermediates)?;
180        let certified_key =
181            CertifiedKey::from_der(certificate_chain, private_key, crypto_provider())
182                .map_err(ToTlsServerError::CertifiedKey)?;
183        Ok(Arc::new(certified_key))
184    }
185}
186
187#[nameth]
188#[derive(thiserror::Error, Debug)]
189pub enum ToTlsServerError<E: std::error::Error> {
190    #[error("[{n}] {0}", n = self.name())]
191    Certificate(E),
192
193    #[error("[{n}] {0}", n = self.name())]
194    CertificateToDer(ErrorStack),
195
196    #[error("[{n}] {0}", n = self.name())]
197    Intermediates(E),
198
199    #[error("[{n}] {0}", n = self.name())]
200    IntermediateToDer(ErrorStack),
201
202    #[error("[{n}] {0}", n = self.name())]
203    PrivateKeyToDer(ErrorStack),
204
205    #[error("[{n}] {0}", n = self.name())]
206    ToPrivateKey(&'static str),
207
208    #[error("[{n}] {0}", n = self.name())]
209    ServerConfig(rustls::Error),
210
211    #[error("[{n}] {0}", n = self.name())]
212    CertifiedKey(rustls::Error),
213}
214
215fn log_server_certiticate(certificate: &CertificateInfo<X509, PKey<Private>>) {
216    if !tracing::enabled!(Level::INFO) {
217        return;
218    }
219    let now = SystemTime::now();
220    let subject = certificate.certificate.subject_name();
221    let issuer = certificate.certificate.issuer_name();
222    let not_after = certificate.certificate.not_after();
223    let expiration =
224        match asn1_to_system_time(not_after).map(|not_after| not_after.duration_since(now)) {
225            Ok(Ok(expiration)) => humantime::format_duration(expiration).to_string(),
226            Err(error) => format!("Err: {error}"),
227            Ok(Err(error)) => format!("Err: {error}"),
228        };
229    info! { "Server certificate: {subject:?} issued by {issuer:?} expires {not_after} ({expiration})" };
230    debug!("Server certificate details: {}", certificate.display());
231}