trz_gateway_common/security_configuration/trusted_store/
tls_client.rs1use std::sync::Arc;
2
3use nameth::NamedEnumValues as _;
4use nameth::NamedType as _;
5use nameth::nameth;
6use rustls::ClientConfig;
7use rustls::client::WebPkiServerVerifier;
8use rustls::client::danger::HandshakeSignatureValid;
9use rustls::client::danger::ServerCertVerified;
10use rustls::pki_types::CertificateDer;
11use rustls::pki_types::ServerName;
12use rustls::pki_types::UnixTime;
13use rustls::server::VerifierBuilderError;
14use tracing::info;
15use tracing::info_span;
16
17use super::TrustedStoreConfig;
18use super::root_cert_store::ToRootCertStore;
19use super::root_cert_store::ToRootCertStoreError;
20use crate::security_configuration::custom_server_certificate_verifier::CustomServerCertificateVerifier;
21
22pub trait ToTlsClient: TrustedStoreConfig + Sized {
26 fn to_tls_client(
27 &self,
28 server_certificate_verifier: impl CustomServerCertificateVerifier + 'static,
29 ) -> Result<ClientConfig, ToTlsClientError<Self::Error>> {
30 to_tls_client_impl(self, server_certificate_verifier)
31 }
32}
33
34impl<T: TrustedStoreConfig> ToTlsClient for T {}
35
36fn to_tls_client_impl<T, V>(
37 trusted_store_config: &T,
38 server_certificate_verifier: V,
39) -> Result<ClientConfig, ToTlsClientError<T::Error>>
40where
41 T: TrustedStoreConfig,
42 V: CustomServerCertificateVerifier + 'static,
43{
44 let _span = info_span!("Setup TLS client").entered();
45 let root_store = Arc::new(trusted_store_config.to_root_cert_store()?);
46 let builder = if V::has_custom_logic() {
47 info!("Use root certificates + custom server certificate verifier");
48 ClientConfig::builder()
49 .dangerous()
50 .with_custom_certificate_verifier(Arc::new(CustomWebPkiServerVerifier {
51 custom: server_certificate_verifier,
52 chain: WebPkiServerVerifier::builder(root_store).build()?,
53 }))
54 } else {
55 info!("Use root certificates");
56 ClientConfig::builder().with_root_certificates(root_store)
57 };
58 Ok(builder.with_no_client_auth())
59}
60
61#[nameth]
62#[derive(thiserror::Error, Debug)]
63pub enum ToTlsClientError<E: std::error::Error> {
64 #[error("[{n}] {0}", n = self.name())]
65 ToRootCertStore(#[from] ToRootCertStoreError<E>),
66
67 #[error("[{n}] {0}", n = self.name())]
68 VerifierBuilderError(#[from] VerifierBuilderError),
69}
70
71#[nameth]
72struct CustomWebPkiServerVerifier<T> {
73 custom: T,
74 chain: Arc<WebPkiServerVerifier>,
75}
76
77impl<T: CustomServerCertificateVerifier> rustls::client::danger::ServerCertVerifier
78 for CustomWebPkiServerVerifier<T>
79{
80 fn verify_server_cert(
81 &self,
82 end_entity: &CertificateDer<'_>,
83 intermediates: &[CertificateDer<'_>],
84 server_name: &ServerName<'_>,
85 ocsp_response: &[u8],
86 now: UnixTime,
87 ) -> Result<ServerCertVerified, rustls::Error> {
88 let ServerCertVerified { .. } = self.custom.verify_server_certificate(
89 end_entity,
90 intermediates,
91 server_name,
92 ocsp_response,
93 now,
94 )?;
95 self.chain
96 .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
97 }
98
99 fn verify_tls12_signature(
100 &self,
101 message: &[u8],
102 cert: &CertificateDer<'_>,
103 dss: &rustls::DigitallySignedStruct,
104 ) -> Result<HandshakeSignatureValid, rustls::Error> {
105 self.chain.verify_tls12_signature(message, cert, dss)
106 }
107
108 fn verify_tls13_signature(
109 &self,
110 message: &[u8],
111 cert: &CertificateDer<'_>,
112 dss: &rustls::DigitallySignedStruct,
113 ) -> Result<HandshakeSignatureValid, rustls::Error> {
114 self.chain.verify_tls13_signature(message, cert, dss)
115 }
116
117 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
118 self.chain.supported_verify_schemes()
119 }
120
121 fn requires_raw_public_keys(&self) -> bool {
122 self.chain.requires_raw_public_keys()
123 }
124
125 fn root_hint_subjects(&self) -> Option<&[rustls::DistinguishedName]> {
126 self.chain.root_hint_subjects()
127 }
128}
129
130impl<T> std::fmt::Debug for CustomWebPkiServerVerifier<T> {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 f.debug_tuple(Self::type_name()).finish()
133 }
134}