1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4 clippy::must_use_candidate,
5 clippy::unwrap_in_result,
6 clippy::panic_in_result_fn
7)]
8
9pub use rustls;
37#[cfg(feature = "native-certs")]
38pub use rustls_native_certs;
39pub use rustls_pki_types;
40#[cfg(feature = "platform-verifier")]
41pub use rustls_platform_verifier;
42pub use webpki;
43#[cfg(feature = "webpki-root-certs")]
44pub use webpki_root_certs;
45
46#[cfg(feature = "futures")]
47use futures_io::{AsyncRead, AsyncWrite};
48use rustls::{
49 ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
50 client::WantsClientCert,
51};
52use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
53
54use std::{
55 error::Error,
56 fmt,
57 io::{self, Read, Write},
58 sync::Arc,
59};
60
61pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
63
64#[cfg(feature = "futures")]
65pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
67
68#[derive(Clone, Default, Debug)]
70pub struct RustlsConnectorConfig {
71 store: Vec<CertificateDer<'static>>,
72 #[cfg(feature = "platform-verifier")]
73 platform_verifier: bool,
74}
75
76impl RustlsConnectorConfig {
77 #[cfg(feature = "webpki-root-certs")]
78 pub fn new_with_webpki_root_certs() -> Self {
80 Self::default().with_webpki_root_certs()
81 }
82
83 #[cfg(feature = "platform-verifier")]
84 pub fn new_with_platform_verifier() -> Self {
86 Self::default().with_platform_verifier()
87 }
88
89 #[cfg(feature = "native-certs")]
90 pub fn new_with_native_certs() -> io::Result<Self> {
96 Self::default().with_native_certs()
97 }
98
99 pub fn add_parsable_certificates(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
103 self.store.append(&mut der_certs)
104 }
105
106 pub fn with_parsable_certificates(mut self, der_certs: Vec<CertificateDer<'static>>) -> Self {
110 self.add_parsable_certificates(der_certs);
111 self
112 }
113
114 #[cfg(feature = "webpki-root-certs")]
115 pub fn with_webpki_root_certs(mut self) -> Self {
117 self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
118 self
119 }
120
121 #[cfg(feature = "platform-verifier")]
122 pub fn with_platform_verifier(mut self) -> Self {
124 self.platform_verifier = true;
125 self
126 }
127
128 #[cfg(feature = "native-certs")]
129 pub fn with_native_certs(mut self) -> io::Result<Self> {
135 let certs_result = rustls_native_certs::load_native_certs();
136 for err in certs_result.errors {
137 log::warn!("Got error while loading some native certificates: {err:?}");
138 }
139 if certs_result.certs.is_empty() {
140 return Err(io::Error::other(
141 "Could not load any valid native certificates",
142 ));
143 }
144 self.add_parsable_certificates(certs_result.certs);
145 Ok(self)
146 }
147
148 fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
149 let builder = ClientConfig::builder();
150 #[cfg(feature = "platform-verifier")]
151 {
152 if self.platform_verifier {
153 let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
154 self.store,
155 builder.crypto_provider().clone(),
156 )
157 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
158 return Ok(builder
161 .dangerous()
162 .with_custom_certificate_verifier(Arc::new(verifier)));
163 }
164 }
165 let mut store = RootCertStore::empty();
166 let (_, ignored) = store.add_parsable_certificates(self.store);
167 if ignored > 0 {
168 log::warn!("{ignored} platform CA root certificates were ignored due to errors");
169 }
170 if store.is_empty() {
171 return Err(io::Error::other("Could not load any valid certificates"));
172 }
173 Ok(builder.with_root_certificates(store))
174 }
175
176 pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
182 Ok(self.builder()?.with_no_client_auth().into())
183 }
184
185 pub fn connector_with_single_cert(
194 self,
195 cert_chain: Vec<CertificateDer<'static>>,
196 key_der: PrivateKeyDer<'static>,
197 ) -> io::Result<RustlsConnector> {
198 Ok(self
199 .builder()?
200 .with_client_auth_cert(cert_chain, key_der)
201 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
202 .into())
203 }
204}
205
206#[derive(Clone, Debug)]
208pub struct RustlsConnector(Arc<ClientConfig>);
209
210impl From<ClientConfig> for RustlsConnector {
211 fn from(config: ClientConfig) -> Self {
212 Arc::new(config).into()
213 }
214}
215
216impl From<Arc<ClientConfig>> for RustlsConnector {
217 fn from(config: Arc<ClientConfig>) -> Self {
218 Self(config)
219 }
220}
221
222impl RustlsConnector {
223 #[cfg(feature = "webpki-root-certs")]
224 pub fn new_with_webpki_root_certs() -> io::Result<Self> {
230 RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
231 }
232
233 #[cfg(feature = "platform-verifier")]
234 pub fn new_with_platform_verifier() -> io::Result<Self> {
240 RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
241 }
242
243 #[cfg(feature = "native-certs")]
244 pub fn new_with_native_certs() -> io::Result<Self> {
250 RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
251 }
252
253 #[allow(clippy::result_large_err)]
260 pub fn connect<S: Read + Write + Send + 'static>(
261 &self,
262 domain: &str,
263 stream: S,
264 ) -> Result<TlsStream<S>, HandshakeError<S>> {
265 let session = ClientConnection::new(
266 self.0.clone(),
267 server_name(domain).map_err(HandshakeError::Failure)?,
268 )
269 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
270 MidHandshakeTlsStream { session, stream }.handshake()
271 }
272
273 #[cfg(feature = "futures")]
274 pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
280 &self,
281 domain: &str,
282 stream: S,
283 ) -> io::Result<AsyncTlsStream<S>> {
284 futures_rustls::TlsConnector::from(self.0.clone())
285 .connect(server_name(domain)?, stream)
286 .await
287 }
288}
289
290fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
291 Ok(ServerName::try_from(domain)
292 .map_err(|err| {
293 io::Error::new(
294 io::ErrorKind::InvalidData,
295 format!("Invalid domain name: {err:?}"),
296 )
297 })?
298 .to_owned())
299}
300
301#[derive(Debug)]
303pub struct MidHandshakeTlsStream<S: Read + Write> {
304 session: ClientConnection,
305 stream: S,
306}
307
308impl<S: Read + Write + Send + 'static> MidHandshakeTlsStream<S> {
309 pub fn get_ref(&self) -> &S {
311 &self.stream
312 }
313
314 pub fn get_mut(&mut self) -> &mut S {
316 &mut self.stream
317 }
318
319 #[allow(clippy::result_large_err)]
326 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
327 if let Err(e) = self.session.complete_io(&mut self.stream) {
328 if e.kind() == io::ErrorKind::WouldBlock {
329 if self.session.is_handshaking() {
330 return Err(HandshakeError::WouldBlock(self));
331 }
332 } else {
333 return Err(e.into());
334 }
335 }
336 Ok(TlsStream::new(self.session, self.stream))
337 }
338}
339
340impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342 f.write_str("MidHandshakeTlsStream")
343 }
344}
345
346#[allow(clippy::large_enum_variant)]
348pub enum HandshakeError<S: Read + Write + Send + 'static> {
349 WouldBlock(MidHandshakeTlsStream<S>),
352 Failure(io::Error),
354}
355
356impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 match self {
359 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
360 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
361 }
362 }
363}
364
365impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
366 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367 let mut d = f.debug_tuple("HandshakeError");
368 match self {
369 HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
370 HandshakeError::Failure(err) => d.field(&err),
371 }
372 .finish()
373 }
374}
375
376impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
377 fn source(&self) -> Option<&(dyn Error + 'static)> {
378 match self {
379 HandshakeError::Failure(err) => Some(err),
380 _ => None,
381 }
382 }
383}
384
385impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
386 fn from(err: io::Error) -> Self {
387 HandshakeError::Failure(err)
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn empty_config_fails() {
397 assert!(
398 RustlsConnectorConfig::default()
399 .connector_with_no_client_auth()
400 .is_err()
401 );
402 }
403
404 #[test]
405 #[cfg(feature = "webpki-root-certs")]
406 fn webpki_root_certs_connector_builds() {
407 RustlsConnector::new_with_webpki_root_certs().unwrap();
408 }
409
410 #[test]
411 #[cfg(feature = "platform-verifier")]
412 fn platform_verifier_connector_builds() {
413 RustlsConnector::new_with_platform_verifier().unwrap();
414 }
415
416 #[test]
417 fn handshake_error_failure_display() {
418 let err: HandshakeError<std::net::TcpStream> =
419 HandshakeError::Failure(io::Error::other("test error"));
420 assert!(err.to_string().contains("test error"));
421 assert!(format!("{err:?}").contains("test error"));
422 assert!(err.source().is_some());
423 }
424}