Skip to main content

rustls_connector/
lib.rs

1#![deny(missing_docs)]
2#![allow(clippy::result_large_err, clippy::large_enum_variant)]
3
4//! # Connector similar to openssl or native-tls for rustls
5//!
6//! rustls-connector is a library aiming at simplifying using rustls as
7//! an alternative to openssl and native-tls
8//!
9//! # Examples
10//!
11//! To connect to a remote server:
12//!
13//! ```rust, no_run
14//! use rustls_connector::RustlsConnector;
15//!
16//! use std::{
17//!     io::{Read, Write},
18//!     net::TcpStream,
19//! };
20//!
21//! let connector = RustlsConnector::new_with_native_certs().unwrap();
22//! let stream = TcpStream::connect("google.com:443").unwrap();
23//! let mut stream = connector.connect("google.com", stream).unwrap();
24//!
25//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
26//! let mut res = vec![];
27//! stream.read_to_end(&mut res).unwrap();
28//! println!("{}", String::from_utf8_lossy(&res));
29//! ```
30
31pub use rustls;
32#[cfg(feature = "native-certs")]
33pub use rustls_native_certs;
34pub use rustls_pki_types;
35pub use webpki;
36#[cfg(feature = "webpki-roots-certs")]
37pub use webpki_root_certs;
38
39#[cfg(feature = "futures")]
40use futures_io::{AsyncRead, AsyncWrite};
41use rustls::{
42    ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
43    client::WantsClientCert,
44};
45use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
46
47use std::{
48    error::Error,
49    fmt,
50    io::{self, Read, Write},
51    sync::Arc,
52};
53
54/// A TLS stream
55pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
56
57#[cfg(feature = "futures")]
58/// An async TLS stream
59pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
60
61/// Configuration helper for [`RustlsConnector`]
62#[derive(Clone)]
63pub struct RustlsConnectorConfig(RootCertStore);
64
65impl RustlsConnectorConfig {
66    #[cfg(feature = "webpki-roots-certs")]
67    /// Create a new [`RustlsConnectorConfig`] using the webpki-roots certs (requires webpki-roots-certs feature enabled)
68    pub fn new_with_webpki_roots_certs() -> Self {
69        Self::default().with_webpki_root_certs()
70    }
71
72    #[cfg(feature = "native-certs")]
73    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if we fail to load the native certs.
78    pub fn new_with_native_certs() -> io::Result<Self> {
79        let mut config = Self::default();
80        let (_, ignored) = config.register_native_certs()?;
81        if ignored > 0 {
82            log::warn!("{ignored} platform CA root certificates were ignored due to errors");
83        }
84        Ok(config)
85    }
86
87    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
88    ///
89    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
90    ///
91    /// Returns the number of certificates added, and the number that were ignored.
92    pub fn add_parsable_certificates<'a>(
93        &mut self,
94        der_certs: impl IntoIterator<Item = CertificateDer<'a>>,
95    ) -> (usize, usize) {
96        self.0.add_parsable_certificates(der_certs)
97    }
98
99    #[cfg(feature = "webpki-roots-certs")]
100    /// Add certs from webpki-root-certs (requires webpki-roots-certs feature enabled)
101    pub fn with_webpki_root_certs(mut self) -> Self {
102        self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.iter().cloned());
103        self
104    }
105
106    #[cfg(feature = "native-certs")]
107    /// Add the system certs (requires native-certs feature enabled)
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if we fail to load the native certs.
112    pub fn register_native_certs(&mut self) -> io::Result<(usize, usize)> {
113        let certs_result = rustls_native_certs::load_native_certs();
114        for err in certs_result.errors {
115            log::warn!("Got error while loading some native certificates: {err:?}");
116        }
117        let (added, ignored) = self.add_parsable_certificates(certs_result.certs);
118        if self.0.is_empty() {
119            return Err(io::Error::other(
120                "Could not load any valid native certificates",
121            ));
122        }
123        Ok((added, ignored))
124    }
125
126    fn builder(self) -> ConfigBuilder<ClientConfig, WantsClientCert> {
127        ClientConfig::builder().with_root_certificates(self.0)
128    }
129
130    /// Create a new [`RustlsConnector`] from this config and no client certificate
131    pub fn connector_with_no_client_auth(self) -> RustlsConnector {
132        self.builder().with_no_client_auth().into()
133    }
134
135    /// Create a new [`RustlsConnector`] from this config and the given client certificate
136    ///
137    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
138    /// Ed25519 private key.
139    ///
140    /// This function fails if key_der is invalid.
141    pub fn connector_with_single_cert(
142        self,
143        cert_chain: Vec<CertificateDer<'static>>,
144        key_der: PrivateKeyDer<'static>,
145    ) -> io::Result<RustlsConnector> {
146        Ok(self
147            .builder()
148            .with_client_auth_cert(cert_chain, key_der)
149            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
150            .into())
151    }
152}
153
154impl Default for RustlsConnectorConfig {
155    fn default() -> Self {
156        Self(RootCertStore::empty())
157    }
158}
159
160/// The connector
161#[derive(Clone)]
162pub struct RustlsConnector(Arc<ClientConfig>);
163
164impl Default for RustlsConnector {
165    fn default() -> Self {
166        RustlsConnectorConfig::default().connector_with_no_client_auth()
167    }
168}
169
170impl From<ClientConfig> for RustlsConnector {
171    fn from(config: ClientConfig) -> Self {
172        Arc::new(config).into()
173    }
174}
175
176impl From<Arc<ClientConfig>> for RustlsConnector {
177    fn from(config: Arc<ClientConfig>) -> Self {
178        Self(config)
179    }
180}
181
182impl RustlsConnector {
183    #[cfg(feature = "webpki-roots-certs")]
184    /// Create a new RustlsConnector using the webpki-roots certs (requires webpki-roots-certs feature enabled)
185    pub fn new_with_webpki_roots_certs() -> Self {
186        RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
187    }
188
189    #[cfg(feature = "native-certs")]
190    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
191    ///
192    /// # Errors
193    ///
194    /// Returns an error if we fail to load the native certs.
195    pub fn new_with_native_certs() -> io::Result<Self> {
196        Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
197    }
198
199    /// Connect to the given host
200    ///
201    /// # Errors
202    ///
203    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
204    /// failure when we couldn't complete the hanshake
205    pub fn connect<S: Read + Write + Send + 'static>(
206        &self,
207        domain: &str,
208        stream: S,
209    ) -> Result<TlsStream<S>, HandshakeError<S>> {
210        let session = ClientConnection::new(
211            self.0.clone(),
212            server_name(domain).map_err(HandshakeError::Failure)?,
213        )
214        .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
215        MidHandshakeTlsStream { session, stream }.handshake()
216    }
217
218    #[cfg(feature = "futures")]
219    /// Connect to the given host asynchronously
220    ///
221    /// # Errors
222    ///
223    /// Returns a [`io::Error`] containing the failure when we couldn't complete the TLS hanshake
224    pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
225        &self,
226        domain: &str,
227        stream: S,
228    ) -> io::Result<AsyncTlsStream<S>> {
229        futures_rustls::TlsConnector::from(self.0.clone())
230            .connect(server_name(domain)?, stream)
231            .await
232    }
233}
234
235fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
236    Ok(ServerName::try_from(domain)
237        .map_err(|err| {
238            io::Error::new(
239                io::ErrorKind::InvalidData,
240                format!("Invalid domain name ({err:?}): {domain}"),
241            )
242        })?
243        .to_owned())
244}
245
246/// A TLS stream which has been interrupted during the handshake
247#[derive(Debug)]
248pub struct MidHandshakeTlsStream<S: Read + Write> {
249    session: ClientConnection,
250    stream: S,
251}
252
253impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
254    /// Get a reference to the inner stream
255    pub fn get_ref(&self) -> &S {
256        &self.stream
257    }
258
259    /// Get a mutable reference to the inner stream
260    pub fn get_mut(&mut self) -> &mut S {
261        &mut self.stream
262    }
263
264    /// Retry the handshake
265    ///
266    /// # Errors
267    ///
268    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
269    /// failure when we couldn't complete the hanshake
270    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
271        if let Err(e) = self.session.complete_io(&mut self.stream) {
272            if e.kind() == io::ErrorKind::WouldBlock {
273                if self.session.is_handshaking() {
274                    return Err(HandshakeError::WouldBlock(self));
275                }
276            } else {
277                return Err(e.into());
278            }
279        }
280        Ok(TlsStream::new(self.session, self.stream))
281    }
282}
283
284impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
285    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286        f.write_str("MidHandshakeTlsStream")
287    }
288}
289
290/// An error returned while performing the handshake
291pub enum HandshakeError<S: Read + Write + Send + 'static> {
292    /// We hit WouldBlock during handshake.
293    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
294    WouldBlock(MidHandshakeTlsStream<S>),
295    /// We hit a critical failure.
296    Failure(io::Error),
297}
298
299impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
300    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301        match self {
302            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
303            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
304        }
305    }
306}
307
308impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
309    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
310        let mut d = f.debug_tuple("HandshakeError");
311        match self {
312            HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
313            HandshakeError::Failure(err) => d.field(&err),
314        }
315        .finish()
316    }
317}
318
319impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
320    fn source(&self) -> Option<&(dyn Error + 'static)> {
321        match self {
322            HandshakeError::Failure(err) => Some(err),
323            _ => None,
324        }
325    }
326}
327
328impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
329    fn from(err: io::Error) -> Self {
330        HandshakeError::Failure(err)
331    }
332}