1#![deny(missing_docs)]
2#![allow(clippy::result_large_err, clippy::large_enum_variant)]
3
4pub use rustls;
32#[cfg(feature = "native-certs")]
33pub use rustls_native_certs;
34pub use rustls_pki_types;
35pub use webpki;
36#[cfg(feature = "webpki-roots-certs")]
37pub use webpki_root_certs;
38
39#[cfg(feature = "futures")]
40use futures_io::{AsyncRead, AsyncWrite};
41use rustls::{
42 ClientConfig, ClientConnection, ConfigBuilder, RootCertStore, StreamOwned,
43 client::WantsClientCert,
44};
45use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
46
47use std::{
48 error::Error,
49 fmt,
50 io::{self, Read, Write},
51 sync::Arc,
52};
53
54pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
56
57#[cfg(feature = "futures")]
58pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
60
61#[derive(Clone)]
63pub struct RustlsConnectorConfig(RootCertStore);
64
65impl RustlsConnectorConfig {
66 #[cfg(feature = "webpki-roots-certs")]
67 pub fn new_with_webpki_roots_certs() -> Self {
69 Self::default().with_webpki_root_certs()
70 }
71
72 #[cfg(feature = "native-certs")]
73 pub fn new_with_native_certs() -> io::Result<Self> {
79 let mut config = Self::default();
80 let (_, ignored) = config.register_native_certs()?;
81 if ignored > 0 {
82 log::warn!("{ignored} platform CA root certificates were ignored due to errors");
83 }
84 Ok(config)
85 }
86
87 pub fn add_parsable_certificates<'a>(
93 &mut self,
94 der_certs: impl IntoIterator<Item = CertificateDer<'a>>,
95 ) -> (usize, usize) {
96 self.0.add_parsable_certificates(der_certs)
97 }
98
99 #[cfg(feature = "webpki-roots-certs")]
100 pub fn with_webpki_root_certs(mut self) -> Self {
102 self.add_parsable_certificates(webpki_root_certs::TLS_SERVER_ROOT_CERTS.iter().cloned());
103 self
104 }
105
106 #[cfg(feature = "native-certs")]
107 pub fn register_native_certs(&mut self) -> io::Result<(usize, usize)> {
113 let certs_result = rustls_native_certs::load_native_certs();
114 for err in certs_result.errors {
115 log::warn!("Got error while loading some native certificates: {err:?}");
116 }
117 let (added, ignored) = self.add_parsable_certificates(certs_result.certs);
118 if self.0.is_empty() {
119 return Err(io::Error::other(
120 "Could not load any valid native certificates",
121 ));
122 }
123 Ok((added, ignored))
124 }
125
126 fn builder(self) -> ConfigBuilder<ClientConfig, WantsClientCert> {
127 ClientConfig::builder().with_root_certificates(self.0)
128 }
129
130 pub fn connector_with_no_client_auth(self) -> RustlsConnector {
132 self.builder().with_no_client_auth().into()
133 }
134
135 pub fn connector_with_single_cert(
142 self,
143 cert_chain: Vec<CertificateDer<'static>>,
144 key_der: PrivateKeyDer<'static>,
145 ) -> io::Result<RustlsConnector> {
146 Ok(self
147 .builder()
148 .with_client_auth_cert(cert_chain, key_der)
149 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
150 .into())
151 }
152}
153
154impl Default for RustlsConnectorConfig {
155 fn default() -> Self {
156 Self(RootCertStore::empty())
157 }
158}
159
160#[derive(Clone)]
162pub struct RustlsConnector(Arc<ClientConfig>);
163
164impl Default for RustlsConnector {
165 fn default() -> Self {
166 RustlsConnectorConfig::default().connector_with_no_client_auth()
167 }
168}
169
170impl From<ClientConfig> for RustlsConnector {
171 fn from(config: ClientConfig) -> Self {
172 Arc::new(config).into()
173 }
174}
175
176impl From<Arc<ClientConfig>> for RustlsConnector {
177 fn from(config: Arc<ClientConfig>) -> Self {
178 Self(config)
179 }
180}
181
182impl RustlsConnector {
183 #[cfg(feature = "webpki-roots-certs")]
184 pub fn new_with_webpki_roots_certs() -> Self {
186 RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
187 }
188
189 #[cfg(feature = "native-certs")]
190 pub fn new_with_native_certs() -> io::Result<Self> {
196 Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
197 }
198
199 pub fn connect<S: Read + Write + Send + 'static>(
206 &self,
207 domain: &str,
208 stream: S,
209 ) -> Result<TlsStream<S>, HandshakeError<S>> {
210 let session = ClientConnection::new(
211 self.0.clone(),
212 server_name(domain).map_err(HandshakeError::Failure)?,
213 )
214 .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
215 MidHandshakeTlsStream { session, stream }.handshake()
216 }
217
218 #[cfg(feature = "futures")]
219 pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
225 &self,
226 domain: &str,
227 stream: S,
228 ) -> io::Result<AsyncTlsStream<S>> {
229 futures_rustls::TlsConnector::from(self.0.clone())
230 .connect(server_name(domain)?, stream)
231 .await
232 }
233}
234
235fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
236 Ok(ServerName::try_from(domain)
237 .map_err(|err| {
238 io::Error::new(
239 io::ErrorKind::InvalidData,
240 format!("Invalid domain name ({err:?}): {domain}"),
241 )
242 })?
243 .to_owned())
244}
245
246#[derive(Debug)]
248pub struct MidHandshakeTlsStream<S: Read + Write> {
249 session: ClientConnection,
250 stream: S,
251}
252
253impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
254 pub fn get_ref(&self) -> &S {
256 &self.stream
257 }
258
259 pub fn get_mut(&mut self) -> &mut S {
261 &mut self.stream
262 }
263
264 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
271 if let Err(e) = self.session.complete_io(&mut self.stream) {
272 if e.kind() == io::ErrorKind::WouldBlock {
273 if self.session.is_handshaking() {
274 return Err(HandshakeError::WouldBlock(self));
275 }
276 } else {
277 return Err(e.into());
278 }
279 }
280 Ok(TlsStream::new(self.session, self.stream))
281 }
282}
283
284impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
285 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286 f.write_str("MidHandshakeTlsStream")
287 }
288}
289
290pub enum HandshakeError<S: Read + Write + Send + 'static> {
292 WouldBlock(MidHandshakeTlsStream<S>),
295 Failure(io::Error),
297}
298
299impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
300 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301 match self {
302 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
303 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
304 }
305 }
306}
307
308impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
309 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
310 let mut d = f.debug_tuple("HandshakeError");
311 match self {
312 HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
313 HandshakeError::Failure(err) => d.field(&err),
314 }
315 .finish()
316 }
317}
318
319impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
320 fn source(&self) -> Option<&(dyn Error + 'static)> {
321 match self {
322 HandshakeError::Failure(err) => Some(err),
323 _ => None,
324 }
325 }
326}
327
328impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
329 fn from(err: io::Error) -> Self {
330 HandshakeError::Failure(err)
331 }
332}