tls_api_rustls/
connector.rs

1use std::convert::TryFrom;
2use std::sync::Arc;
3
4use rustls::crypto::verify_tls12_signature;
5use rustls::crypto::verify_tls13_signature;
6use rustls::crypto::WebPkiSupportedAlgorithms;
7use rustls::StreamOwned;
8
9use tls_api::async_as_sync::AsyncIoAsSyncIo;
10use tls_api::spi_connector_common;
11use tls_api::AsyncSocket;
12use tls_api::AsyncSocketBox;
13use tls_api::BoxFuture;
14use tls_api::ImplInfo;
15
16use crate::handshake::HandshakeFuture;
17use crate::RustlsStream;
18use std::future::Future;
19
20pub struct TlsConnectorBuilder {
21    pub config: rustls::ClientConfig,
22    pub verify_hostname: bool,
23    pub root_store: rustls::RootCertStore,
24}
25pub struct TlsConnector {
26    pub config: Arc<rustls::ClientConfig>,
27}
28
29impl tls_api::TlsConnectorBuilder for TlsConnectorBuilder {
30    type Connector = TlsConnector;
31
32    type Underlying = rustls::ClientConfig;
33
34    fn underlying_mut(&mut self) -> &mut rustls::ClientConfig {
35        &mut self.config
36    }
37
38    fn set_alpn_protocols(&mut self, protocols: &[&[u8]]) -> anyhow::Result<()> {
39        self.config.alpn_protocols = protocols.iter().map(|p: &&[u8]| p.to_vec()).collect();
40        Ok(())
41    }
42
43    fn set_verify_hostname(&mut self, verify: bool) -> anyhow::Result<()> {
44        if !verify {
45            #[derive(Debug)]
46            struct NoCertificateServerVerifier {
47                supported: WebPkiSupportedAlgorithms,
48            }
49
50            impl rustls::client::danger::ServerCertVerifier for NoCertificateServerVerifier {
51                fn verify_server_cert(
52                    &self,
53                    _end_entity: &rustls::pki_types::CertificateDer<'_>,
54                    _intermediates: &[rustls::pki_types::CertificateDer<'_>],
55                    _server_name: &rustls::pki_types::ServerName<'_>,
56                    _ocsp_response: &[u8],
57                    _now: rustls::pki_types::UnixTime,
58                ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error>
59                {
60                    Ok(rustls::client::danger::ServerCertVerified::assertion())
61                }
62
63                fn verify_tls12_signature(
64                    &self,
65                    message: &[u8],
66                    cert: &rustls::pki_types::CertificateDer<'_>,
67                    dss: &rustls::DigitallySignedStruct,
68                ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
69                {
70                    verify_tls12_signature(message, cert, dss, &self.supported)
71                }
72
73                fn verify_tls13_signature(
74                    &self,
75                    message: &[u8],
76                    cert: &rustls::pki_types::CertificateDer<'_>,
77                    dss: &rustls::DigitallySignedStruct,
78                ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
79                {
80                    verify_tls13_signature(message, cert, dss, &self.supported)
81                }
82
83                fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
84                    self.supported.supported_schemes()
85                }
86            }
87
88            let no_cert_verifier = NoCertificateServerVerifier {
89                supported: rustls::crypto::CryptoProvider::get_default()
90                    .unwrap()
91                    .signature_verification_algorithms,
92            };
93
94            self.config
95                .dangerous()
96                .set_certificate_verifier(Arc::new(no_cert_verifier));
97            self.verify_hostname = false;
98        } else if !self.verify_hostname {
99            return Err(crate::Error::VerifyHostnameTrue.into());
100        }
101
102        Ok(())
103    }
104
105    fn add_root_certificate(&mut self, cert: &[u8]) -> anyhow::Result<()> {
106        let cert = rustls::pki_types::CertificateDer::from(cert);
107        self.root_store.add(cert).map_err(anyhow::Error::new)?;
108        Ok(())
109    }
110
111    fn build(self) -> anyhow::Result<TlsConnector> {
112        let mut config = self.config;
113        if !self.root_store.is_empty() {
114            let mut new_config = rustls::ClientConfig::builder()
115                .with_root_certificates(self.root_store)
116                .with_no_client_auth();
117            new_config.alpn_protocols = config.alpn_protocols;
118            new_config.resumption = config.resumption;
119            new_config.max_fragment_size = config.max_fragment_size;
120            new_config.client_auth_cert_resolver = config.client_auth_cert_resolver;
121            new_config.enable_sni = config.enable_sni;
122            new_config.key_log = config.key_log;
123            new_config.enable_early_data = config.enable_early_data;
124            config = new_config;
125        }
126        Ok(TlsConnector {
127            config: Arc::new(config),
128        })
129    }
130}
131
132impl TlsConnector {
133    pub fn connect_impl<'a, S>(
134        &'a self,
135        domain: &'a str,
136        stream: S,
137    ) -> impl Future<Output = anyhow::Result<crate::TlsStream<S>>> + 'a
138    where
139        S: AsyncSocket,
140    {
141        let dns_name = rustls::pki_types::ServerName::try_from(domain);
142        let dns_name = match dns_name {
143            Ok(dns_name) => dns_name.to_owned(),
144            Err(e) => return BoxFuture::new(async { Err(anyhow::anyhow!(e)) }),
145        };
146        let conn = rustls::ClientConnection::new(self.config.clone(), dns_name);
147        let conn = match conn.map_err(anyhow::Error::new) {
148            Ok(conn) => conn,
149            Err(e) => return BoxFuture::new(async { Err(e) }),
150        };
151        let tls_stream: crate::TlsStream<S> =
152            crate::TlsStream::new(RustlsStream::Client(StreamOwned {
153                conn,
154                sock: AsyncIoAsSyncIo::new(stream),
155            }));
156
157        BoxFuture::new(HandshakeFuture::MidHandshake(tls_stream))
158    }
159}
160
161impl tls_api::TlsConnector for TlsConnector {
162    type Builder = TlsConnectorBuilder;
163
164    type Underlying = Arc<rustls::ClientConfig>;
165    type TlsStream = crate::TlsStream<AsyncSocketBox>;
166
167    fn underlying_mut(&mut self) -> &mut Self::Underlying {
168        &mut self.config
169    }
170
171    const IMPLEMENTED: bool = true;
172    const SUPPORTS_ALPN: bool = true;
173
174    fn info() -> ImplInfo {
175        crate::info()
176    }
177
178    fn builder() -> anyhow::Result<TlsConnectorBuilder> {
179        let mut roots = rustls::RootCertStore::empty();
180        roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
181        let config = rustls::ClientConfig::builder()
182            .with_root_certificates(roots)
183            .with_no_client_auth();
184        Ok(TlsConnectorBuilder {
185            config,
186            verify_hostname: true,
187            root_store: rustls::RootCertStore::empty(),
188        })
189    }
190
191    spi_connector_common!(crate::TlsStream<S>);
192}