1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4 clippy::must_use_candidate,
5 clippy::unwrap_in_result,
6 clippy::panic_in_result_fn
7)]
8
9pub use rustls;
61#[cfg(feature = "native-certs")]
62pub use rustls_native_certs;
64pub use rustls_pki_types;
66#[cfg(feature = "platform-verifier")]
67pub use rustls_platform_verifier;
69pub use webpki;
71#[cfg(feature = "webpki-root-certs")]
72pub use webpki_root_certs;
74
75#[cfg(feature = "futures")]
76use futures_io::{AsyncRead, AsyncWrite};
77use rustls::{
78 ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
79 client::WantsClientCert,
80};
81use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
82
83use std::{
84 error::Error,
85 fmt,
86 io::{self, Read, Write},
87 sync::Arc,
88};
89
90pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
92
93#[cfg(feature = "futures")]
94pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
96
97#[derive(Clone, Default, Debug)]
99pub struct RustlsConnectorConfig {
100 store: Vec<CertificateDer<'static>>,
101 #[cfg(feature = "platform-verifier")]
102 platform_verifier: bool,
103}
104
105impl RustlsConnectorConfig {
106 #[cfg(feature = "webpki-root-certs")]
107 pub fn new_with_webpki_root_certs() -> Self {
109 Self::default().with_webpki_root_certs()
110 }
111
112 #[cfg(feature = "platform-verifier")]
113 pub fn new_with_platform_verifier() -> Self {
115 Self::default().with_platform_verifier()
116 }
117
118 #[cfg(feature = "native-certs")]
119 pub fn new_with_native_certs() -> io::Result<Self> {
125 Self::default().with_native_certs()
126 }
127
128 pub fn add_parsable_certificates(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
132 self.store.append(&mut der_certs)
133 }
134
135 pub fn with_parsable_certificates(mut self, der_certs: Vec<CertificateDer<'static>>) -> Self {
139 self.add_parsable_certificates(der_certs);
140 self
141 }
142
143 #[cfg(feature = "webpki-root-certs")]
144 pub fn with_webpki_root_certs(mut self) -> Self {
146 self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
147 self
148 }
149
150 #[cfg(feature = "platform-verifier")]
151 pub fn with_platform_verifier(mut self) -> Self {
153 self.platform_verifier = true;
154 self
155 }
156
157 #[cfg(feature = "native-certs")]
158 pub fn with_native_certs(mut self) -> io::Result<Self> {
164 let certs_result = rustls_native_certs::load_native_certs();
165 for err in certs_result.errors {
166 log::warn!("Got error while loading some native certificates: {err:?}");
167 }
168 if certs_result.certs.is_empty() {
169 return Err(io::Error::other(
170 "Could not load any valid native certificates",
171 ));
172 }
173 self.add_parsable_certificates(certs_result.certs);
174 Ok(self)
175 }
176
177 fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
178 let builder = ClientConfig::builder();
179 #[cfg(feature = "platform-verifier")]
180 {
181 if self.platform_verifier {
182 let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
183 self.store,
184 builder.crypto_provider().clone(),
185 )
186 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
187 return Ok(builder
190 .dangerous()
191 .with_custom_certificate_verifier(Arc::new(verifier)));
192 }
193 }
194 let mut store = RootCertStore::empty();
195 let (_, ignored) = store.add_parsable_certificates(self.store);
196 if ignored > 0 {
197 log::warn!("{ignored} platform CA root certificates were ignored due to errors");
198 }
199 if store.is_empty() {
200 return Err(io::Error::other("Could not load any valid certificates"));
201 }
202 Ok(builder.with_root_certificates(store))
203 }
204
205 pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
211 Ok(self.builder()?.with_no_client_auth().into())
212 }
213
214 pub fn connector_with_single_cert(
223 self,
224 cert_chain: Vec<CertificateDer<'static>>,
225 key_der: PrivateKeyDer<'static>,
226 ) -> io::Result<RustlsConnector> {
227 Ok(self
228 .builder()?
229 .with_client_auth_cert(cert_chain, key_der)
230 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
231 .into())
232 }
233}
234
235#[derive(Clone, Debug)]
242pub struct RustlsConnector(Arc<ClientConfig>);
243
244impl From<ClientConfig> for RustlsConnector {
245 fn from(config: ClientConfig) -> Self {
246 Arc::new(config).into()
247 }
248}
249
250impl From<Arc<ClientConfig>> for RustlsConnector {
251 fn from(config: Arc<ClientConfig>) -> Self {
252 Self(config)
253 }
254}
255
256impl RustlsConnector {
257 #[cfg(feature = "webpki-root-certs")]
258 pub fn new_with_webpki_root_certs() -> io::Result<Self> {
264 RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
265 }
266
267 #[cfg(feature = "platform-verifier")]
268 pub fn new_with_platform_verifier() -> io::Result<Self> {
274 RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
275 }
276
277 #[cfg(feature = "native-certs")]
278 pub fn new_with_native_certs() -> io::Result<Self> {
284 RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
285 }
286
287 #[allow(clippy::result_large_err)]
294 pub fn connect<S: Read + Write + Send + 'static>(
295 &self,
296 domain: &str,
297 stream: S,
298 ) -> Result<TlsStream<S>, HandshakeError<S>> {
299 let session = ClientConnection::new(
300 self.0.clone(),
301 server_name(domain).map_err(HandshakeError::Failure)?,
302 )
303 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
304 MidHandshakeTlsStream { session, stream }.handshake()
305 }
306
307 #[cfg(feature = "futures")]
308 pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
314 &self,
315 domain: &str,
316 stream: S,
317 ) -> io::Result<AsyncTlsStream<S>> {
318 futures_rustls::TlsConnector::from(self.0.clone())
319 .connect(server_name(domain)?, stream)
320 .await
321 }
322}
323
324fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
325 Ok(ServerName::try_from(domain)
326 .map_err(|err| {
327 io::Error::new(
328 io::ErrorKind::InvalidData,
329 format!("Invalid domain name: {err:?}"),
330 )
331 })?
332 .to_owned())
333}
334
335#[derive(Debug)]
337pub struct MidHandshakeTlsStream<S: Read + Write> {
338 session: ClientConnection,
339 stream: S,
340}
341
342impl<S: Read + Write + Send + 'static> MidHandshakeTlsStream<S> {
343 pub fn get_ref(&self) -> &S {
345 &self.stream
346 }
347
348 pub fn get_mut(&mut self) -> &mut S {
350 &mut self.stream
351 }
352
353 #[allow(clippy::result_large_err)]
360 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
361 if let Err(e) = self.session.complete_io(&mut self.stream) {
362 if e.kind() == io::ErrorKind::WouldBlock {
363 if self.session.is_handshaking() {
364 return Err(HandshakeError::WouldBlock(self));
365 }
366 } else {
367 return Err(e.into());
368 }
369 }
370 Ok(TlsStream::new(self.session, self.stream))
371 }
372}
373
374impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
375 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376 f.write_str("MidHandshakeTlsStream")
377 }
378}
379
380#[allow(clippy::large_enum_variant)]
382pub enum HandshakeError<S: Read + Write + Send + 'static> {
383 WouldBlock(MidHandshakeTlsStream<S>),
386 Failure(io::Error),
388}
389
390impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
391 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392 match self {
393 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
394 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
395 }
396 }
397}
398
399impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
400 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401 let mut d = f.debug_tuple("HandshakeError");
402 match self {
403 HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
404 HandshakeError::Failure(err) => d.field(&err),
405 }
406 .finish()
407 }
408}
409
410impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
411 fn source(&self) -> Option<&(dyn Error + 'static)> {
412 match self {
413 HandshakeError::Failure(err) => Some(err),
414 _ => None,
415 }
416 }
417}
418
419impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
420 fn from(err: io::Error) -> Self {
421 HandshakeError::Failure(err)
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn empty_config_fails() {
431 assert!(
432 RustlsConnectorConfig::default()
433 .connector_with_no_client_auth()
434 .is_err()
435 );
436 }
437
438 #[test]
439 #[cfg(feature = "webpki-root-certs")]
440 fn webpki_root_certs_connector_builds() {
441 RustlsConnector::new_with_webpki_root_certs().unwrap();
442 }
443
444 #[test]
445 #[cfg(feature = "platform-verifier")]
446 fn platform_verifier_connector_builds() {
447 RustlsConnector::new_with_platform_verifier().unwrap();
448 }
449
450 #[test]
451 fn handshake_error_failure_display() {
452 let err: HandshakeError<std::net::TcpStream> =
453 HandshakeError::Failure(io::Error::other("test error"));
454 assert!(err.to_string().contains("test error"));
455 assert!(format!("{err:?}").contains("test error"));
456 assert!(err.source().is_some());
457 }
458}