tls_api_rustls/
connector.rs1use std::convert::TryFrom;
2use std::sync::Arc;
3
4use rustls::crypto::verify_tls12_signature;
5use rustls::crypto::verify_tls13_signature;
6use rustls::crypto::WebPkiSupportedAlgorithms;
7use rustls::StreamOwned;
8
9use tls_api::async_as_sync::AsyncIoAsSyncIo;
10use tls_api::spi_connector_common;
11use tls_api::AsyncSocket;
12use tls_api::AsyncSocketBox;
13use tls_api::BoxFuture;
14use tls_api::ImplInfo;
15
16use crate::handshake::HandshakeFuture;
17use crate::RustlsStream;
18use std::future::Future;
19
20pub struct TlsConnectorBuilder {
21 pub config: rustls::ClientConfig,
22 pub verify_hostname: bool,
23 pub root_store: rustls::RootCertStore,
24}
25pub struct TlsConnector {
26 pub config: Arc<rustls::ClientConfig>,
27}
28
29impl tls_api::TlsConnectorBuilder for TlsConnectorBuilder {
30 type Connector = TlsConnector;
31
32 type Underlying = rustls::ClientConfig;
33
34 fn underlying_mut(&mut self) -> &mut rustls::ClientConfig {
35 &mut self.config
36 }
37
38 fn set_alpn_protocols(&mut self, protocols: &[&[u8]]) -> anyhow::Result<()> {
39 self.config.alpn_protocols = protocols.iter().map(|p: &&[u8]| p.to_vec()).collect();
40 Ok(())
41 }
42
43 fn set_verify_hostname(&mut self, verify: bool) -> anyhow::Result<()> {
44 if !verify {
45 #[derive(Debug)]
46 struct NoCertificateServerVerifier {
47 supported: WebPkiSupportedAlgorithms,
48 }
49
50 impl rustls::client::danger::ServerCertVerifier for NoCertificateServerVerifier {
51 fn verify_server_cert(
52 &self,
53 _end_entity: &rustls::pki_types::CertificateDer<'_>,
54 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
55 _server_name: &rustls::pki_types::ServerName<'_>,
56 _ocsp_response: &[u8],
57 _now: rustls::pki_types::UnixTime,
58 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error>
59 {
60 Ok(rustls::client::danger::ServerCertVerified::assertion())
61 }
62
63 fn verify_tls12_signature(
64 &self,
65 message: &[u8],
66 cert: &rustls::pki_types::CertificateDer<'_>,
67 dss: &rustls::DigitallySignedStruct,
68 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
69 {
70 verify_tls12_signature(message, cert, dss, &self.supported)
71 }
72
73 fn verify_tls13_signature(
74 &self,
75 message: &[u8],
76 cert: &rustls::pki_types::CertificateDer<'_>,
77 dss: &rustls::DigitallySignedStruct,
78 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
79 {
80 verify_tls13_signature(message, cert, dss, &self.supported)
81 }
82
83 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
84 self.supported.supported_schemes()
85 }
86 }
87
88 let no_cert_verifier = NoCertificateServerVerifier {
89 supported: rustls::crypto::CryptoProvider::get_default()
90 .unwrap()
91 .signature_verification_algorithms,
92 };
93
94 self.config
95 .dangerous()
96 .set_certificate_verifier(Arc::new(no_cert_verifier));
97 self.verify_hostname = false;
98 } else if !self.verify_hostname {
99 return Err(crate::Error::VerifyHostnameTrue.into());
100 }
101
102 Ok(())
103 }
104
105 fn add_root_certificate(&mut self, cert: &[u8]) -> anyhow::Result<()> {
106 let cert = rustls::pki_types::CertificateDer::from(cert);
107 self.root_store.add(cert).map_err(anyhow::Error::new)?;
108 Ok(())
109 }
110
111 fn build(self) -> anyhow::Result<TlsConnector> {
112 let mut config = self.config;
113 if !self.root_store.is_empty() {
114 let mut new_config = rustls::ClientConfig::builder()
115 .with_root_certificates(self.root_store)
116 .with_no_client_auth();
117 new_config.alpn_protocols = config.alpn_protocols;
118 new_config.resumption = config.resumption;
119 new_config.max_fragment_size = config.max_fragment_size;
120 new_config.client_auth_cert_resolver = config.client_auth_cert_resolver;
121 new_config.enable_sni = config.enable_sni;
122 new_config.key_log = config.key_log;
123 new_config.enable_early_data = config.enable_early_data;
124 config = new_config;
125 }
126 Ok(TlsConnector {
127 config: Arc::new(config),
128 })
129 }
130}
131
132impl TlsConnector {
133 pub fn connect_impl<'a, S>(
134 &'a self,
135 domain: &'a str,
136 stream: S,
137 ) -> impl Future<Output = anyhow::Result<crate::TlsStream<S>>> + 'a
138 where
139 S: AsyncSocket,
140 {
141 let dns_name = rustls::pki_types::ServerName::try_from(domain);
142 let dns_name = match dns_name {
143 Ok(dns_name) => dns_name.to_owned(),
144 Err(e) => return BoxFuture::new(async { Err(anyhow::anyhow!(e)) }),
145 };
146 let conn = rustls::ClientConnection::new(self.config.clone(), dns_name);
147 let conn = match conn.map_err(anyhow::Error::new) {
148 Ok(conn) => conn,
149 Err(e) => return BoxFuture::new(async { Err(e) }),
150 };
151 let tls_stream: crate::TlsStream<S> =
152 crate::TlsStream::new(RustlsStream::Client(StreamOwned {
153 conn,
154 sock: AsyncIoAsSyncIo::new(stream),
155 }));
156
157 BoxFuture::new(HandshakeFuture::MidHandshake(tls_stream))
158 }
159}
160
161impl tls_api::TlsConnector for TlsConnector {
162 type Builder = TlsConnectorBuilder;
163
164 type Underlying = Arc<rustls::ClientConfig>;
165 type TlsStream = crate::TlsStream<AsyncSocketBox>;
166
167 fn underlying_mut(&mut self) -> &mut Self::Underlying {
168 &mut self.config
169 }
170
171 const IMPLEMENTED: bool = true;
172 const SUPPORTS_ALPN: bool = true;
173
174 fn info() -> ImplInfo {
175 crate::info()
176 }
177
178 fn builder() -> anyhow::Result<TlsConnectorBuilder> {
179 let mut roots = rustls::RootCertStore::empty();
180 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
181 let config = rustls::ClientConfig::builder()
182 .with_root_certificates(roots)
183 .with_no_client_auth();
184 Ok(TlsConnectorBuilder {
185 config,
186 verify_hostname: true,
187 root_store: rustls::RootCertStore::empty(),
188 })
189 }
190
191 spi_connector_common!(crate::TlsStream<S>);
192}