1#![deny(missing_docs)]
2
3pub use rustls;
31#[cfg(feature = "native-certs")]
32pub use rustls_native_certs;
33pub use rustls_pki_types;
34#[cfg(feature = "platform-verifier")]
35pub use rustls_platform_verifier;
36pub use webpki;
37#[cfg(feature = "webpki-root-certs")]
38pub use webpki_root_certs;
39
40#[cfg(feature = "futures")]
41use futures_io::{AsyncRead, AsyncWrite};
42use rustls::{
43 ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
44 client::WantsClientCert,
45};
46use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
47
48use std::{
49 error::Error,
50 fmt,
51 io::{self, Read, Write},
52 sync::Arc,
53};
54
55pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
57
58#[cfg(feature = "futures")]
59pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
61
62#[derive(Clone, Default)]
64pub struct RustlsConnectorConfig {
65 store: Vec<CertificateDer<'static>>,
66 #[cfg(feature = "platform-verifier")]
67 platform_verifier: bool,
68}
69
70impl RustlsConnectorConfig {
71 #[cfg(feature = "webpki-root-certs")]
72 pub fn new_with_webpki_root_certs() -> Self {
74 Self::default().with_webpki_root_certs()
75 }
76
77 #[cfg(feature = "platform-verifier")]
78 pub fn new_with_platform_verifier() -> Self {
80 Self::default().with_platform_verifier()
81 }
82
83 #[cfg(feature = "native-certs")]
84 pub fn new_with_native_certs() -> io::Result<Self> {
90 Self::default().with_native_certs()
91 }
92
93 pub fn add_parsable_certificates(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
97 self.store.append(&mut der_certs)
98 }
99
100 pub fn with_parsable_certificates(mut self, der_certs: Vec<CertificateDer<'static>>) -> Self {
104 self.add_parsable_certificates(der_certs);
105 self
106 }
107
108 #[cfg(feature = "webpki-root-certs")]
109 pub fn with_webpki_root_certs(mut self) -> Self {
111 self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
112 self
113 }
114
115 #[cfg(feature = "platform-verifier")]
116 pub fn with_platform_verifier(mut self) -> Self {
118 self.platform_verifier = true;
119 self
120 }
121
122 #[cfg(feature = "native-certs")]
123 pub fn with_native_certs(mut self) -> io::Result<Self> {
129 let certs_result = rustls_native_certs::load_native_certs();
130 for err in certs_result.errors {
131 log::warn!("Got error while loading some native certificates: {err:?}");
132 }
133 if certs_result.certs.is_empty() {
134 return Err(io::Error::other(
135 "Could not load any valid native certificates",
136 ));
137 }
138 self.add_parsable_certificates(certs_result.certs);
139 Ok(self)
140 }
141
142 fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
143 let builder = ClientConfig::builder();
144 #[cfg(feature = "platform-verifier")]
145 {
146 if self.platform_verifier {
147 let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
148 self.store,
149 builder.crypto_provider().clone(),
150 )
151 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
152 return Ok(builder
153 .dangerous()
154 .with_custom_certificate_verifier(Arc::new(verifier)));
155 }
156 }
157 let mut store = RootCertStore::empty();
158 let (_, ignored) = store.add_parsable_certificates(self.store);
159 if ignored > 0 {
160 log::warn!("{ignored} platform CA root certificates were ignored due to errors");
161 }
162 if store.is_empty() {
163 return Err(io::Error::other("Could not load any valid certificates"));
164 }
165 Ok(builder.with_root_certificates(store))
166 }
167
168 pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
174 Ok(self.builder()?.with_no_client_auth().into())
175 }
176
177 pub fn connector_with_single_cert(
186 self,
187 cert_chain: Vec<CertificateDer<'static>>,
188 key_der: PrivateKeyDer<'static>,
189 ) -> io::Result<RustlsConnector> {
190 Ok(self
191 .builder()?
192 .with_client_auth_cert(cert_chain, key_der)
193 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
194 .into())
195 }
196}
197
198
199#[derive(Clone)]
201pub struct RustlsConnector(Arc<ClientConfig>);
202
203impl From<ClientConfig> for RustlsConnector {
204 fn from(config: ClientConfig) -> Self {
205 Arc::new(config).into()
206 }
207}
208
209impl From<Arc<ClientConfig>> for RustlsConnector {
210 fn from(config: Arc<ClientConfig>) -> Self {
211 Self(config)
212 }
213}
214
215impl RustlsConnector {
216 #[cfg(feature = "webpki-root-certs")]
217 pub fn new_with_webpki_root_certs() -> io::Result<Self> {
223 RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
224 }
225
226 #[cfg(feature = "platform-verifier")]
227 pub fn new_with_platform_verifier() -> io::Result<Self> {
233 RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
234 }
235
236 #[cfg(feature = "native-certs")]
237 pub fn new_with_native_certs() -> io::Result<Self> {
243 RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
244 }
245
246 #[allow(clippy::result_large_err)]
253 pub fn connect<S: Read + Write + Send + 'static>(
254 &self,
255 domain: &str,
256 stream: S,
257 ) -> Result<TlsStream<S>, HandshakeError<S>> {
258 let session = ClientConnection::new(
259 self.0.clone(),
260 server_name(domain).map_err(HandshakeError::Failure)?,
261 )
262 .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
263 MidHandshakeTlsStream { session, stream }.handshake()
264 }
265
266 #[cfg(feature = "futures")]
267 pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
273 &self,
274 domain: &str,
275 stream: S,
276 ) -> io::Result<AsyncTlsStream<S>> {
277 futures_rustls::TlsConnector::from(self.0.clone())
278 .connect(server_name(domain)?, stream)
279 .await
280 }
281}
282
283fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
284 Ok(ServerName::try_from(domain)
285 .map_err(|err| {
286 io::Error::new(
287 io::ErrorKind::InvalidData,
288 format!("Invalid domain name ({err:?}): {domain}"),
289 )
290 })?
291 .to_owned())
292}
293
294#[derive(Debug)]
296pub struct MidHandshakeTlsStream<S: Read + Write> {
297 session: ClientConnection,
298 stream: S,
299}
300
301impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
302 pub fn get_ref(&self) -> &S {
304 &self.stream
305 }
306
307 pub fn get_mut(&mut self) -> &mut S {
309 &mut self.stream
310 }
311
312 #[allow(clippy::result_large_err)]
319 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
320 if let Err(e) = self.session.complete_io(&mut self.stream) {
321 if e.kind() == io::ErrorKind::WouldBlock {
322 if self.session.is_handshaking() {
323 return Err(HandshakeError::WouldBlock(self));
324 }
325 } else {
326 return Err(e.into());
327 }
328 }
329 Ok(TlsStream::new(self.session, self.stream))
330 }
331}
332
333impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
334 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335 f.write_str("MidHandshakeTlsStream")
336 }
337}
338
339#[allow(clippy::large_enum_variant)]
341pub enum HandshakeError<S: Read + Write + Send + 'static> {
342 WouldBlock(MidHandshakeTlsStream<S>),
345 Failure(io::Error),
347}
348
349impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 match self {
352 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
353 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
354 }
355 }
356}
357
358impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
359 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 let mut d = f.debug_tuple("HandshakeError");
361 match self {
362 HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
363 HandshakeError::Failure(err) => d.field(&err),
364 }
365 .finish()
366 }
367}
368
369impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
370 fn source(&self) -> Option<&(dyn Error + 'static)> {
371 match self {
372 HandshakeError::Failure(err) => Some(err),
373 _ => None,
374 }
375 }
376}
377
378impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
379 fn from(err: io::Error) -> Self {
380 HandshakeError::Failure(err)
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn empty_config_fails() {
390 assert!(RustlsConnectorConfig::default()
391 .connector_with_no_client_auth()
392 .is_err());
393 }
394
395 #[test]
396 #[cfg(feature = "webpki-root-certs")]
397 fn webpki_root_certs_connector_builds() {
398 RustlsConnector::new_with_webpki_root_certs().unwrap();
399 }
400
401 #[test]
402 #[cfg(feature = "platform-verifier")]
403 fn platform_verifier_connector_builds() {
404 RustlsConnector::new_with_platform_verifier().unwrap();
405 }
406
407 #[test]
408 fn handshake_error_failure_display() {
409 let err: HandshakeError<std::net::TcpStream> =
410 HandshakeError::Failure(io::Error::other("test error"));
411 assert!(err.to_string().contains("test error"));
412 assert!(format!("{err:?}").contains("test error"));
413 assert!(err.source().is_some());
414 }
415}