Skip to main content

rustls_connector/
lib.rs

1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4    clippy::must_use_candidate,
5    clippy::unwrap_in_result,
6    clippy::panic_in_result_fn
7)]
8
9//! # Connector similar to openssl or native-tls for rustls
10//!
11//! rustls-connector is a library aiming at simplifying using rustls as
12//! an alternative to openssl and native-tls
13//!
14//! # Examples
15//!
16//! To connect to a remote server:
17//!
18//! ```rust, no_run
19//! use rustls_connector::RustlsConnector;
20//!
21//! use std::{
22//!     io::{Read, Write},
23//!     net::TcpStream,
24//! };
25//!
26//! let connector = RustlsConnector::new_with_platform_verifier().unwrap();
27//! let stream = TcpStream::connect("google.com:443").unwrap();
28//! let mut stream = connector.connect("google.com", stream).unwrap();
29//!
30//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
31//! let mut res = vec![];
32//! stream.read_to_end(&mut res).unwrap();
33//! println!("{}", String::from_utf8_lossy(&res));
34//! ```
35
36pub use rustls;
37#[cfg(feature = "native-certs")]
38pub use rustls_native_certs;
39pub use rustls_pki_types;
40#[cfg(feature = "platform-verifier")]
41pub use rustls_platform_verifier;
42pub use webpki;
43#[cfg(feature = "webpki-root-certs")]
44pub use webpki_root_certs;
45
46#[cfg(feature = "futures")]
47use futures_io::{AsyncRead, AsyncWrite};
48use rustls::{
49    ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
50    client::WantsClientCert,
51};
52use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
53
54use std::{
55    error::Error,
56    fmt,
57    io::{self, Read, Write},
58    sync::Arc,
59};
60
61/// A TLS stream
62pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
63
64#[cfg(feature = "futures")]
65/// An async TLS stream
66pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
67
68/// Configuration helper for [`RustlsConnector`]
69#[derive(Clone, Default, Debug)]
70pub struct RustlsConnectorConfig {
71    store: Vec<CertificateDer<'static>>,
72    #[cfg(feature = "platform-verifier")]
73    platform_verifier: bool,
74}
75
76impl RustlsConnectorConfig {
77    #[cfg(feature = "webpki-root-certs")]
78    /// Create a new [`RustlsConnectorConfig`] using the webpki-root-certs (requires webpki-root-certs feature enabled)
79    pub fn new_with_webpki_root_certs() -> Self {
80        Self::default().with_webpki_root_certs()
81    }
82
83    #[cfg(feature = "platform-verifier")]
84    /// Create a new [`RustlsConnectorConfig`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
85    pub fn new_with_platform_verifier() -> Self {
86        Self::default().with_platform_verifier()
87    }
88
89    #[cfg(feature = "native-certs")]
90    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if we fail to load the native certs.
95    pub fn new_with_native_certs() -> io::Result<Self> {
96        Self::default().with_native_certs()
97    }
98
99    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
100    ///
101    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
102    pub fn add_parsable_certificates(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
103        self.store.append(&mut der_certs)
104    }
105
106    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
107    ///
108    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
109    pub fn with_parsable_certificates(mut self, der_certs: Vec<CertificateDer<'static>>) -> Self {
110        self.add_parsable_certificates(der_certs);
111        self
112    }
113
114    #[cfg(feature = "webpki-root-certs")]
115    /// Add certs from webpki-root-certs (requires webpki-root-certs feature enabled)
116    pub fn with_webpki_root_certs(mut self) -> Self {
117        self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
118        self
119    }
120
121    #[cfg(feature = "platform-verifier")]
122    /// Use the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
123    pub fn with_platform_verifier(mut self) -> Self {
124        self.platform_verifier = true;
125        self
126    }
127
128    #[cfg(feature = "native-certs")]
129    /// Add the system certs (requires native-certs feature enabled)
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if we fail to load the native certs.
134    pub fn with_native_certs(mut self) -> io::Result<Self> {
135        let certs_result = rustls_native_certs::load_native_certs();
136        for err in certs_result.errors {
137            log::warn!("Got error while loading some native certificates: {err:?}");
138        }
139        if certs_result.certs.is_empty() {
140            return Err(io::Error::other(
141                "Could not load any valid native certificates",
142            ));
143        }
144        self.add_parsable_certificates(certs_result.certs);
145        Ok(self)
146    }
147
148    fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
149        let builder = ClientConfig::builder();
150        #[cfg(feature = "platform-verifier")]
151        {
152            if self.platform_verifier {
153                let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
154                    self.store,
155                    builder.crypto_provider().clone(),
156                )
157                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
158                return Ok(builder
159                    .dangerous()
160                    .with_custom_certificate_verifier(Arc::new(verifier)));
161            }
162        }
163        let mut store = RootCertStore::empty();
164        let (_, ignored) = store.add_parsable_certificates(self.store);
165        if ignored > 0 {
166            log::warn!("{ignored} platform CA root certificates were ignored due to errors");
167        }
168        if store.is_empty() {
169            return Err(io::Error::other("Could not load any valid certificates"));
170        }
171        Ok(builder.with_root_certificates(store))
172    }
173
174    /// Create a new [`RustlsConnector`] from this config and no client certificate
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if we fail to init our verifier
179    pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
180        Ok(self.builder()?.with_no_client_auth().into())
181    }
182
183    /// Create a new [`RustlsConnector`] from this config and the given client certificate
184    ///
185    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
186    /// Ed25519 private key.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if we fail to init our verifier or if key_der is invalid.
191    pub fn connector_with_single_cert(
192        self,
193        cert_chain: Vec<CertificateDer<'static>>,
194        key_der: PrivateKeyDer<'static>,
195    ) -> io::Result<RustlsConnector> {
196        Ok(self
197            .builder()?
198            .with_client_auth_cert(cert_chain, key_der)
199            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
200            .into())
201    }
202}
203
204/// The connector
205#[derive(Clone, Debug)]
206pub struct RustlsConnector(Arc<ClientConfig>);
207
208impl From<ClientConfig> for RustlsConnector {
209    fn from(config: ClientConfig) -> Self {
210        Arc::new(config).into()
211    }
212}
213
214impl From<Arc<ClientConfig>> for RustlsConnector {
215    fn from(config: Arc<ClientConfig>) -> Self {
216        Self(config)
217    }
218}
219
220impl RustlsConnector {
221    #[cfg(feature = "webpki-root-certs")]
222    /// Create a new RustlsConnector using the webpki-root certs (requires webpki-root-certs feature enabled)
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if we fail to init our verifier
227    pub fn new_with_webpki_root_certs() -> io::Result<Self> {
228        RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
229    }
230
231    #[cfg(feature = "platform-verifier")]
232    /// Create a new [`RustlsConnector`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
233    ///
234    /// # Errors
235    ///
236    /// Returns an error if we fail to init our verifier
237    pub fn new_with_platform_verifier() -> io::Result<Self> {
238        RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
239    }
240
241    #[cfg(feature = "native-certs")]
242    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if we fail to load the native certs.
247    pub fn new_with_native_certs() -> io::Result<Self> {
248        RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
249    }
250
251    /// Connect to the given host
252    ///
253    /// # Errors
254    ///
255    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
256    /// failure when we couldn't complete the hanshake
257    #[allow(clippy::result_large_err)]
258    pub fn connect<S: Read + Write + Send + 'static>(
259        &self,
260        domain: &str,
261        stream: S,
262    ) -> Result<TlsStream<S>, HandshakeError<S>> {
263        let session = ClientConnection::new(
264            self.0.clone(),
265            server_name(domain).map_err(HandshakeError::Failure)?,
266        )
267        .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
268        MidHandshakeTlsStream { session, stream }.handshake()
269    }
270
271    #[cfg(feature = "futures")]
272    /// Connect to the given host asynchronously
273    ///
274    /// # Errors
275    ///
276    /// Returns a [`io::Error`] containing the failure when we couldn't complete the TLS hanshake
277    pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
278        &self,
279        domain: &str,
280        stream: S,
281    ) -> io::Result<AsyncTlsStream<S>> {
282        futures_rustls::TlsConnector::from(self.0.clone())
283            .connect(server_name(domain)?, stream)
284            .await
285    }
286}
287
288fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
289    Ok(ServerName::try_from(domain)
290        .map_err(|err| {
291            io::Error::new(
292                io::ErrorKind::InvalidData,
293                format!("Invalid domain name ({err:?}): {domain}"),
294            )
295        })?
296        .to_owned())
297}
298
299/// A TLS stream which has been interrupted during the handshake
300#[derive(Debug)]
301pub struct MidHandshakeTlsStream<S: Read + Write> {
302    session: ClientConnection,
303    stream: S,
304}
305
306impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
307    /// Get a reference to the inner stream
308    pub fn get_ref(&self) -> &S {
309        &self.stream
310    }
311
312    /// Get a mutable reference to the inner stream
313    pub fn get_mut(&mut self) -> &mut S {
314        &mut self.stream
315    }
316
317    /// Retry the handshake
318    ///
319    /// # Errors
320    ///
321    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
322    /// failure when we couldn't complete the hanshake
323    #[allow(clippy::result_large_err)]
324    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
325        if let Err(e) = self.session.complete_io(&mut self.stream) {
326            if e.kind() == io::ErrorKind::WouldBlock {
327                if self.session.is_handshaking() {
328                    return Err(HandshakeError::WouldBlock(self));
329                }
330            } else {
331                return Err(e.into());
332            }
333        }
334        Ok(TlsStream::new(self.session, self.stream))
335    }
336}
337
338impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
339    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340        f.write_str("MidHandshakeTlsStream")
341    }
342}
343
344/// An error returned while performing the handshake
345#[allow(clippy::large_enum_variant)]
346pub enum HandshakeError<S: Read + Write + Send + 'static> {
347    /// We hit WouldBlock during handshake.
348    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
349    WouldBlock(MidHandshakeTlsStream<S>),
350    /// We hit a critical failure.
351    Failure(io::Error),
352}
353
354impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
355    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
356        match self {
357            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
358            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
359        }
360    }
361}
362
363impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
364    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        let mut d = f.debug_tuple("HandshakeError");
366        match self {
367            HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
368            HandshakeError::Failure(err) => d.field(&err),
369        }
370        .finish()
371    }
372}
373
374impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
375    fn source(&self) -> Option<&(dyn Error + 'static)> {
376        match self {
377            HandshakeError::Failure(err) => Some(err),
378            _ => None,
379        }
380    }
381}
382
383impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
384    fn from(err: io::Error) -> Self {
385        HandshakeError::Failure(err)
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn empty_config_fails() {
395        assert!(
396            RustlsConnectorConfig::default()
397                .connector_with_no_client_auth()
398                .is_err()
399        );
400    }
401
402    #[test]
403    #[cfg(feature = "webpki-root-certs")]
404    fn webpki_root_certs_connector_builds() {
405        RustlsConnector::new_with_webpki_root_certs().unwrap();
406    }
407
408    #[test]
409    #[cfg(feature = "platform-verifier")]
410    fn platform_verifier_connector_builds() {
411        RustlsConnector::new_with_platform_verifier().unwrap();
412    }
413
414    #[test]
415    fn handshake_error_failure_display() {
416        let err: HandshakeError<std::net::TcpStream> =
417            HandshakeError::Failure(io::Error::other("test error"));
418        assert!(err.to_string().contains("test error"));
419        assert!(format!("{err:?}").contains("test error"));
420        assert!(err.source().is_some());
421    }
422}