rustls_connector/
lib.rs

1#![deny(missing_docs)]
2
3//! # Connector similar to openssl or native-tls for rustls
4//!
5//! rustls-connector is a library aiming at simplifying using rustls as
6//! an alternative to openssl and native-tls
7//!
8//! # Examples
9//!
10//! To connect to a remote server:
11//!
12//! ```rust, no_run
13//! use rustls_connector::RustlsConnector;
14//!
15//! use std::{
16//!     io::{Read, Write},
17//!     net::TcpStream,
18//! };
19//!
20//! let connector = RustlsConnector::new_with_native_certs().unwrap();
21//! let stream = TcpStream::connect("google.com:443").unwrap();
22//! let mut stream = connector.connect("google.com", stream).unwrap();
23//!
24//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
25//! let mut res = vec![];
26//! stream.read_to_end(&mut res).unwrap();
27//! println!("{}", String::from_utf8_lossy(&res));
28//! ```
29
30pub use rustls;
31#[cfg(feature = "native-certs")]
32pub use rustls_native_certs;
33pub use rustls_pki_types;
34pub use webpki;
35#[cfg(feature = "webpki-roots-certs")]
36pub use webpki_roots;
37
38use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
39use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
40
41use std::{
42    error::Error,
43    fmt::{self, Debug},
44    io::{self, Read, Write},
45    sync::Arc,
46};
47
48/// A TLS stream
49pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
50
51/// Configuration helper for [`RustlsConnector`]
52#[derive(Clone)]
53pub struct RustlsConnectorConfig(RootCertStore);
54
55impl RustlsConnectorConfig {
56    #[cfg(feature = "webpki-roots-certs")]
57    /// Create a new [`RustlsConnectorConfig`] using the webpki-roots certs (requires webpki-roots-certs feature enabled)
58    pub fn new_with_webpki_roots_certs() -> Self {
59        Self(RootCertStore {
60            roots: webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect(),
61        })
62    }
63
64    #[cfg(feature = "native-certs")]
65    /// Create a new [`RustlsConnectorConfig`] using the system certs (requires native-certs feature enabled)
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if we fail to load the native certs.
70    pub fn new_with_native_certs() -> io::Result<Self> {
71        let mut root_store = RootCertStore::empty();
72        for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
73        {
74            if let Err(err) = root_store.add(cert) {
75                log::warn!(
76                    "Got error while importing some native certificates: {:?}",
77                    err
78                );
79            }
80        }
81        Ok(Self(root_store))
82    }
83
84    /// Parse the given DER-encoded certificates and add all that can be parsed in a best-effort fashion.
85    ///
86    /// This is because large collections of root certificates often include ancient or syntactically invalid certificates.
87    ///
88    /// Returns the number of certificates added, and the number that were ignored.
89    pub fn add_parsable_certificates(
90        &mut self,
91        der_certs: Vec<CertificateDer<'_>>,
92    ) -> (usize, usize) {
93        self.0.add_parsable_certificates(der_certs)
94    }
95
96    /// Create a new [`RustlsConnector`] from this config and no client certificate
97    pub fn connector_with_no_client_auth(self) -> RustlsConnector {
98        ClientConfig::builder()
99            .with_root_certificates(self.0)
100            .with_no_client_auth()
101            .into()
102    }
103
104    /// Create a new [`RustlsConnector`] from this config and the given client certificate
105    ///
106    /// cert_chain is a vector of DER-encoded certificates. key_der is a DER-encoded RSA, ECDSA, or
107    /// Ed25519 private key.
108    ///
109    /// This function fails if key_der is invalid.
110    pub fn connector_with_single_cert(
111        self,
112        cert_chain: Vec<CertificateDer<'static>>,
113        key_der: PrivateKeyDer<'static>,
114    ) -> io::Result<RustlsConnector> {
115        Ok(ClientConfig::builder()
116            .with_root_certificates(self.0)
117            .with_client_auth_cert(cert_chain, key_der)
118            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
119            .into())
120    }
121}
122
123impl Default for RustlsConnectorConfig {
124    fn default() -> Self {
125        Self(RootCertStore::empty())
126    }
127}
128
129/// The connector
130#[derive(Clone)]
131pub struct RustlsConnector(Arc<ClientConfig>);
132
133impl Default for RustlsConnector {
134    fn default() -> Self {
135        RustlsConnectorConfig::default().connector_with_no_client_auth()
136    }
137}
138
139impl From<ClientConfig> for RustlsConnector {
140    fn from(config: ClientConfig) -> Self {
141        Arc::new(config).into()
142    }
143}
144
145impl From<Arc<ClientConfig>> for RustlsConnector {
146    fn from(config: Arc<ClientConfig>) -> Self {
147        Self(config)
148    }
149}
150
151impl RustlsConnector {
152    #[cfg(feature = "webpki-roots-certs")]
153    /// Create a new RustlsConnector using the webpki-roots certs (requires webpki-roots-certs feature enabled)
154    pub fn new_with_webpki_roots_certs() -> Self {
155        RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
156    }
157
158    #[cfg(feature = "native-certs")]
159    /// Create a new [`RustlsConnector`] using the system certs (requires native-certs feature enabled)
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if we fail to load the native certs.
164    pub fn new_with_native_certs() -> io::Result<Self> {
165        Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
166    }
167
168    /// Connect to the given host
169    ///
170    /// # Errors
171    ///
172    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
173    /// failure when we couldn't complete the hanshake
174    pub fn connect<S: Debug + Read + Send + Sync + Write + 'static>(
175        &self,
176        domain: &str,
177        stream: S,
178    ) -> Result<TlsStream<S>, HandshakeError<S>> {
179        let session = ClientConnection::new(
180            self.0.clone(),
181            ServerName::try_from(domain)
182                .map_err(|err| {
183                    HandshakeError::Failure(io::Error::new(
184                        io::ErrorKind::InvalidData,
185                        format!("Invalid domain name ({:?}): {}", err, domain),
186                    ))
187                })?
188                .to_owned(),
189        )
190        .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
191        MidHandshakeTlsStream { session, stream }.handshake()
192    }
193}
194
195/// A TLS stream which has been interrupted during the handshake
196#[derive(Debug)]
197pub struct MidHandshakeTlsStream<S: Read + Write> {
198    session: ClientConnection,
199    stream: S,
200}
201
202impl<S: Debug + Read + Send + Sync + Write + 'static> MidHandshakeTlsStream<S> {
203    /// Get a reference to the inner stream
204    pub fn get_ref(&self) -> &S {
205        &self.stream
206    }
207
208    /// Get a mutable reference to the inner stream
209    pub fn get_mut(&mut self) -> &mut S {
210        &mut self.stream
211    }
212
213    /// Retry the handshake
214    ///
215    /// # Errors
216    ///
217    /// Returns a [`HandshakeError`] containing either the current state of the handshake or the
218    /// failure when we couldn't complete the hanshake
219    pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
220        if let Err(e) = self.session.complete_io(&mut self.stream) {
221            if e.kind() == io::ErrorKind::WouldBlock {
222                if self.session.is_handshaking() {
223                    return Err(HandshakeError::WouldBlock(Box::new(self)));
224                }
225            } else {
226                return Err(e.into());
227            }
228        }
229        Ok(TlsStream::new(self.session, self.stream))
230    }
231}
232
233impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235        f.write_str("MidHandshakeTlsStream")
236    }
237}
238
239/// An error returned while performing the handshake
240#[derive(Debug)]
241pub enum HandshakeError<S: Read + Send + Sync + Write + 'static> {
242    /// We hit WouldBlock during handshake.
243    /// Note that this is not a critical failure, you should be able to call handshake again once the stream is ready to perform I/O.
244    WouldBlock(Box<MidHandshakeTlsStream<S>>),
245    /// We hit a critical failure.
246    Failure(io::Error),
247}
248
249impl<S: Debug + Read + Send + Sync + Write + 'static> fmt::Display for HandshakeError<S> {
250    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251        match self {
252            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
253            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {}", err)),
254        }
255    }
256}
257
258impl<S: Debug + Read + Send + Sync + Write + 'static> Error for HandshakeError<S> {
259    fn source(&self) -> Option<&(dyn Error + 'static)> {
260        match self {
261            HandshakeError::Failure(err) => Some(err),
262            _ => None,
263        }
264    }
265}
266
267impl<S: Debug + Read + Send + Sync + Write + 'static> From<io::Error> for HandshakeError<S> {
268    fn from(err: io::Error) -> Self {
269        HandshakeError::Failure(err)
270    }
271}