Skip to main content

rustls_connector/
lib.rs

1#![deny(missing_docs)]
2#![allow(clippy::result_large_err, clippy::large_enum_variant)]
3
4//! # Connector similar to openssl or native-tls for rustls
5//!
6//! rustls-connector is a library aiming at simplifying using rustls as
7//! an alternative to openssl and native-tls
8//!
9//! # Examples
10//!
11//! To connect to a remote server:
12//!
13//! ```rust, no_run
14//! use rustls_connector::RustlsConnector;
15//!
16//! use std::{
17//!     io::{Read, Write},
18//!     net::TcpStream,
19//! };
20//!
21//! let connector = RustlsConnector::new_with_platform_verifier().unwrap();
22//! let stream = TcpStream::connect("google.com:443").unwrap();
23//! let mut stream = connector.connect("google.com", stream).unwrap();
24//!
25//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
26//! let mut res = vec![];
27//! stream.read_to_end(&mut res).unwrap();
28//! println!("{}", String::from_utf8_lossy(&res));
29//! ```
30
31pub use rustls;
32#[cfg(feature = "native-certs")]
33pub use rustls_native_certs;
34pub use rustls_pki_types;
35#[cfg(feature = "platform-verifier")]
36pub use rustls_platform_verifier;
37pub use webpki;
38#[cfg(feature = "webpki-root-certs")]
39pub use webpki_root_certs;
40
41#[cfg(feature = "futures")]
42use futures_io::{AsyncRead, AsyncWrite};
43use rustls::{
44    ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
45    client::WantsClientCert,
46};
47use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
48
49use std::{
50    error::Error,
51    fmt,
52    io::{self, Read, Write},
53    sync::Arc,
54};
55
56/// A TLS stream
57pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
58
59#[cfg(feature = "futures")]
60/// An async TLS stream
61pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
62
63/// Configuration helper for [`RustlsConnector`]
64#[derive(Clone)]
65pub struct RustlsConnectorConfig {
66    store: Vec<CertificateDer<'static>>,
67    platform_verifier: bool,
68}
69
70impl RustlsConnectorConfig {
71    #[cfg(feature = "webpki-root-certs")]
72    /// Create a new [`RustlsConnectorConfig`] using the webpki-root-certs (requires webpki-root-certs feature enabled)
73    pub fn new_with_webpki_root_certs() -> Self {
74        Self::default().with_webpki_root_certs()
75    }
76
77    #[cfg(feature = "platform-verifier")]
78    /// Create a new [`RustlsConnectorConfig`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
79    pub fn new_with_platform_verifier() -> Self {
80        Self::default().with_platform_verifier()
81    }
82
83    #[cfg(feature = "native-certs")]
84    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if we fail to load the native certs.
89    pub fn new_with_native_certs() -> io::Result<Self> {
90        Self::default().with_native_certs()
91    }
92
93    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
94    ///
95    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
96    pub fn add_parsable_certificates<'a>(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
97        self.store.append(&mut der_certs)
98    }
99
100    #[cfg(feature = "webpki-root-certs")]
101    /// Add certs from webpki-root-certs (requires webpki-root-certs feature enabled)
102    pub fn with_webpki_root_certs(mut self) -> Self {
103        self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
104        self
105    }
106
107    #[cfg(feature = "platform-verifier")]
108    /// Use the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
109    pub fn with_platform_verifier(mut self) -> Self {
110        self.platform_verifier = true;
111        self
112    }
113
114    #[cfg(feature = "native-certs")]
115    /// Add the system certs (requires native-certs feature enabled)
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if we fail to load the native certs.
120    pub fn with_native_certs(mut self) -> io::Result<Self> {
121        let certs_result = rustls_native_certs::load_native_certs();
122        for err in certs_result.errors {
123            log::warn!("Got error while loading some native certificates: {err:?}");
124        }
125        if certs_result.certs.is_empty() {
126            return Err(io::Error::other(
127                "Could not load any valid native certificates",
128            ));
129        }
130        self.add_parsable_certificates(certs_result.certs);
131        Ok(self)
132    }
133
134    fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
135        let builder = ClientConfig::builder();
136        #[cfg(feature = "platform-verifier")]
137        {
138            if self.platform_verifier {
139                let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
140                    self.store,
141                    builder.crypto_provider().clone(),
142                )
143                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
144                return Ok(builder
145                    .dangerous()
146                    .with_custom_certificate_verifier(Arc::new(verifier)));
147            }
148        }
149        let mut store = RootCertStore::empty();
150        let (_, ignored) = store.add_parsable_certificates(self.store);
151        if ignored > 0 {
152            log::warn!("{ignored} platform CA root certificates were ignored due to errors");
153        }
154        if store.is_empty() {
155            return Err(io::Error::other("Could not load any valid certificates"));
156        }
157        Ok(builder.with_root_certificates(store))
158    }
159
160    /// Create a new [`RustlsConnector`] from this config and no client certificate
161    ///
162    /// # Errors
163    ///
164    /// Returns an error if we fail to init our verifier
165    pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
166        Ok(self.builder()?.with_no_client_auth().into())
167    }
168
169    /// Create a new [`RustlsConnector`] from this config and the given client certificate
170    ///
171    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
172    /// Ed25519 private key.
173    ///
174    /// # Errors
175    ///
176    /// Returns an error if we fail to init our verifier or if key_der is invalid.
177    pub fn connector_with_single_cert(
178        self,
179        cert_chain: Vec<CertificateDer<'static>>,
180        key_der: PrivateKeyDer<'static>,
181    ) -> io::Result<RustlsConnector> {
182        Ok(self
183            .builder()?
184            .with_client_auth_cert(cert_chain, key_der)
185            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
186            .into())
187    }
188}
189
190impl Default for RustlsConnectorConfig {
191    fn default() -> Self {
192        Self {
193            store: Vec::new(),
194            platform_verifier: false,
195        }
196    }
197}
198
199/// The connector
200#[derive(Clone)]
201pub struct RustlsConnector(Arc<ClientConfig>);
202
203impl Default for RustlsConnector {
204    fn default() -> Self {
205        RustlsConnectorConfig::default()
206            .connector_with_no_client_auth()
207            .expect("no error codepath for default RustlsConnectorConfig")
208    }
209}
210
211impl From<ClientConfig> for RustlsConnector {
212    fn from(config: ClientConfig) -> Self {
213        Arc::new(config).into()
214    }
215}
216
217impl From<Arc<ClientConfig>> for RustlsConnector {
218    fn from(config: Arc<ClientConfig>) -> Self {
219        Self(config)
220    }
221}
222
223impl RustlsConnector {
224    #[cfg(feature = "webpki-root-certs")]
225    /// Create a new RustlsConnector using the webpki-root certs (requires webpki-root-certs feature enabled)
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if we fail to init our verifier
230    pub fn new_with_webpki_root_certs() -> io::Result<Self> {
231        RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
232    }
233
234    #[cfg(feature = "platform-verifier")]
235    /// Create a new [`RustlsConnector`] using the rustls-platform-verifier mechanism (requires platform-verifier feature enabled)
236    ///
237    /// # Errors
238    ///
239    /// Returns an error if we fail to init our verifier
240    pub fn new_with_platform_verifier() -> io::Result<Self> {
241        RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
242    }
243
244    #[cfg(feature = "native-certs")]
245    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if we fail to load the native certs.
250    pub fn new_with_native_certs() -> io::Result<Self> {
251        RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
252    }
253
254    /// Connect to the given host
255    ///
256    /// # Errors
257    ///
258    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
259    /// failure when we couldn't complete the hanshake
260    pub fn connect<S: Read + Write + Send + 'static>(
261        &self,
262        domain: &str,
263        stream: S,
264    ) -> Result<TlsStream<S>, HandshakeError<S>> {
265        let session = ClientConnection::new(
266            self.0.clone(),
267            server_name(domain).map_err(HandshakeError::Failure)?,
268        )
269        .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
270        MidHandshakeTlsStream { session, stream }.handshake()
271    }
272
273    #[cfg(feature = "futures")]
274    /// Connect to the given host asynchronously
275    ///
276    /// # Errors
277    ///
278    /// Returns a [`io::Error`] containing the failure when we couldn't complete the TLS hanshake
279    pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
280        &self,
281        domain: &str,
282        stream: S,
283    ) -> io::Result<AsyncTlsStream<S>> {
284        futures_rustls::TlsConnector::from(self.0.clone())
285            .connect(server_name(domain)?, stream)
286            .await
287    }
288}
289
290fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
291    Ok(ServerName::try_from(domain)
292        .map_err(|err| {
293            io::Error::new(
294                io::ErrorKind::InvalidData,
295                format!("Invalid domain name ({err:?}): {domain}"),
296            )
297        })?
298        .to_owned())
299}
300
301/// A TLS stream which has been interrupted during the handshake
302#[derive(Debug)]
303pub struct MidHandshakeTlsStream<S: Read + Write> {
304    session: ClientConnection,
305    stream: S,
306}
307
308impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
309    /// Get a reference to the inner stream
310    pub fn get_ref(&self) -> &S {
311        &self.stream
312    }
313
314    /// Get a mutable reference to the inner stream
315    pub fn get_mut(&mut self) -> &mut S {
316        &mut self.stream
317    }
318
319    /// Retry the handshake
320    ///
321    /// # Errors
322    ///
323    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
324    /// failure when we couldn't complete the hanshake
325    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
326        if let Err(e) = self.session.complete_io(&mut self.stream) {
327            if e.kind() == io::ErrorKind::WouldBlock {
328                if self.session.is_handshaking() {
329                    return Err(HandshakeError::WouldBlock(self));
330                }
331            } else {
332                return Err(e.into());
333            }
334        }
335        Ok(TlsStream::new(self.session, self.stream))
336    }
337}
338
339impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
340    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341        f.write_str("MidHandshakeTlsStream")
342    }
343}
344
345/// An error returned while performing the handshake
346pub enum HandshakeError<S: Read + Write + Send + 'static> {
347    /// We hit WouldBlock during handshake.
348    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
349    WouldBlock(MidHandshakeTlsStream<S>),
350    /// We hit a critical failure.
351    Failure(io::Error),
352}
353
354impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
355    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
356        match self {
357            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
358            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
359        }
360    }
361}
362
363impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
364    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        let mut d = f.debug_tuple("HandshakeError");
366        match self {
367            HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
368            HandshakeError::Failure(err) => d.field(&err),
369        }
370        .finish()
371    }
372}
373
374impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
375    fn source(&self) -> Option<&(dyn Error + 'static)> {
376        match self {
377            HandshakeError::Failure(err) => Some(err),
378            _ => None,
379        }
380    }
381}
382
383impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
384    fn from(err: io::Error) -> Self {
385        HandshakeError::Failure(err)
386    }
387}