trz_gateway_common/security_configuration/certificate/
tls_server.rs1use std::sync::Arc;
2use std::time::SystemTime;
3
4use nameth::NamedEnumValues as _;
5use nameth::nameth;
6use openssl::error::ErrorStack;
7use openssl::pkey::PKey;
8use openssl::pkey::Private;
9use openssl::x509::X509;
10use rustls::ServerConfig;
11use rustls::pki_types::CertificateDer;
12use rustls::pki_types::PrivateKeyDer;
13use rustls::server::ClientHello;
14use rustls::server::ResolvesServerCert;
15use rustls::sign::CertifiedKey;
16use tracing::Level;
17use tracing::debug;
18use tracing::info;
19use tracing::info_span;
20use tracing::warn;
21
22use super::CertificateConfig;
23use crate::certificate_info::CertificateInfo;
24use crate::certificate_info::X509CertificateInfo;
25use crate::crypto_provider::crypto_provider;
26use crate::security_configuration::certificate::display_x509_certificate;
27use crate::x509::time::asn1_to_system_time;
28
29pub trait ToTlsServer: CertificateConfig + Sized {
31 fn to_tls_server(self) -> Result<Arc<ServerConfig>, ToTlsServerError<Self::Error>> {
32 to_tls_server_impl(self)
33 }
34}
35
36impl<T: CertificateConfig> ToTlsServer for T {}
37
38fn to_tls_server_impl<T: CertificateConfig>(
39 certificate_config: T,
40) -> Result<Arc<ServerConfig>, ToTlsServerError<T::Error>> {
41 let _span = info_span!("Setup TLS server certificate").entered();
42 let server_config = ServerConfig::builder().with_no_client_auth();
43 let mut server_config = if certificate_config.is_dynamic() {
44 server_config.with_cert_resolver(Arc::new(ServerCertificateResolver {
45 state: Default::default(),
46 certificate_config,
47 }))
48 } else {
49 let (certificate_chain, private_key) = build_single_cert::<T>(
50 &*certificate_config
51 .certificate()
52 .map_err(ToTlsServerError::Certificate)?,
53 &certificate_config
54 .intermediates()
55 .map_err(ToTlsServerError::Intermediates)?,
56 )?;
57 server_config
58 .with_single_cert(certificate_chain, private_key)
59 .map_err(ToTlsServerError::ServerConfig)?
60 };
61 server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
62 Ok(Arc::new(server_config))
63}
64
65fn build_single_cert<T: CertificateConfig>(
66 certificate: &X509CertificateInfo,
67 intermediates: &[X509],
68) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), ToTlsServerError<T::Error>> {
69 let mut certificate_chain = vec![];
70 {
71 log_server_certiticate(certificate);
72 let certificate = certificate.certificate.to_der();
73 let certificate = certificate.map_err(ToTlsServerError::CertificateToDer)?;
74 certificate_chain.push(certificate.into());
75 }
76 for intermediate in intermediates.iter() {
77 info!(
78 "Intermediate certificate: {:?} issued by {:?}",
79 intermediate.subject_name(),
80 intermediate.issuer_name()
81 );
82 debug!(
83 "Intermediate certificate details: {}",
84 display_x509_certificate(intermediate)
85 );
86 let intermediate = intermediate.to_der();
87 let intermediate = intermediate.map_err(ToTlsServerError::IntermediateToDer)?;
88 certificate_chain.push(intermediate.into());
89 }
90
91 let private_key = certificate
92 .private_key
93 .private_key_to_der()
94 .map_err(ToTlsServerError::PrivateKeyToDer)?
95 .try_into()
96 .map_err(ToTlsServerError::ToPrivateKey)?;
97
98 Ok((certificate_chain, private_key))
99}
100
101struct ServerCertificateResolver<T> {
102 certificate_config: T,
103 state: std::sync::Mutex<Option<CertResolverState>>,
104}
105
106struct CertResolverState {
107 certified_key: Arc<CertifiedKey>,
108 certificate: Arc<X509CertificateInfo>,
109 intermediates: Arc<Vec<X509>>,
110}
111
112impl<T> std::fmt::Debug for ServerCertificateResolver<T> {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("CertResolver").finish()
115 }
116}
117
118impl<T: CertificateConfig> ResolvesServerCert for ServerCertificateResolver<T> {
119 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
120 let _span = info_span!(
121 "Resolve server certificate",
122 host = client_hello.server_name()
123 )
124 .entered();
125 let mut state = self.state.lock().unwrap();
126 match self.resolve_impl(&mut state) {
127 Ok(certified_key) => Some(certified_key),
128 Err(error) => {
129 warn!("Failed to resolve server certificate: {error}");
130 if let Some(state) = &*state {
131 info!("Reuse stale cached server certificate");
132 Some(state.certified_key.clone())
133 } else {
134 None
135 }
136 }
137 }
138 }
139}
140
141impl<T: CertificateConfig> ServerCertificateResolver<T> {
142 fn resolve_impl(
143 &self,
144 state: &mut Option<CertResolverState>,
145 ) -> Result<Arc<CertifiedKey>, ToTlsServerError<T::Error>> {
146 let certificate = self
147 .certificate_config
148 .certificate()
149 .map_err(ToTlsServerError::Certificate)?;
150 let intermediates = self
151 .certificate_config
152 .intermediates()
153 .map_err(ToTlsServerError::Intermediates)?;
154
155 if let Some(state) = state
156 && Arc::ptr_eq(&certificate, &state.certificate)
157 && Arc::ptr_eq(&intermediates, &state.intermediates)
158 {
159 debug!("Reuse cached server certificate");
160 return Ok(state.certified_key.clone());
161 }
162
163 log_server_certiticate(&certificate);
164 let certified_key = self.make_certified_key(&certificate, &intermediates)?;
165 *state = Some(CertResolverState {
166 certified_key: certified_key.clone(),
167 certificate,
168 intermediates,
169 });
170 return Ok(certified_key);
171 }
172
173 fn make_certified_key(
174 &self,
175 certificate: &X509CertificateInfo,
176 intermediates: &[X509],
177 ) -> Result<Arc<CertifiedKey>, ToTlsServerError<T::Error>> {
178 let (certificate_chain, private_key) = build_single_cert::<T>(certificate, intermediates)?;
179 let certified_key =
180 CertifiedKey::from_der(certificate_chain, private_key, crypto_provider())
181 .map_err(ToTlsServerError::CertifiedKey)?;
182 Ok(Arc::new(certified_key))
183 }
184}
185
186#[nameth]
187#[derive(thiserror::Error, Debug)]
188pub enum ToTlsServerError<E: std::error::Error> {
189 #[error("[{n}] {0}", n = self.name())]
190 Certificate(E),
191
192 #[error("[{n}] {0}", n = self.name())]
193 CertificateToDer(ErrorStack),
194
195 #[error("[{n}] {0}", n = self.name())]
196 Intermediates(E),
197
198 #[error("[{n}] {0}", n = self.name())]
199 IntermediateToDer(ErrorStack),
200
201 #[error("[{n}] {0}", n = self.name())]
202 PrivateKeyToDer(ErrorStack),
203
204 #[error("[{n}] {0}", n = self.name())]
205 ToPrivateKey(&'static str),
206
207 #[error("[{n}] {0}", n = self.name())]
208 ServerConfig(rustls::Error),
209
210 #[error("[{n}] {0}", n = self.name())]
211 CertifiedKey(rustls::Error),
212}
213
214fn log_server_certiticate(certificate: &CertificateInfo<X509, PKey<Private>>) {
215 if !tracing::enabled!(Level::INFO) {
216 return;
217 }
218 let now = SystemTime::now();
219 let subject = certificate.certificate.subject_name();
220 let issuer = certificate.certificate.issuer_name();
221 let not_after = certificate.certificate.not_after();
222 let expiration =
223 match asn1_to_system_time(not_after).map(|not_after| not_after.duration_since(now)) {
224 Ok(Ok(expiration)) => humantime::format_duration(expiration).to_string(),
225 Err(error) => format!("Err: {error}"),
226 Ok(Err(error)) => format!("Err: {error}"),
227 };
228 info! { "Server certificate: {subject:?} issued by {issuer:?} expires {not_after} ({expiration})" };
229 debug!("Server certificate details: {}", certificate.display());
230}