Skip to main content

rustls_connector/
lib.rs

1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4    clippy::must_use_candidate,
5    clippy::unwrap_in_result,
6    clippy::panic_in_result_fn
7)]
8
9//! # Connector similar to openssl or native-tls for rustls
10//!
11//! rustls-connector is a library aiming at simplifying using rustls as
12//! an alternative to openssl and native-tls
13//!
14//! # Examples
15//!
16//! To connect to a remote server:
17//!
18//! ```rust, no_run
19//! use rustls_connector::RustlsConnector;
20//!
21//! use std::{
22//!     io::{Read, Write},
23//!     net::TcpStream,
24//! };
25//!
26//! let connector = RustlsConnector::new_with_platform_verifier().unwrap();
27//! let stream = TcpStream::connect("google.com:443").unwrap();
28//! let mut stream = connector.connect("google.com", stream).unwrap();
29//!
30//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
31//! let mut res = vec![];
32//! stream.read_to_end(&mut res).unwrap();
33//! println!("{}", String::from_utf8_lossy(&res));
34//! ```
35
36pub use rustls;
37#[cfg(feature = "native-certs")]
38pub use rustls_native_certs;
39pub use rustls_pki_types;
40#[cfg(feature = "platform-verifier")]
41pub use rustls_platform_verifier;
42pub use webpki;
43#[cfg(feature = "webpki-root-certs")]
44pub use webpki_root_certs;
45
46#[cfg(feature = "futures")]
47use futures_io::{AsyncRead, AsyncWrite};
48use rustls::{
49    ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
50    client::WantsClientCert,
51};
52use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
53
54use std::{
55    error::Error,
56    fmt,
57    io::{self, Read, Write},
58    sync::Arc,
59};
60
61/// A TLS stream
62pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
63
64#[cfg(feature = "futures")]
65/// An async TLS stream
66pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
67
68/// Configuration helper for [`RustlsConnector`]
69#[derive(Clone, Default, Debug)]
70pub struct RustlsConnectorConfig {
71    store: Vec<CertificateDer<'static>>,
72    #[cfg(feature = "platform-verifier")]
73    platform_verifier: bool,
74}
75
76impl RustlsConnectorConfig {
77    #[cfg(feature = "webpki-root-certs")]
78    /// Create a new [`RustlsConnectorConfig`] using the webpki-root-certs (requires webpki-root-certs feature enabled)
79    pub fn new_with_webpki_root_certs() -> Self {
80        Self::default().with_webpki_root_certs()
81    }
82
83    #[cfg(feature = "platform-verifier")]
84    /// Create a new [`RustlsConnectorConfig`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
85    pub fn new_with_platform_verifier() -> Self {
86        Self::default().with_platform_verifier()
87    }
88
89    #[cfg(feature = "native-certs")]
90    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if we fail to load the native certs.
95    pub fn new_with_native_certs() -> io::Result<Self> {
96        Self::default().with_native_certs()
97    }
98
99    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
100    ///
101    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
102    pub fn add_parsable_certificates(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
103        self.store.append(&mut der_certs)
104    }
105
106    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
107    ///
108    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
109    pub fn with_parsable_certificates(mut self, der_certs: Vec<CertificateDer<'static>>) -> Self {
110        self.add_parsable_certificates(der_certs);
111        self
112    }
113
114    #[cfg(feature = "webpki-root-certs")]
115    /// Add certs from webpki-root-certs (requires webpki-root-certs feature enabled)
116    pub fn with_webpki_root_certs(mut self) -> Self {
117        self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
118        self
119    }
120
121    #[cfg(feature = "platform-verifier")]
122    /// Use the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
123    pub fn with_platform_verifier(mut self) -> Self {
124        self.platform_verifier = true;
125        self
126    }
127
128    #[cfg(feature = "native-certs")]
129    /// Add the system certs (requires native-certs feature enabled)
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if we fail to load the native certs.
134    pub fn with_native_certs(mut self) -> io::Result<Self> {
135        let certs_result = rustls_native_certs::load_native_certs();
136        for err in certs_result.errors {
137            log::warn!("Got error while loading some native certificates: {err:?}");
138        }
139        if certs_result.certs.is_empty() {
140            return Err(io::Error::other(
141                "Could not load any valid native certificates",
142            ));
143        }
144        self.add_parsable_certificates(certs_result.certs);
145        Ok(self)
146    }
147
148    fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
149        let builder = ClientConfig::builder();
150        #[cfg(feature = "platform-verifier")]
151        {
152            if self.platform_verifier {
153                let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
154                    self.store,
155                    builder.crypto_provider().clone(),
156                )
157                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
158                // `.dangerous()` is the rustls API for supplying a custom verifier;
159                // it does not bypass verification — `Verifier` delegates to the OS store.
160                return Ok(builder
161                    .dangerous()
162                    .with_custom_certificate_verifier(Arc::new(verifier)));
163            }
164        }
165        let mut store = RootCertStore::empty();
166        let (_, ignored) = store.add_parsable_certificates(self.store);
167        if ignored > 0 {
168            log::warn!("{ignored} platform CA root certificates were ignored due to errors");
169        }
170        if store.is_empty() {
171            return Err(io::Error::other("Could not load any valid certificates"));
172        }
173        Ok(builder.with_root_certificates(store))
174    }
175
176    /// Create a new [`RustlsConnector`] from this config and no client certificate
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if we fail to init our verifier
181    pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
182        Ok(self.builder()?.with_no_client_auth().into())
183    }
184
185    /// Create a new [`RustlsConnector`] from this config and the given client certificate
186    ///
187    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
188    /// Ed25519 private key.
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if we fail to init our verifier or if key_der is invalid.
193    pub fn connector_with_single_cert(
194        self,
195        cert_chain: Vec<CertificateDer<'static>>,
196        key_der: PrivateKeyDer<'static>,
197    ) -> io::Result<RustlsConnector> {
198        Ok(self
199            .builder()?
200            .with_client_auth_cert(cert_chain, key_der)
201            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
202            .into())
203    }
204}
205
206/// The connector
207#[derive(Clone, Debug)]
208pub struct RustlsConnector(Arc<ClientConfig>);
209
210impl From<ClientConfig> for RustlsConnector {
211    fn from(config: ClientConfig) -> Self {
212        Arc::new(config).into()
213    }
214}
215
216impl From<Arc<ClientConfig>> for RustlsConnector {
217    fn from(config: Arc<ClientConfig>) -> Self {
218        Self(config)
219    }
220}
221
222impl RustlsConnector {
223    #[cfg(feature = "webpki-root-certs")]
224    /// Create a new RustlsConnector using the webpki-root certs (requires webpki-root-certs feature enabled)
225    ///
226    /// # Errors
227    ///
228    /// Returns an error if we fail to init our verifier
229    pub fn new_with_webpki_root_certs() -> io::Result<Self> {
230        RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
231    }
232
233    #[cfg(feature = "platform-verifier")]
234    /// Create a new [`RustlsConnector`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
235    ///
236    /// # Errors
237    ///
238    /// Returns an error if we fail to init our verifier
239    pub fn new_with_platform_verifier() -> io::Result<Self> {
240        RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
241    }
242
243    #[cfg(feature = "native-certs")]
244    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
245    ///
246    /// # Errors
247    ///
248    /// Returns an error if we fail to load the native certs.
249    pub fn new_with_native_certs() -> io::Result<Self> {
250        RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
251    }
252
253    /// Connect to the given host
254    ///
255    /// # Errors
256    ///
257    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
258    /// failure when we couldn't complete the handshake
259    #[allow(clippy::result_large_err)]
260    pub fn connect<S: Read + Write + Send + 'static>(
261        &self,
262        domain: &str,
263        stream: S,
264    ) -> Result<TlsStream<S>, HandshakeError<S>> {
265        let session = ClientConnection::new(
266            self.0.clone(),
267            server_name(domain).map_err(HandshakeError::Failure)?,
268        )
269        .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
270        MidHandshakeTlsStream { session, stream }.handshake()
271    }
272
273    #[cfg(feature = "futures")]
274    /// Connect to the given host asynchronously
275    ///
276    /// # Errors
277    ///
278    /// Returns a [`io::Error`] containing the failure when we couldn't complete the TLS handshake
279    pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
280        &self,
281        domain: &str,
282        stream: S,
283    ) -> io::Result<AsyncTlsStream<S>> {
284        futures_rustls::TlsConnector::from(self.0.clone())
285            .connect(server_name(domain)?, stream)
286            .await
287    }
288}
289
290fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
291    Ok(ServerName::try_from(domain)
292        .map_err(|err| {
293            io::Error::new(
294                io::ErrorKind::InvalidData,
295                format!("Invalid domain name: {err:?}"),
296            )
297        })?
298        .to_owned())
299}
300
301/// A TLS stream which has been interrupted during the handshake
302#[derive(Debug)]
303pub struct MidHandshakeTlsStream<S: Read + Write> {
304    session: ClientConnection,
305    stream: S,
306}
307
308impl<S: Read + Write + Send + 'static> MidHandshakeTlsStream<S> {
309    /// Get a reference to the inner stream
310    pub fn get_ref(&self) -> &S {
311        &self.stream
312    }
313
314    /// Get a mutable reference to the inner stream
315    pub fn get_mut(&mut self) -> &mut S {
316        &mut self.stream
317    }
318
319    /// Retry the handshake
320    ///
321    /// # Errors
322    ///
323    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
324    /// failure when we couldn't complete the handshake
325    #[allow(clippy::result_large_err)]
326    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
327        if let Err(e) = self.session.complete_io(&mut self.stream) {
328            if e.kind() == io::ErrorKind::WouldBlock {
329                if self.session.is_handshaking() {
330                    return Err(HandshakeError::WouldBlock(self));
331                }
332            } else {
333                return Err(e.into());
334            }
335        }
336        Ok(TlsStream::new(self.session, self.stream))
337    }
338}
339
340impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342        f.write_str("MidHandshakeTlsStream")
343    }
344}
345
346/// An error returned while performing the handshake
347#[allow(clippy::large_enum_variant)]
348pub enum HandshakeError<S: Read + Write + Send + 'static> {
349    /// We hit WouldBlock during handshake.
350    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
351    WouldBlock(MidHandshakeTlsStream<S>),
352    /// We hit a critical failure.
353    Failure(io::Error),
354}
355
356impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        match self {
359            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
360            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
361        }
362    }
363}
364
365impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
366    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367        let mut d = f.debug_tuple("HandshakeError");
368        match self {
369            HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
370            HandshakeError::Failure(err) => d.field(&err),
371        }
372        .finish()
373    }
374}
375
376impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
377    fn source(&self) -> Option<&(dyn Error + 'static)> {
378        match self {
379            HandshakeError::Failure(err) => Some(err),
380            _ => None,
381        }
382    }
383}
384
385impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
386    fn from(err: io::Error) -> Self {
387        HandshakeError::Failure(err)
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn empty_config_fails() {
397        assert!(
398            RustlsConnectorConfig::default()
399                .connector_with_no_client_auth()
400                .is_err()
401        );
402    }
403
404    #[test]
405    #[cfg(feature = "webpki-root-certs")]
406    fn webpki_root_certs_connector_builds() {
407        RustlsConnector::new_with_webpki_root_certs().unwrap();
408    }
409
410    #[test]
411    #[cfg(feature = "platform-verifier")]
412    fn platform_verifier_connector_builds() {
413        RustlsConnector::new_with_platform_verifier().unwrap();
414    }
415
416    #[test]
417    fn handshake_error_failure_display() {
418        let err: HandshakeError<std::net::TcpStream> =
419            HandshakeError::Failure(io::Error::other("test error"));
420        assert!(err.to_string().contains("test error"));
421        assert!(format!("{err:?}").contains("test error"));
422        assert!(err.source().is_some());
423    }
424}