Skip to main content

rustls_connector/
lib.rs

1#![deny(missing_docs)]
2
3//! # Connector similar to openssl or native-tls for rustls
4//!
5//! rustls-connector is a library aiming at simplifying using rustls as
6//! an alternative to openssl and native-tls
7//!
8//! # Examples
9//!
10//! To connect to a remote server:
11//!
12//! ```rust, no_run
13//! use rustls_connector::RustlsConnector;
14//!
15//! use std::{
16//!     io::{Read, Write},
17//!     net::TcpStream,
18//! };
19//!
20//! let connector = RustlsConnector::new_with_platform_verifier().unwrap();
21//! let stream = TcpStream::connect("google.com:443").unwrap();
22//! let mut stream = connector.connect("google.com", stream).unwrap();
23//!
24//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
25//! let mut res = vec![];
26//! stream.read_to_end(&mut res).unwrap();
27//! println!("{}", String::from_utf8_lossy(&res));
28//! ```
29
30pub use rustls;
31#[cfg(feature = "native-certs")]
32pub use rustls_native_certs;
33pub use rustls_pki_types;
34#[cfg(feature = "platform-verifier")]
35pub use rustls_platform_verifier;
36pub use webpki;
37#[cfg(feature = "webpki-root-certs")]
38pub use webpki_root_certs;
39
40#[cfg(feature = "futures")]
41use futures_io::{AsyncRead, AsyncWrite};
42use rustls::{
43    ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
44    client::WantsClientCert,
45};
46use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
47
48use std::{
49    error::Error,
50    fmt,
51    io::{self, Read, Write},
52    sync::Arc,
53};
54
55/// A TLS stream
56pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
57
58#[cfg(feature = "futures")]
59/// An async TLS stream
60pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
61
62/// Configuration helper for [`RustlsConnector`]
63#[derive(Clone, Default)]
64pub struct RustlsConnectorConfig {
65    store: Vec<CertificateDer<'static>>,
66    #[cfg(feature = "platform-verifier")]
67    platform_verifier: bool,
68}
69
70impl RustlsConnectorConfig {
71    #[cfg(feature = "webpki-root-certs")]
72    /// Create a new [`RustlsConnectorConfig`] using the webpki-root-certs (requires webpki-root-certs feature enabled)
73    pub fn new_with_webpki_root_certs() -> Self {
74        Self::default().with_webpki_root_certs()
75    }
76
77    #[cfg(feature = "platform-verifier")]
78    /// Create a new [`RustlsConnectorConfig`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
79    pub fn new_with_platform_verifier() -> Self {
80        Self::default().with_platform_verifier()
81    }
82
83    #[cfg(feature = "native-certs")]
84    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if we fail to load the native certs.
89    pub fn new_with_native_certs() -> io::Result<Self> {
90        Self::default().with_native_certs()
91    }
92
93    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
94    ///
95    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
96    pub fn add_parsable_certificates(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
97        self.store.append(&mut der_certs)
98    }
99
100    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
101    ///
102    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
103    pub fn with_parsable_certificates(mut self, der_certs: Vec<CertificateDer<'static>>) -> Self {
104        self.add_parsable_certificates(der_certs);
105        self
106    }
107
108    #[cfg(feature = "webpki-root-certs")]
109    /// Add certs from webpki-root-certs (requires webpki-root-certs feature enabled)
110    pub fn with_webpki_root_certs(mut self) -> Self {
111        self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
112        self
113    }
114
115    #[cfg(feature = "platform-verifier")]
116    /// Use the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
117    pub fn with_platform_verifier(mut self) -> Self {
118        self.platform_verifier = true;
119        self
120    }
121
122    #[cfg(feature = "native-certs")]
123    /// Add the system certs (requires native-certs feature enabled)
124    ///
125    /// # Errors
126    ///
127    /// Returns an error if we fail to load the native certs.
128    pub fn with_native_certs(mut self) -> io::Result<Self> {
129        let certs_result = rustls_native_certs::load_native_certs();
130        for err in certs_result.errors {
131            log::warn!("Got error while loading some native certificates: {err:?}");
132        }
133        if certs_result.certs.is_empty() {
134            return Err(io::Error::other(
135                "Could not load any valid native certificates",
136            ));
137        }
138        self.add_parsable_certificates(certs_result.certs);
139        Ok(self)
140    }
141
142    fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
143        let builder = ClientConfig::builder();
144        #[cfg(feature = "platform-verifier")]
145        {
146            if self.platform_verifier {
147                let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
148                    self.store,
149                    builder.crypto_provider().clone(),
150                )
151                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
152                return Ok(builder
153                    .dangerous()
154                    .with_custom_certificate_verifier(Arc::new(verifier)));
155            }
156        }
157        let mut store = RootCertStore::empty();
158        let (_, ignored) = store.add_parsable_certificates(self.store);
159        if ignored > 0 {
160            log::warn!("{ignored} platform CA root certificates were ignored due to errors");
161        }
162        if store.is_empty() {
163            return Err(io::Error::other("Could not load any valid certificates"));
164        }
165        Ok(builder.with_root_certificates(store))
166    }
167
168    /// Create a new [`RustlsConnector`] from this config and no client certificate
169    ///
170    /// # Errors
171    ///
172    /// Returns an error if we fail to init our verifier
173    pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
174        Ok(self.builder()?.with_no_client_auth().into())
175    }
176
177    /// Create a new [`RustlsConnector`] from this config and the given client certificate
178    ///
179    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
180    /// Ed25519 private key.
181    ///
182    /// # Errors
183    ///
184    /// Returns an error if we fail to init our verifier or if key_der is invalid.
185    pub fn connector_with_single_cert(
186        self,
187        cert_chain: Vec<CertificateDer<'static>>,
188        key_der: PrivateKeyDer<'static>,
189    ) -> io::Result<RustlsConnector> {
190        Ok(self
191            .builder()?
192            .with_client_auth_cert(cert_chain, key_der)
193            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
194            .into())
195    }
196}
197
198
199/// The connector
200#[derive(Clone)]
201pub struct RustlsConnector(Arc<ClientConfig>);
202
203impl From<ClientConfig> for RustlsConnector {
204    fn from(config: ClientConfig) -> Self {
205        Arc::new(config).into()
206    }
207}
208
209impl From<Arc<ClientConfig>> for RustlsConnector {
210    fn from(config: Arc<ClientConfig>) -> Self {
211        Self(config)
212    }
213}
214
215impl RustlsConnector {
216    #[cfg(feature = "webpki-root-certs")]
217    /// Create a new RustlsConnector using the webpki-root certs (requires webpki-root-certs feature enabled)
218    ///
219    /// # Errors
220    ///
221    /// Returns an error if we fail to init our verifier
222    pub fn new_with_webpki_root_certs() -> io::Result<Self> {
223        RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
224    }
225
226    #[cfg(feature = "platform-verifier")]
227    /// Create a new [`RustlsConnector`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
228    ///
229    /// # Errors
230    ///
231    /// Returns an error if we fail to init our verifier
232    pub fn new_with_platform_verifier() -> io::Result<Self> {
233        RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
234    }
235
236    #[cfg(feature = "native-certs")]
237    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
238    ///
239    /// # Errors
240    ///
241    /// Returns an error if we fail to load the native certs.
242    pub fn new_with_native_certs() -> io::Result<Self> {
243        RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
244    }
245
246    /// Connect to the given host
247    ///
248    /// # Errors
249    ///
250    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
251    /// failure when we couldn't complete the hanshake
252    #[allow(clippy::result_large_err)]
253    pub fn connect<S: Read + Write + Send + 'static>(
254        &self,
255        domain: &str,
256        stream: S,
257    ) -> Result<TlsStream<S>, HandshakeError<S>> {
258        let session = ClientConnection::new(
259            self.0.clone(),
260            server_name(domain).map_err(HandshakeError::Failure)?,
261        )
262        .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
263        MidHandshakeTlsStream { session, stream }.handshake()
264    }
265
266    #[cfg(feature = "futures")]
267    /// Connect to the given host asynchronously
268    ///
269    /// # Errors
270    ///
271    /// Returns a [`io::Error`] containing the failure when we couldn't complete the TLS hanshake
272    pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
273        &self,
274        domain: &str,
275        stream: S,
276    ) -> io::Result<AsyncTlsStream<S>> {
277        futures_rustls::TlsConnector::from(self.0.clone())
278            .connect(server_name(domain)?, stream)
279            .await
280    }
281}
282
283fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
284    Ok(ServerName::try_from(domain)
285        .map_err(|err| {
286            io::Error::new(
287                io::ErrorKind::InvalidData,
288                format!("Invalid domain name ({err:?}): {domain}"),
289            )
290        })?
291        .to_owned())
292}
293
294/// A TLS stream which has been interrupted during the handshake
295#[derive(Debug)]
296pub struct MidHandshakeTlsStream<S: Read + Write> {
297    session: ClientConnection,
298    stream: S,
299}
300
301impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
302    /// Get a reference to the inner stream
303    pub fn get_ref(&self) -> &S {
304        &self.stream
305    }
306
307    /// Get a mutable reference to the inner stream
308    pub fn get_mut(&mut self) -> &mut S {
309        &mut self.stream
310    }
311
312    /// Retry the handshake
313    ///
314    /// # Errors
315    ///
316    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
317    /// failure when we couldn't complete the hanshake
318    #[allow(clippy::result_large_err)]
319    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
320        if let Err(e) = self.session.complete_io(&mut self.stream) {
321            if e.kind() == io::ErrorKind::WouldBlock {
322                if self.session.is_handshaking() {
323                    return Err(HandshakeError::WouldBlock(self));
324                }
325            } else {
326                return Err(e.into());
327            }
328        }
329        Ok(TlsStream::new(self.session, self.stream))
330    }
331}
332
333impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
334    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335        f.write_str("MidHandshakeTlsStream")
336    }
337}
338
339/// An error returned while performing the handshake
340#[allow(clippy::large_enum_variant)]
341pub enum HandshakeError<S: Read + Write + Send + 'static> {
342    /// We hit WouldBlock during handshake.
343    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
344    WouldBlock(MidHandshakeTlsStream<S>),
345    /// We hit a critical failure.
346    Failure(io::Error),
347}
348
349impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
350    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351        match self {
352            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
353            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
354        }
355    }
356}
357
358impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
359    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360        let mut d = f.debug_tuple("HandshakeError");
361        match self {
362            HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
363            HandshakeError::Failure(err) => d.field(&err),
364        }
365        .finish()
366    }
367}
368
369impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
370    fn source(&self) -> Option<&(dyn Error + 'static)> {
371        match self {
372            HandshakeError::Failure(err) => Some(err),
373            _ => None,
374        }
375    }
376}
377
378impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
379    fn from(err: io::Error) -> Self {
380        HandshakeError::Failure(err)
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn empty_config_fails() {
390        assert!(RustlsConnectorConfig::default()
391            .connector_with_no_client_auth()
392            .is_err());
393    }
394
395    #[test]
396    #[cfg(feature = "webpki-root-certs")]
397    fn webpki_root_certs_connector_builds() {
398        RustlsConnector::new_with_webpki_root_certs().unwrap();
399    }
400
401    #[test]
402    #[cfg(feature = "platform-verifier")]
403    fn platform_verifier_connector_builds() {
404        RustlsConnector::new_with_platform_verifier().unwrap();
405    }
406
407    #[test]
408    fn handshake_error_failure_display() {
409        let err: HandshakeError<std::net::TcpStream> =
410            HandshakeError::Failure(io::Error::other("test error"));
411        assert!(err.to_string().contains("test error"));
412        assert!(format!("{err:?}").contains("test error"));
413        assert!(err.source().is_some());
414    }
415}