rustls_connector/
lib.rs

1#![deny(missing_docs)]
2#![allow(clippy::result_large_err, clippy::large_enum_variant)]
3
4//! # Connector similar to openssl or native-tls for rustls
5//!
6//! rustls-connector is a library aiming at simplifying using rustls as
7//! an alternative to openssl and native-tls
8//!
9//! # Examples
10//!
11//! To connect to a remote server:
12//!
13//! ```rust, no_run
14//! use rustls_connector::RustlsConnector;
15//!
16//! use std::{
17//!     io::{Read, Write},
18//!     net::TcpStream,
19//! };
20//!
21//! let connector = RustlsConnector::new_with_native_certs().unwrap();
22//! let stream = TcpStream::connect("google.com:443").unwrap();
23//! let mut stream = connector.connect("google.com", stream).unwrap();
24//!
25//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
26//! let mut res = vec![];
27//! stream.read_to_end(&mut res).unwrap();
28//! println!("{}", String::from_utf8_lossy(&res));
29//! ```
30
31pub use rustls;
32#[cfg(feature = "native-certs")]
33pub use rustls_native_certs;
34pub use rustls_pki_types;
35pub use webpki;
36#[cfg(feature = "webpki-roots-certs")]
37pub use webpki_roots;
38
39#[cfg(feature = "futures")]
40use futures_io::{AsyncRead, AsyncWrite};
41use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
42use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
43
44use std::{
45    error::Error,
46    fmt,
47    io::{self, Read, Write},
48    sync::Arc,
49};
50
51/// A TLS stream
52pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
53
54#[cfg(feature = "futures")]
55/// An async TLS stream
56pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
57
58/// Configuration helper for [`RustlsConnector`]
59#[derive(Clone)]
60pub struct RustlsConnectorConfig(RootCertStore);
61
62impl RustlsConnectorConfig {
63    #[cfg(feature = "webpki-roots-certs")]
64    /// Create a new [`RustlsConnectorConfig`] using the webpki-roots certs (requires webpki-roots-certs feature enabled)
65    pub fn new_with_webpki_roots_certs() -> Self {
66        Self(RootCertStore {
67            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
68        })
69    }
70
71    #[cfg(feature = "native-certs")]
72    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if we fail to load the native certs.
77    pub fn new_with_native_certs() -> io::Result<Self> {
78        let mut root_store = RootCertStore::empty();
79        let mut certs_result = rustls_native_certs::load_native_certs();
80        if let Some(err) = certs_result.errors.pop() {
81            return Err(io::Error::other(err));
82        }
83        for cert in certs_result.certs {
84            if let Err(err) = root_store.add(cert) {
85                log::warn!("Got error while importing some native certificates: {err:?}");
86            }
87        }
88        Ok(Self(root_store))
89    }
90
91    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
92    ///
93    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
94    ///
95    /// Returns the number of certificates added, and the number that were ignored.
96    pub fn add_parsable_certificates(
97        &mut self,
98        der_certs: Vec<CertificateDer<'_>>,
99    ) -> (usize, usize) {
100        self.0.add_parsable_certificates(der_certs)
101    }
102
103    /// Create a new [`RustlsConnector`] from this config and no client certificate
104    pub fn connector_with_no_client_auth(self) -> RustlsConnector {
105        ClientConfig::builder()
106            .with_root_certificates(self.0)
107            .with_no_client_auth()
108            .into()
109    }
110
111    /// Create a new [`RustlsConnector`] from this config and the given client certificate
112    ///
113    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
114    /// Ed25519 private key.
115    ///
116    /// This function fails if key_der is invalid.
117    pub fn connector_with_single_cert(
118        self,
119        cert_chain: Vec<CertificateDer<'static>>,
120        key_der: PrivateKeyDer<'static>,
121    ) -> io::Result<RustlsConnector> {
122        Ok(ClientConfig::builder()
123            .with_root_certificates(self.0)
124            .with_client_auth_cert(cert_chain, key_der)
125            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
126            .into())
127    }
128}
129
130impl Default for RustlsConnectorConfig {
131    fn default() -> Self {
132        Self(RootCertStore::empty())
133    }
134}
135
136/// The connector
137#[derive(Clone)]
138pub struct RustlsConnector(Arc<ClientConfig>);
139
140impl Default for RustlsConnector {
141    fn default() -> Self {
142        RustlsConnectorConfig::default().connector_with_no_client_auth()
143    }
144}
145
146impl From<ClientConfig> for RustlsConnector {
147    fn from(config: ClientConfig) -> Self {
148        Arc::new(config).into()
149    }
150}
151
152impl From<Arc<ClientConfig>> for RustlsConnector {
153    fn from(config: Arc<ClientConfig>) -> Self {
154        Self(config)
155    }
156}
157
158impl RustlsConnector {
159    #[cfg(feature = "webpki-roots-certs")]
160    /// Create a new RustlsConnector using the webpki-roots certs (requires webpki-roots-certs feature enabled)
161    pub fn new_with_webpki_roots_certs() -> Self {
162        RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
163    }
164
165    #[cfg(feature = "native-certs")]
166    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
167    ///
168    /// # Errors
169    ///
170    /// Returns an error if we fail to load the native certs.
171    pub fn new_with_native_certs() -> io::Result<Self> {
172        Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
173    }
174
175    /// Connect to the given host
176    ///
177    /// # Errors
178    ///
179    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
180    /// failure when we couldn't complete the hanshake
181    pub fn connect<S: Read + Write + Send + 'static>(
182        &self,
183        domain: &str,
184        stream: S,
185    ) -> Result<TlsStream<S>, HandshakeError<S>> {
186        let session = ClientConnection::new(
187            self.0.clone(),
188            server_name(domain).map_err(HandshakeError::Failure)?,
189        )
190        .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
191        MidHandshakeTlsStream { session, stream }.handshake()
192    }
193
194    #[cfg(feature = "futures")]
195    /// Connect to the given host asynchronously
196    ///
197    /// # Errors
198    ///
199    /// Returns a [`io::Error`] containing the failure when we couldn't complete the TLS hanshake
200    pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
201        &self,
202        domain: &str,
203        stream: S,
204    ) -> io::Result<AsyncTlsStream<S>> {
205        futures_rustls::TlsConnector::from(self.0.clone())
206            .connect(server_name(domain)?, stream)
207            .await
208    }
209}
210
211fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
212    Ok(ServerName::try_from(domain)
213        .map_err(|err| {
214            io::Error::new(
215                io::ErrorKind::InvalidData,
216                format!("Invalid domain name ({err:?}): {domain}"),
217            )
218        })?
219        .to_owned())
220}
221
222/// A TLS stream which has been interrupted during the handshake
223#[derive(Debug)]
224pub struct MidHandshakeTlsStream<S: Read + Write> {
225    session: ClientConnection,
226    stream: S,
227}
228
229impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
230    /// Get a reference to the inner stream
231    pub fn get_ref(&self) -> &S {
232        &self.stream
233    }
234
235    /// Get a mutable reference to the inner stream
236    pub fn get_mut(&mut self) -> &mut S {
237        &mut self.stream
238    }
239
240    /// Retry the handshake
241    ///
242    /// # Errors
243    ///
244    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
245    /// failure when we couldn't complete the hanshake
246    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
247        if let Err(e) = self.session.complete_io(&mut self.stream) {
248            if e.kind() == io::ErrorKind::WouldBlock {
249                if self.session.is_handshaking() {
250                    return Err(HandshakeError::WouldBlock(self));
251                }
252            } else {
253                return Err(e.into());
254            }
255        }
256        Ok(TlsStream::new(self.session, self.stream))
257    }
258}
259
260impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
261    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262        f.write_str("MidHandshakeTlsStream")
263    }
264}
265
266/// An error returned while performing the handshake
267pub enum HandshakeError<S: Read + Write + Send + 'static> {
268    /// We hit WouldBlock during handshake.
269    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
270    WouldBlock(MidHandshakeTlsStream<S>),
271    /// We hit a critical failure.
272    Failure(io::Error),
273}
274
275impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
276    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277        match self {
278            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
279            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
280        }
281    }
282}
283
284impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
285    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286        let mut d = f.debug_tuple("HandshakeError");
287        match self {
288            HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
289            HandshakeError::Failure(err) => d.field(&err),
290        }
291        .finish()
292    }
293}
294
295impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
296    fn source(&self) -> Option<&(dyn Error + 'static)> {
297        match self {
298            HandshakeError::Failure(err) => Some(err),
299            _ => None,
300        }
301    }
302}
303
304impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
305    fn from(err: io::Error) -> Self {
306        HandshakeError::Failure(err)
307    }
308}