rustls_connector/
lib.rs

1#![deny(missing_docs)]
2
3//! # Connector similar to openssl or native-tls for rustls
4//!
5//! rustls-connector is a library aiming at simplifying using rustls as
6//! an alternative to openssl and native-tls
7//!
8//! # Examples
9//!
10//! To connect to a remote server:
11//!
12//! ```rust, no_run
13//! use rustls_connector::RustlsConnector;
14//!
15//! use std::{
16//!     io::{Read, Write},
17//!     net::TcpStream,
18//! };
19//!
20//! let connector = RustlsConnector::new_with_native_certs().unwrap();
21//! let stream = TcpStream::connect("google.com:443").unwrap();
22//! let mut stream = connector.connect("google.com", stream).unwrap();
23//!
24//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
25//! let mut res = vec![];
26//! stream.read_to_end(&mut res).unwrap();
27//! println!("{}", String::from_utf8_lossy(&res));
28//! ```
29
30pub use rustls;
31#[cfg(feature = "native-certs")]
32pub use rustls_native_certs;
33pub use rustls_pki_types;
34pub use webpki;
35#[cfg(feature = "webpki-roots-certs")]
36pub use webpki_roots;
37
38#[cfg(feature = "futures")]
39use futures_io::{AsyncRead, AsyncWrite};
40use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
41use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
42
43use std::{
44    error::Error,
45    fmt::{self, Debug},
46    io::{self, Read, Write},
47    sync::Arc,
48};
49
50/// A TLS stream
51pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
52
53#[cfg(feature = "futures")]
54/// An async TLS stream
55pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
56
57/// Configuration helper for [`RustlsConnector`]
58#[derive(Clone)]
59pub struct RustlsConnectorConfig(RootCertStore);
60
61impl RustlsConnectorConfig {
62    #[cfg(feature = "webpki-roots-certs")]
63    /// Create a new [`RustlsConnectorConfig`] using the webpki-roots certs (requires webpki-roots-certs feature enabled)
64    pub fn new_with_webpki_roots_certs() -> Self {
65        Self(RootCertStore {
66            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
67        })
68    }
69
70    #[cfg(feature = "native-certs")]
71    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if we fail to load the native certs.
76    pub fn new_with_native_certs() -> io::Result<Self> {
77        let mut root_store = RootCertStore::empty();
78        let mut certs_result = rustls_native_certs::load_native_certs();
79        if let Some(err) = certs_result.errors.pop() {
80            return Err(io::Error::other(err));
81        }
82        for cert in certs_result.certs {
83            if let Err(err) = root_store.add(cert) {
84                log::warn!("Got error while importing some native certificates: {err:?}");
85            }
86        }
87        Ok(Self(root_store))
88    }
89
90    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
91    ///
92    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
93    ///
94    /// Returns the number of certificates added, and the number that were ignored.
95    pub fn add_parsable_certificates(
96        &mut self,
97        der_certs: Vec<CertificateDer<'_>>,
98    ) -> (usize, usize) {
99        self.0.add_parsable_certificates(der_certs)
100    }
101
102    /// Create a new [`RustlsConnector`] from this config and no client certificate
103    pub fn connector_with_no_client_auth(self) -> RustlsConnector {
104        ClientConfig::builder()
105            .with_root_certificates(self.0)
106            .with_no_client_auth()
107            .into()
108    }
109
110    /// Create a new [`RustlsConnector`] from this config and the given client certificate
111    ///
112    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
113    /// Ed25519 private key.
114    ///
115    /// This function fails if key_der is invalid.
116    pub fn connector_with_single_cert(
117        self,
118        cert_chain: Vec<CertificateDer<'static>>,
119        key_der: PrivateKeyDer<'static>,
120    ) -> io::Result<RustlsConnector> {
121        Ok(ClientConfig::builder()
122            .with_root_certificates(self.0)
123            .with_client_auth_cert(cert_chain, key_der)
124            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
125            .into())
126    }
127}
128
129impl Default for RustlsConnectorConfig {
130    fn default() -> Self {
131        Self(RootCertStore::empty())
132    }
133}
134
135/// The connector
136#[derive(Clone)]
137pub struct RustlsConnector(Arc<ClientConfig>);
138
139impl Default for RustlsConnector {
140    fn default() -> Self {
141        RustlsConnectorConfig::default().connector_with_no_client_auth()
142    }
143}
144
145impl From<ClientConfig> for RustlsConnector {
146    fn from(config: ClientConfig) -> Self {
147        Arc::new(config).into()
148    }
149}
150
151impl From<Arc<ClientConfig>> for RustlsConnector {
152    fn from(config: Arc<ClientConfig>) -> Self {
153        Self(config)
154    }
155}
156
157impl RustlsConnector {
158    #[cfg(feature = "webpki-roots-certs")]
159    /// Create a new RustlsConnector using the webpki-roots certs (requires webpki-roots-certs feature enabled)
160    pub fn new_with_webpki_roots_certs() -> Self {
161        RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
162    }
163
164    #[cfg(feature = "native-certs")]
165    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
166    ///
167    /// # Errors
168    ///
169    /// Returns an error if we fail to load the native certs.
170    pub fn new_with_native_certs() -> io::Result<Self> {
171        Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
172    }
173
174    /// Connect to the given host
175    ///
176    /// # Errors
177    ///
178    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
179    /// failure when we couldn't complete the hanshake
180    pub fn connect<S: Debug + Read + Send + Sync + Write + 'static>(
181        &self,
182        domain: &str,
183        stream: S,
184    ) -> Result<TlsStream<S>, HandshakeError<S>> {
185        let session = ClientConnection::new(
186            self.0.clone(),
187            server_name(domain).map_err(HandshakeError::Failure)?,
188        )
189        .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
190        MidHandshakeTlsStream { session, stream }.handshake()
191    }
192
193    #[cfg(feature = "futures")]
194    /// Connect to the given host asynchronously
195    ///
196    /// # Errors
197    ///
198    /// Returns a [`io::Error`] containing the failure when we couldn't complete the TLS hanshake
199    pub async fn connect_async<
200        S: Debug + AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
201    >(
202        &self,
203        domain: &str,
204        stream: S,
205    ) -> io::Result<AsyncTlsStream<S>> {
206        futures_rustls::TlsConnector::from(self.0.clone())
207            .connect(server_name(domain)?, stream)
208            .await
209    }
210}
211
212fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
213    Ok(ServerName::try_from(domain)
214        .map_err(|err| {
215            io::Error::new(
216                io::ErrorKind::InvalidData,
217                format!("Invalid domain name ({err:?}): {domain}"),
218            )
219        })?
220        .to_owned())
221}
222
223/// A TLS stream which has been interrupted during the handshake
224#[derive(Debug)]
225pub struct MidHandshakeTlsStream<S: Read + Write> {
226    session: ClientConnection,
227    stream: S,
228}
229
230impl<S: Debug + Read + Send + Sync + Write + 'static> MidHandshakeTlsStream<S> {
231    /// Get a reference to the inner stream
232    pub fn get_ref(&self) -> &S {
233        &self.stream
234    }
235
236    /// Get a mutable reference to the inner stream
237    pub fn get_mut(&mut self) -> &mut S {
238        &mut self.stream
239    }
240
241    /// Retry the handshake
242    ///
243    /// # Errors
244    ///
245    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
246    /// failure when we couldn't complete the hanshake
247    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
248        if let Err(e) = self.session.complete_io(&mut self.stream) {
249            if e.kind() == io::ErrorKind::WouldBlock {
250                if self.session.is_handshaking() {
251                    return Err(HandshakeError::WouldBlock(Box::new(self)));
252                }
253            } else {
254                return Err(e.into());
255            }
256        }
257        Ok(TlsStream::new(self.session, self.stream))
258    }
259}
260
261impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
262    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263        f.write_str("MidHandshakeTlsStream")
264    }
265}
266
267/// An error returned while performing the handshake
268#[derive(Debug)]
269pub enum HandshakeError<S: Read + Send + Sync + Write + 'static> {
270    /// We hit WouldBlock during handshake.
271    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
272    WouldBlock(Box<MidHandshakeTlsStream<S>>),
273    /// We hit a critical failure.
274    Failure(io::Error),
275}
276
277impl<S: Debug + Read + Send + Sync + Write + 'static> fmt::Display for HandshakeError<S> {
278    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279        match self {
280            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
281            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
282        }
283    }
284}
285
286impl<S: Debug + Read + Send + Sync + Write + 'static> Error for HandshakeError<S> {
287    fn source(&self) -> Option<&(dyn Error + 'static)> {
288        match self {
289            HandshakeError::Failure(err) => Some(err),
290            _ => None,
291        }
292    }
293}
294
295impl<S: Debug + Read + Send + Sync + Write + 'static> From<io::Error> for HandshakeError<S> {
296    fn from(err: io::Error) -> Self {
297        HandshakeError::Failure(err)
298    }
299}