Skip to main content

rustls_connector/
lib.rs

1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4    clippy::must_use_candidate,
5    clippy::unwrap_in_result,
6    clippy::panic_in_result_fn
7)]
8
9//! A TLS connector for rustls modelled after the `openssl` and `native-tls` APIs.
10//!
11//! Wraps [`rustls`] with a high-level [`RustlsConnector`] type that mirrors the
12//! ergonomics of `native_tls::TlsConnector`, making it straightforward to swap
13//! TLS backends in existing code.
14//!
15//! # Feature flags
16//!
17//! ## Certificate store (pick at least one)
18//!
19//! | Flag | Notes |
20//! |------|-------|
21//! | `platform-verifier` *(default)* | Platform trust store via rustls-platform-verifier |
22//! | `native-certs` | Native root certificates via rustls-native-certs |
23//! | `webpki-root-certs` | Bundled Mozilla root certificate set |
24//!
25//! ## Rustls crypto provider (at least one must be enabled)
26//!
27//! | Flag | Notes |
28//! |------|-------|
29//! | `rustls--aws_lc_rs` *(default)* | Uses aws-lc-rs |
30//! | `rustls--ring` | Uses ring (more portable) |
31//!
32//! ## Miscellaneous
33//!
34//! | Flag | Notes |
35//! |------|-------|
36//! | `futures` | Async connect via `futures-rustls` |
37//! | `logging` | Enable rustls TLS logging |
38//!
39//! # Example
40//!
41//! ```rust, no_run
42//! use rustls_connector::RustlsConnector;
43//!
44//! use std::{
45//!     io::{Read, Write},
46//!     net::TcpStream,
47//! };
48//!
49//! let connector = RustlsConnector::new_with_platform_verifier().unwrap();
50//! let stream = TcpStream::connect("google.com:443").unwrap();
51//! let mut stream = connector.connect("google.com", stream).unwrap();
52//!
53//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
54//! let mut res = vec![];
55//! stream.read_to_end(&mut res).unwrap();
56//! println!("{}", String::from_utf8_lossy(&res));
57//! ```
58
59/// Reexport of the [`rustls`](https://docs.rs/rustls) crate.
60pub use rustls;
61#[cfg(feature = "native-certs")]
62/// Reexport of the [`rustls_native_certs`](https://docs.rs/rustls-native-certs) crate.
63pub use rustls_native_certs;
64/// Reexport of the [`rustls_pki_types`](https://docs.rs/rustls-pki-types) crate.
65pub use rustls_pki_types;
66#[cfg(feature = "platform-verifier")]
67/// Reexport of the [`rustls_platform_verifier`](https://docs.rs/rustls-platform-verifier) crate.
68pub use rustls_platform_verifier;
69/// Reexport of the [`webpki`](https://docs.rs/webpki) crate.
70pub use webpki;
71#[cfg(feature = "webpki-root-certs")]
72/// Reexport of the [`webpki_root_certs`](https://docs.rs/webpki-root-certs) crate.
73pub use webpki_root_certs;
74
75#[cfg(feature = "futures")]
76use futures_io::{AsyncRead, AsyncWrite};
77use rustls::{
78    ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
79    client::WantsClientCert,
80};
81use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
82
83use std::{
84    error::Error,
85    fmt,
86    io::{self, Read, Write},
87    sync::Arc,
88};
89
90/// A rustls client TLS stream wrapping an underlying synchronous I/O stream `S`.
91pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
92
93#[cfg(feature = "futures")]
94/// A rustls client TLS stream wrapping an underlying async I/O stream `S`.
95pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
96
97/// Configuration helper for [`RustlsConnector`]
98#[derive(Clone, Default, Debug)]
99pub struct RustlsConnectorConfig {
100    store: Vec<CertificateDer<'static>>,
101    #[cfg(feature = "platform-verifier")]
102    platform_verifier: bool,
103}
104
105impl RustlsConnectorConfig {
106    #[cfg(feature = "webpki-root-certs")]
107    /// Create a new [`RustlsConnectorConfig`] using the webpki-root-certs (requires webpki-root-certs feature enabled)
108    pub fn new_with_webpki_root_certs() -> Self {
109        Self::default().with_webpki_root_certs()
110    }
111
112    #[cfg(feature = "platform-verifier")]
113    /// Create a new [`RustlsConnectorConfig`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
114    pub fn new_with_platform_verifier() -> Self {
115        Self::default().with_platform_verifier()
116    }
117
118    #[cfg(feature = "native-certs")]
119    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if we fail to load the native certs.
124    pub fn new_with_native_certs() -> io::Result<Self> {
125        Self::default().with_native_certs()
126    }
127
128    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
129    ///
130    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
131    pub fn add_parsable_certificates(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
132        self.store.append(&mut der_certs)
133    }
134
135    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
136    ///
137    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
138    pub fn with_parsable_certificates(mut self, der_certs: Vec<CertificateDer<'static>>) -> Self {
139        self.add_parsable_certificates(der_certs);
140        self
141    }
142
143    #[cfg(feature = "webpki-root-certs")]
144    /// Add certs from webpki-root-certs (requires webpki-root-certs feature enabled)
145    pub fn with_webpki_root_certs(mut self) -> Self {
146        self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
147        self
148    }
149
150    #[cfg(feature = "platform-verifier")]
151    /// Use the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
152    pub fn with_platform_verifier(mut self) -> Self {
153        self.platform_verifier = true;
154        self
155    }
156
157    #[cfg(feature = "native-certs")]
158    /// Add the system certs (requires native-certs feature enabled)
159    ///
160    /// # Errors
161    ///
162    /// Returns an error if we fail to load the native certs.
163    pub fn with_native_certs(mut self) -> io::Result<Self> {
164        let certs_result = rustls_native_certs::load_native_certs();
165        for err in certs_result.errors {
166            log::warn!("Got error while loading some native certificates: {err:?}");
167        }
168        if certs_result.certs.is_empty() {
169            return Err(io::Error::other(
170                "Could not load any valid native certificates",
171            ));
172        }
173        self.add_parsable_certificates(certs_result.certs);
174        Ok(self)
175    }
176
177    fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
178        let builder = ClientConfig::builder();
179        #[cfg(feature = "platform-verifier")]
180        {
181            if self.platform_verifier {
182                let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
183                    self.store,
184                    builder.crypto_provider().clone(),
185                )
186                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
187                // `.dangerous()` is the rustls API for supplying a custom verifier;
188                // it does not bypass verification — `Verifier` delegates to the OS store.
189                return Ok(builder
190                    .dangerous()
191                    .with_custom_certificate_verifier(Arc::new(verifier)));
192            }
193        }
194        let mut store = RootCertStore::empty();
195        let (_, ignored) = store.add_parsable_certificates(self.store);
196        if ignored > 0 {
197            log::warn!("{ignored} platform CA root certificates were ignored due to errors");
198        }
199        if store.is_empty() {
200            return Err(io::Error::other("Could not load any valid certificates"));
201        }
202        Ok(builder.with_root_certificates(store))
203    }
204
205    /// Create a new [`RustlsConnector`] from this config and no client certificate
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if we fail to init our verifier
210    pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
211        Ok(self.builder()?.with_no_client_auth().into())
212    }
213
214    /// Create a new [`RustlsConnector`] from this config and the given client certificate
215    ///
216    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
217    /// Ed25519 private key.
218    ///
219    /// # Errors
220    ///
221    /// Returns an error if we fail to init our verifier or if key_der is invalid.
222    pub fn connector_with_single_cert(
223        self,
224        cert_chain: Vec<CertificateDer<'static>>,
225        key_der: PrivateKeyDer<'static>,
226    ) -> io::Result<RustlsConnector> {
227        Ok(self
228            .builder()?
229            .with_client_auth_cert(cert_chain, key_der)
230            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
231            .into())
232    }
233}
234
235/// A rustls TLS connector ready to perform TLS handshakes.
236///
237/// Wraps an [`Arc<ClientConfig>`] and can be built from a [`RustlsConnectorConfig`] via
238/// [`connector_with_no_client_auth`](RustlsConnectorConfig::connector_with_no_client_auth) or
239/// [`connector_with_single_cert`](RustlsConnectorConfig::connector_with_single_cert), or
240/// directly from a `ClientConfig` via the [`From`] impl.
241#[derive(Clone, Debug)]
242pub struct RustlsConnector(Arc<ClientConfig>);
243
244impl From<ClientConfig> for RustlsConnector {
245    fn from(config: ClientConfig) -> Self {
246        Arc::new(config).into()
247    }
248}
249
250impl From<Arc<ClientConfig>> for RustlsConnector {
251    fn from(config: Arc<ClientConfig>) -> Self {
252        Self(config)
253    }
254}
255
256impl RustlsConnector {
257    #[cfg(feature = "webpki-root-certs")]
258    /// Create a new RustlsConnector using the webpki-root certs (requires webpki-root-certs feature enabled)
259    ///
260    /// # Errors
261    ///
262    /// Returns an error if we fail to init our verifier
263    pub fn new_with_webpki_root_certs() -> io::Result<Self> {
264        RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
265    }
266
267    #[cfg(feature = "platform-verifier")]
268    /// Create a new [`RustlsConnector`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
269    ///
270    /// # Errors
271    ///
272    /// Returns an error if we fail to init our verifier
273    pub fn new_with_platform_verifier() -> io::Result<Self> {
274        RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
275    }
276
277    #[cfg(feature = "native-certs")]
278    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
279    ///
280    /// # Errors
281    ///
282    /// Returns an error if we fail to load the native certs.
283    pub fn new_with_native_certs() -> io::Result<Self> {
284        RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
285    }
286
287    /// Connect to the given host
288    ///
289    /// # Errors
290    ///
291    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
292    /// failure when we couldn't complete the handshake
293    #[allow(clippy::result_large_err)]
294    pub fn connect<S: Read + Write + Send + 'static>(
295        &self,
296        domain: &str,
297        stream: S,
298    ) -> Result<TlsStream<S>, HandshakeError<S>> {
299        let session = ClientConnection::new(
300            self.0.clone(),
301            server_name(domain).map_err(HandshakeError::Failure)?,
302        )
303        .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
304        MidHandshakeTlsStream { session, stream }.handshake()
305    }
306
307    #[cfg(feature = "futures")]
308    /// Connect to the given host asynchronously
309    ///
310    /// # Errors
311    ///
312    /// Returns a [`io::Error`] containing the failure when we couldn't complete the TLS handshake
313    pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
314        &self,
315        domain: &str,
316        stream: S,
317    ) -> io::Result<AsyncTlsStream<S>> {
318        futures_rustls::TlsConnector::from(self.0.clone())
319            .connect(server_name(domain)?, stream)
320            .await
321    }
322}
323
324fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
325    Ok(ServerName::try_from(domain)
326        .map_err(|err| {
327            io::Error::new(
328                io::ErrorKind::InvalidData,
329                format!("Invalid domain name: {err:?}"),
330            )
331        })?
332        .to_owned())
333}
334
335/// A TLS stream which has been interrupted during the handshake
336#[derive(Debug)]
337pub struct MidHandshakeTlsStream<S: Read + Write> {
338    session: ClientConnection,
339    stream: S,
340}
341
342impl<S: Read + Write + Send + 'static> MidHandshakeTlsStream<S> {
343    /// Get a reference to the inner stream
344    pub fn get_ref(&self) -> &S {
345        &self.stream
346    }
347
348    /// Get a mutable reference to the inner stream
349    pub fn get_mut(&mut self) -> &mut S {
350        &mut self.stream
351    }
352
353    /// Retry the handshake
354    ///
355    /// # Errors
356    ///
357    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
358    /// failure when we couldn't complete the handshake
359    #[allow(clippy::result_large_err)]
360    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
361        if let Err(e) = self.session.complete_io(&mut self.stream) {
362            if e.kind() == io::ErrorKind::WouldBlock {
363                if self.session.is_handshaking() {
364                    return Err(HandshakeError::WouldBlock(self));
365                }
366            } else {
367                return Err(e.into());
368            }
369        }
370        Ok(TlsStream::new(self.session, self.stream))
371    }
372}
373
374impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
375    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376        f.write_str("MidHandshakeTlsStream")
377    }
378}
379
380/// An error returned while performing the handshake
381#[allow(clippy::large_enum_variant)]
382pub enum HandshakeError<S: Read + Write + Send + 'static> {
383    /// We hit WouldBlock during handshake.
384    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
385    WouldBlock(MidHandshakeTlsStream<S>),
386    /// We hit a critical failure.
387    Failure(io::Error),
388}
389
390impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
391    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392        match self {
393            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
394            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
395        }
396    }
397}
398
399impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401        let mut d = f.debug_tuple("HandshakeError");
402        match self {
403            HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
404            HandshakeError::Failure(err) => d.field(&err),
405        }
406        .finish()
407    }
408}
409
410impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
411    fn source(&self) -> Option<&(dyn Error + 'static)> {
412        match self {
413            HandshakeError::Failure(err) => Some(err),
414            _ => None,
415        }
416    }
417}
418
419impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
420    fn from(err: io::Error) -> Self {
421        HandshakeError::Failure(err)
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn empty_config_fails() {
431        assert!(
432            RustlsConnectorConfig::default()
433                .connector_with_no_client_auth()
434                .is_err()
435        );
436    }
437
438    #[test]
439    #[cfg(feature = "webpki-root-certs")]
440    fn webpki_root_certs_connector_builds() {
441        RustlsConnector::new_with_webpki_root_certs().unwrap();
442    }
443
444    #[test]
445    #[cfg(feature = "platform-verifier")]
446    fn platform_verifier_connector_builds() {
447        RustlsConnector::new_with_platform_verifier().unwrap();
448    }
449
450    #[test]
451    fn handshake_error_failure_display() {
452        let err: HandshakeError<std::net::TcpStream> =
453            HandshakeError::Failure(io::Error::other("test error"));
454        assert!(err.to_string().contains("test error"));
455        assert!(format!("{err:?}").contains("test error"));
456        assert!(err.source().is_some());
457    }
458}