1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4 clippy::must_use_candidate,
5 clippy::unwrap_in_result,
6 clippy::panic_in_result_fn
7)]
8
9pub use rustls;
37#[cfg(feature = "native-certs")]
38pub use rustls_native_certs;
39pub use rustls_pki_types;
40#[cfg(feature = "platform-verifier")]
41pub use rustls_platform_verifier;
42pub use webpki;
43#[cfg(feature = "webpki-root-certs")]
44pub use webpki_root_certs;
45
46#[cfg(feature = "futures")]
47use futures_io::{AsyncRead, AsyncWrite};
48use rustls::{
49 ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
50 client::WantsClientCert,
51};
52use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
53
54use std::{
55 error::Error,
56 fmt,
57 io::{self, Read, Write},
58 sync::Arc,
59};
60
61pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
63
64#[cfg(feature = "futures")]
65pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
67
68#[derive(Clone, Default, Debug)]
70pub struct RustlsConnectorConfig {
71 store: Vec<CertificateDer<'static>>,
72 #[cfg(feature = "platform-verifier")]
73 platform_verifier: bool,
74}
75
76impl RustlsConnectorConfig {
77 #[cfg(feature = "webpki-root-certs")]
78 pub fn new_with_webpki_root_certs() -> Self {
80 Self::default().with_webpki_root_certs()
81 }
82
83 #[cfg(feature = "platform-verifier")]
84 pub fn new_with_platform_verifier() -> Self {
86 Self::default().with_platform_verifier()
87 }
88
89 #[cfg(feature = "native-certs")]
90 pub fn new_with_native_certs() -> io::Result<Self> {
96 Self::default().with_native_certs()
97 }
98
99 pub fn add_parsable_certificates(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
103 self.store.append(&mut der_certs)
104 }
105
106 pub fn with_parsable_certificates(mut self, der_certs: Vec<CertificateDer<'static>>) -> Self {
110 self.add_parsable_certificates(der_certs);
111 self
112 }
113
114 #[cfg(feature = "webpki-root-certs")]
115 pub fn with_webpki_root_certs(mut self) -> Self {
117 self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
118 self
119 }
120
121 #[cfg(feature = "platform-verifier")]
122 pub fn with_platform_verifier(mut self) -> Self {
124 self.platform_verifier = true;
125 self
126 }
127
128 #[cfg(feature = "native-certs")]
129 pub fn with_native_certs(mut self) -> io::Result<Self> {
135 let certs_result = rustls_native_certs::load_native_certs();
136 for err in certs_result.errors {
137 log::warn!("Got error while loading some native certificates: {err:?}");
138 }
139 if certs_result.certs.is_empty() {
140 return Err(io::Error::other(
141 "Could not load any valid native certificates",
142 ));
143 }
144 self.add_parsable_certificates(certs_result.certs);
145 Ok(self)
146 }
147
148 fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
149 let builder = ClientConfig::builder();
150 #[cfg(feature = "platform-verifier")]
151 {
152 if self.platform_verifier {
153 let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
154 self.store,
155 builder.crypto_provider().clone(),
156 )
157 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
158 return Ok(builder
159 .dangerous()
160 .with_custom_certificate_verifier(Arc::new(verifier)));
161 }
162 }
163 let mut store = RootCertStore::empty();
164 let (_, ignored) = store.add_parsable_certificates(self.store);
165 if ignored > 0 {
166 log::warn!("{ignored} platform CA root certificates were ignored due to errors");
167 }
168 if store.is_empty() {
169 return Err(io::Error::other("Could not load any valid certificates"));
170 }
171 Ok(builder.with_root_certificates(store))
172 }
173
174 pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
180 Ok(self.builder()?.with_no_client_auth().into())
181 }
182
183 pub fn connector_with_single_cert(
192 self,
193 cert_chain: Vec<CertificateDer<'static>>,
194 key_der: PrivateKeyDer<'static>,
195 ) -> io::Result<RustlsConnector> {
196 Ok(self
197 .builder()?
198 .with_client_auth_cert(cert_chain, key_der)
199 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
200 .into())
201 }
202}
203
204#[derive(Clone, Debug)]
206pub struct RustlsConnector(Arc<ClientConfig>);
207
208impl From<ClientConfig> for RustlsConnector {
209 fn from(config: ClientConfig) -> Self {
210 Arc::new(config).into()
211 }
212}
213
214impl From<Arc<ClientConfig>> for RustlsConnector {
215 fn from(config: Arc<ClientConfig>) -> Self {
216 Self(config)
217 }
218}
219
220impl RustlsConnector {
221 #[cfg(feature = "webpki-root-certs")]
222 pub fn new_with_webpki_root_certs() -> io::Result<Self> {
228 RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
229 }
230
231 #[cfg(feature = "platform-verifier")]
232 pub fn new_with_platform_verifier() -> io::Result<Self> {
238 RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
239 }
240
241 #[cfg(feature = "native-certs")]
242 pub fn new_with_native_certs() -> io::Result<Self> {
248 RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
249 }
250
251 #[allow(clippy::result_large_err)]
258 pub fn connect<S: Read + Write + Send + 'static>(
259 &self,
260 domain: &str,
261 stream: S,
262 ) -> Result<TlsStream<S>, HandshakeError<S>> {
263 let session = ClientConnection::new(
264 self.0.clone(),
265 server_name(domain).map_err(HandshakeError::Failure)?,
266 )
267 .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
268 MidHandshakeTlsStream { session, stream }.handshake()
269 }
270
271 #[cfg(feature = "futures")]
272 pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
278 &self,
279 domain: &str,
280 stream: S,
281 ) -> io::Result<AsyncTlsStream<S>> {
282 futures_rustls::TlsConnector::from(self.0.clone())
283 .connect(server_name(domain)?, stream)
284 .await
285 }
286}
287
288fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
289 Ok(ServerName::try_from(domain)
290 .map_err(|err| {
291 io::Error::new(
292 io::ErrorKind::InvalidData,
293 format!("Invalid domain name ({err:?}): {domain}"),
294 )
295 })?
296 .to_owned())
297}
298
299#[derive(Debug)]
301pub struct MidHandshakeTlsStream<S: Read + Write> {
302 session: ClientConnection,
303 stream: S,
304}
305
306impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
307 pub fn get_ref(&self) -> &S {
309 &self.stream
310 }
311
312 pub fn get_mut(&mut self) -> &mut S {
314 &mut self.stream
315 }
316
317 #[allow(clippy::result_large_err)]
324 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
325 if let Err(e) = self.session.complete_io(&mut self.stream) {
326 if e.kind() == io::ErrorKind::WouldBlock {
327 if self.session.is_handshaking() {
328 return Err(HandshakeError::WouldBlock(self));
329 }
330 } else {
331 return Err(e.into());
332 }
333 }
334 Ok(TlsStream::new(self.session, self.stream))
335 }
336}
337
338impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
339 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340 f.write_str("MidHandshakeTlsStream")
341 }
342}
343
344#[allow(clippy::large_enum_variant)]
346pub enum HandshakeError<S: Read + Write + Send + 'static> {
347 WouldBlock(MidHandshakeTlsStream<S>),
350 Failure(io::Error),
352}
353
354impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
355 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
356 match self {
357 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
358 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
359 }
360 }
361}
362
363impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
364 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365 let mut d = f.debug_tuple("HandshakeError");
366 match self {
367 HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
368 HandshakeError::Failure(err) => d.field(&err),
369 }
370 .finish()
371 }
372}
373
374impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
375 fn source(&self) -> Option<&(dyn Error + 'static)> {
376 match self {
377 HandshakeError::Failure(err) => Some(err),
378 _ => None,
379 }
380 }
381}
382
383impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
384 fn from(err: io::Error) -> Self {
385 HandshakeError::Failure(err)
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn empty_config_fails() {
395 assert!(
396 RustlsConnectorConfig::default()
397 .connector_with_no_client_auth()
398 .is_err()
399 );
400 }
401
402 #[test]
403 #[cfg(feature = "webpki-root-certs")]
404 fn webpki_root_certs_connector_builds() {
405 RustlsConnector::new_with_webpki_root_certs().unwrap();
406 }
407
408 #[test]
409 #[cfg(feature = "platform-verifier")]
410 fn platform_verifier_connector_builds() {
411 RustlsConnector::new_with_platform_verifier().unwrap();
412 }
413
414 #[test]
415 fn handshake_error_failure_display() {
416 let err: HandshakeError<std::net::TcpStream> =
417 HandshakeError::Failure(io::Error::other("test error"));
418 assert!(err.to_string().contains("test error"));
419 assert!(format!("{err:?}").contains("test error"));
420 assert!(err.source().is_some());
421 }
422}