1#![deny(missing_docs)]
2#![allow(clippy::result_large_err, clippy::large_enum_variant)]
3
4pub use rustls;
32#[cfg(feature = "native-certs")]
33pub use rustls_native_certs;
34pub use rustls_pki_types;
35#[cfg(feature = "platform-verifier")]
36pub use rustls_platform_verifier;
37pub use webpki;
38#[cfg(feature = "webpki-root-certs")]
39pub use webpki_root_certs;
40
41#[cfg(feature = "futures")]
42use futures_io::{AsyncRead, AsyncWrite};
43use rustls::{
44 ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
45 client::WantsClientCert,
46};
47use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
48
49use std::{
50 error::Error,
51 fmt,
52 io::{self, Read, Write},
53 sync::Arc,
54};
55
56pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
58
59#[cfg(feature = "futures")]
60pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
62
63#[derive(Clone)]
65pub struct RustlsConnectorConfig {
66 store: Vec<CertificateDer<'static>>,
67 platform_verifier: bool,
68}
69
70impl RustlsConnectorConfig {
71 #[cfg(feature = "webpki-root-certs")]
72 pub fn new_with_webpki_root_certs() -> Self {
74 Self::default().with_webpki_root_certs()
75 }
76
77 #[cfg(feature = "platform-verifier")]
78 pub fn new_with_platform_verifier() -> Self {
80 Self::default().with_platform_verifier()
81 }
82
83 #[cfg(feature = "native-certs")]
84 pub fn new_with_native_certs() -> io::Result<Self> {
90 Self::default().with_native_certs()
91 }
92
93 pub fn add_parsable_certificates<'a>(&mut self, mut der_certs: Vec<CertificateDer<'static>>) {
97 self.store.append(&mut der_certs)
98 }
99
100 #[cfg(feature = "webpki-root-certs")]
101 pub fn with_webpki_root_certs(mut self) -> Self {
103 self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.to_vec());
104 self
105 }
106
107 #[cfg(feature = "platform-verifier")]
108 pub fn with_platform_verifier(mut self) -> Self {
110 self.platform_verifier = true;
111 self
112 }
113
114 #[cfg(feature = "native-certs")]
115 pub fn with_native_certs(mut self) -> io::Result<Self> {
121 let certs_result = rustls_native_certs::load_native_certs();
122 for err in certs_result.errors {
123 log::warn!("Got error while loading some native certificates: {err:?}");
124 }
125 if certs_result.certs.is_empty() {
126 return Err(io::Error::other(
127 "Could not load any valid native certificates",
128 ));
129 }
130 self.add_parsable_certificates(certs_result.certs);
131 Ok(self)
132 }
133
134 fn builder(self) -> io::Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
135 let builder = ClientConfig::builder();
136 #[cfg(feature = "platform-verifier")]
137 {
138 if self.platform_verifier {
139 let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
140 self.store,
141 builder.crypto_provider().clone(),
142 )
143 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
144 return Ok(builder
145 .dangerous()
146 .with_custom_certificate_verifier(Arc::new(verifier)));
147 }
148 }
149 let mut store = RootCertStore::empty();
150 let (_, ignored) = store.add_parsable_certificates(self.store);
151 if ignored > 0 {
152 log::warn!("{ignored} platform CA root certificates were ignored due to errors");
153 }
154 if store.is_empty() {
155 return Err(io::Error::other("Could not load any valid certificates"));
156 }
157 Ok(builder.with_root_certificates(store))
158 }
159
160 pub fn connector_with_no_client_auth(self) -> io::Result<RustlsConnector> {
166 Ok(self.builder()?.with_no_client_auth().into())
167 }
168
169 pub fn connector_with_single_cert(
178 self,
179 cert_chain: Vec<CertificateDer<'static>>,
180 key_der: PrivateKeyDer<'static>,
181 ) -> io::Result<RustlsConnector> {
182 Ok(self
183 .builder()?
184 .with_client_auth_cert(cert_chain, key_der)
185 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
186 .into())
187 }
188}
189
190impl Default for RustlsConnectorConfig {
191 fn default() -> Self {
192 Self {
193 store: Vec::new(),
194 platform_verifier: false,
195 }
196 }
197}
198
199#[derive(Clone)]
201pub struct RustlsConnector(Arc<ClientConfig>);
202
203impl Default for RustlsConnector {
204 fn default() -> Self {
205 RustlsConnectorConfig::default()
206 .connector_with_no_client_auth()
207 .expect("no error codepath for default RustlsConnectorConfig")
208 }
209}
210
211impl From<ClientConfig> for RustlsConnector {
212 fn from(config: ClientConfig) -> Self {
213 Arc::new(config).into()
214 }
215}
216
217impl From<Arc<ClientConfig>> for RustlsConnector {
218 fn from(config: Arc<ClientConfig>) -> Self {
219 Self(config)
220 }
221}
222
223impl RustlsConnector {
224 #[cfg(feature = "webpki-root-certs")]
225 pub fn new_with_webpki_root_certs() -> io::Result<Self> {
231 RustlsConnectorConfig::new_with_webpki_root_certs().connector_with_no_client_auth()
232 }
233
234 #[cfg(feature = "platform-verifier")]
235 pub fn new_with_platform_verifier() -> io::Result<Self> {
241 RustlsConnectorConfig::new_with_platform_verifier().connector_with_no_client_auth()
242 }
243
244 #[cfg(feature = "native-certs")]
245 pub fn new_with_native_certs() -> io::Result<Self> {
251 RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth()
252 }
253
254 pub fn connect<S: Read + Write + Send + 'static>(
261 &self,
262 domain: &str,
263 stream: S,
264 ) -> Result<TlsStream<S>, HandshakeError<S>> {
265 let session = ClientConnection::new(
266 self.0.clone(),
267 server_name(domain).map_err(HandshakeError::Failure)?,
268 )
269 .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
270 MidHandshakeTlsStream { session, stream }.handshake()
271 }
272
273 #[cfg(feature = "futures")]
274 pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
280 &self,
281 domain: &str,
282 stream: S,
283 ) -> io::Result<AsyncTlsStream<S>> {
284 futures_rustls::TlsConnector::from(self.0.clone())
285 .connect(server_name(domain)?, stream)
286 .await
287 }
288}
289
290fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
291 Ok(ServerName::try_from(domain)
292 .map_err(|err| {
293 io::Error::new(
294 io::ErrorKind::InvalidData,
295 format!("Invalid domain name ({err:?}): {domain}"),
296 )
297 })?
298 .to_owned())
299}
300
301#[derive(Debug)]
303pub struct MidHandshakeTlsStream<S: Read + Write> {
304 session: ClientConnection,
305 stream: S,
306}
307
308impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
309 pub fn get_ref(&self) -> &S {
311 &self.stream
312 }
313
314 pub fn get_mut(&mut self) -> &mut S {
316 &mut self.stream
317 }
318
319 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
326 if let Err(e) = self.session.complete_io(&mut self.stream) {
327 if e.kind() == io::ErrorKind::WouldBlock {
328 if self.session.is_handshaking() {
329 return Err(HandshakeError::WouldBlock(self));
330 }
331 } else {
332 return Err(e.into());
333 }
334 }
335 Ok(TlsStream::new(self.session, self.stream))
336 }
337}
338
339impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
340 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341 f.write_str("MidHandshakeTlsStream")
342 }
343}
344
345pub enum HandshakeError<S: Read + Write + Send + 'static> {
347 WouldBlock(MidHandshakeTlsStream<S>),
350 Failure(io::Error),
352}
353
354impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
355 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
356 match self {
357 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
358 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
359 }
360 }
361}
362
363impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
364 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365 let mut d = f.debug_tuple("HandshakeError");
366 match self {
367 HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
368 HandshakeError::Failure(err) => d.field(&err),
369 }
370 .finish()
371 }
372}
373
374impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
375 fn source(&self) -> Option<&(dyn Error + 'static)> {
376 match self {
377 HandshakeError::Failure(err) => Some(err),
378 _ => None,
379 }
380 }
381}
382
383impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
384 fn from(err: io::Error) -> Self {
385 HandshakeError::Failure(err)
386 }
387}