rustls_connector/
lib.rs

1#![deny(missing_docs)]
2
3//! # Connector similar to openssl or native-tls for rustls
4//!
5//! rustls-connector is a library aiming at simplifying using rustls as
6//! an alternative to openssl and native-tls
7//!
8//! # Examples
9//!
10//! To connect to a remote server:
11//!
12//! ```rust, no_run
13//! use rustls_connector::RustlsConnector;
14//!
15//! use std::{
16//!     io::{Read, Write},
17//!     net::TcpStream,
18//! };
19//!
20//! let connector = RustlsConnector::new_with_native_certs().unwrap();
21//! let stream = TcpStream::connect("google.com:443").unwrap();
22//! let mut stream = connector.connect("google.com", stream).unwrap();
23//!
24//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
25//! let mut res = vec![];
26//! stream.read_to_end(&mut res).unwrap();
27//! println!("{}", String::from_utf8_lossy(&res));
28//! ```
29
30pub use rustls;
31#[cfg(feature = "native-certs")]
32pub use rustls_native_certs;
33pub use rustls_pki_types;
34pub use webpki;
35#[cfg(feature = "webpki-roots-certs")]
36pub use webpki_roots;
37
38#[cfg(feature = "futures")]
39use futures_io::{AsyncRead, AsyncWrite};
40use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
41use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
42
43use std::{
44    error::Error,
45    fmt,
46    io::{self, Read, Write},
47    sync::Arc,
48};
49
50/// A TLS stream
51pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
52
53#[cfg(feature = "futures")]
54/// An async TLS stream
55pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
56
57/// Configuration helper for [`RustlsConnector`]
58#[derive(Clone)]
59pub struct RustlsConnectorConfig(RootCertStore);
60
61impl RustlsConnectorConfig {
62    #[cfg(feature = "webpki-roots-certs")]
63    /// Create a new [`RustlsConnectorConfig`] using the webpki-roots certs (requires webpki-roots-certs feature enabled)
64    pub fn new_with_webpki_roots_certs() -> Self {
65        Self(RootCertStore {
66            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
67        })
68    }
69
70    #[cfg(feature = "native-certs")]
71    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if we fail to load the native certs.
76    pub fn new_with_native_certs() -> io::Result<Self> {
77        let mut root_store = RootCertStore::empty();
78        let mut certs_result = rustls_native_certs::load_native_certs();
79        if let Some(err) = certs_result.errors.pop() {
80            return Err(io::Error::other(err));
81        }
82        for cert in certs_result.certs {
83            if let Err(err) = root_store.add(cert) {
84                log::warn!("Got error while importing some native certificates: {err:?}");
85            }
86        }
87        Ok(Self(root_store))
88    }
89
90    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
91    ///
92    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
93    ///
94    /// Returns the number of certificates added, and the number that were ignored.
95    pub fn add_parsable_certificates(
96        &mut self,
97        der_certs: Vec<CertificateDer<'_>>,
98    ) -> (usize, usize) {
99        self.0.add_parsable_certificates(der_certs)
100    }
101
102    /// Create a new [`RustlsConnector`] from this config and no client certificate
103    pub fn connector_with_no_client_auth(self) -> RustlsConnector {
104        ClientConfig::builder()
105            .with_root_certificates(self.0)
106            .with_no_client_auth()
107            .into()
108    }
109
110    /// Create a new [`RustlsConnector`] from this config and the given client certificate
111    ///
112    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
113    /// Ed25519 private key.
114    ///
115    /// This function fails if key_der is invalid.
116    pub fn connector_with_single_cert(
117        self,
118        cert_chain: Vec<CertificateDer<'static>>,
119        key_der: PrivateKeyDer<'static>,
120    ) -> io::Result<RustlsConnector> {
121        Ok(ClientConfig::builder()
122            .with_root_certificates(self.0)
123            .with_client_auth_cert(cert_chain, key_der)
124            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
125            .into())
126    }
127}
128
129impl Default for RustlsConnectorConfig {
130    fn default() -> Self {
131        Self(RootCertStore::empty())
132    }
133}
134
135/// The connector
136#[derive(Clone)]
137pub struct RustlsConnector(Arc<ClientConfig>);
138
139impl Default for RustlsConnector {
140    fn default() -> Self {
141        RustlsConnectorConfig::default().connector_with_no_client_auth()
142    }
143}
144
145impl From<ClientConfig> for RustlsConnector {
146    fn from(config: ClientConfig) -> Self {
147        Arc::new(config).into()
148    }
149}
150
151impl From<Arc<ClientConfig>> for RustlsConnector {
152    fn from(config: Arc<ClientConfig>) -> Self {
153        Self(config)
154    }
155}
156
157impl RustlsConnector {
158    #[cfg(feature = "webpki-roots-certs")]
159    /// Create a new RustlsConnector using the webpki-roots certs (requires webpki-roots-certs feature enabled)
160    pub fn new_with_webpki_roots_certs() -> Self {
161        RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
162    }
163
164    #[cfg(feature = "native-certs")]
165    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
166    ///
167    /// # Errors
168    ///
169    /// Returns an error if we fail to load the native certs.
170    pub fn new_with_native_certs() -> io::Result<Self> {
171        Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
172    }
173
174    /// Connect to the given host
175    ///
176    /// # Errors
177    ///
178    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
179    /// failure when we couldn't complete the hanshake
180    pub fn connect<S: Read + Write + Send + 'static>(
181        &self,
182        domain: &str,
183        stream: S,
184    ) -> Result<TlsStream<S>, HandshakeError<S>> {
185        let session = ClientConnection::new(
186            self.0.clone(),
187            server_name(domain).map_err(HandshakeError::Failure)?,
188        )
189        .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
190        MidHandshakeTlsStream { session, stream }.handshake()
191    }
192
193    #[cfg(feature = "futures")]
194    /// Connect to the given host asynchronously
195    ///
196    /// # Errors
197    ///
198    /// Returns a [`io::Error`] containing the failure when we couldn't complete the TLS hanshake
199    pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
200        &self,
201        domain: &str,
202        stream: S,
203    ) -> io::Result<AsyncTlsStream<S>> {
204        futures_rustls::TlsConnector::from(self.0.clone())
205            .connect(server_name(domain)?, stream)
206            .await
207    }
208}
209
210fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
211    Ok(ServerName::try_from(domain)
212        .map_err(|err| {
213            io::Error::new(
214                io::ErrorKind::InvalidData,
215                format!("Invalid domain name ({err:?}): {domain}"),
216            )
217        })?
218        .to_owned())
219}
220
221/// A TLS stream which has been interrupted during the handshake
222#[derive(Debug)]
223pub struct MidHandshakeTlsStream<S: Read + Write> {
224    session: ClientConnection,
225    stream: S,
226}
227
228impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
229    /// Get a reference to the inner stream
230    pub fn get_ref(&self) -> &S {
231        &self.stream
232    }
233
234    /// Get a mutable reference to the inner stream
235    pub fn get_mut(&mut self) -> &mut S {
236        &mut self.stream
237    }
238
239    /// Retry the handshake
240    ///
241    /// # Errors
242    ///
243    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
244    /// failure when we couldn't complete the hanshake
245    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
246        if let Err(e) = self.session.complete_io(&mut self.stream) {
247            if e.kind() == io::ErrorKind::WouldBlock {
248                if self.session.is_handshaking() {
249                    return Err(HandshakeError::WouldBlock(Box::new(self)));
250                }
251            } else {
252                return Err(e.into());
253            }
254        }
255        Ok(TlsStream::new(self.session, self.stream))
256    }
257}
258
259impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        f.write_str("MidHandshakeTlsStream")
262    }
263}
264
265/// An error returned while performing the handshake
266pub enum HandshakeError<S: Read + Write + Send + 'static> {
267    /// We hit WouldBlock during handshake.
268    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
269    WouldBlock(Box<MidHandshakeTlsStream<S>>),
270    /// We hit a critical failure.
271    Failure(io::Error),
272}
273
274impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        match self {
277            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
278            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
279        }
280    }
281}
282
283impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
284    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285        let mut d = f.debug_tuple("HandshakeError");
286        match self {
287            HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
288            HandshakeError::Failure(err) => d.field(&err),
289        }
290        .finish()
291    }
292}
293
294impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
295    fn source(&self) -> Option<&(dyn Error + 'static)> {
296        match self {
297            HandshakeError::Failure(err) => Some(err),
298            _ => None,
299        }
300    }
301}
302
303impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
304    fn from(err: io::Error) -> Self {
305        HandshakeError::Failure(err)
306    }
307}