1#![deny(missing_docs)]
2#![allow(clippy::result_large_err, clippy::large_enum_variant)]
3
4pub use rustls;
32#[cfg(feature = "native-certs")]
33pub use rustls_native_certs;
34pub use rustls_pki_types;
35pub use webpki;
36#[cfg(feature = "webpki-roots-certs")]
37pub use webpki_roots;
38
39#[cfg(feature = "futures")]
40use futures_io::{AsyncRead, AsyncWrite};
41use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
42use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
43
44use std::{
45 error::Error,
46 fmt,
47 io::{self, Read, Write},
48 sync::Arc,
49};
50
51pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
53
54#[cfg(feature = "futures")]
55pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
57
58#[derive(Clone)]
60pub struct RustlsConnectorConfig(RootCertStore);
61
62impl RustlsConnectorConfig {
63 #[cfg(feature = "webpki-roots-certs")]
64 pub fn new_with_webpki_roots_certs() -> Self {
66 Self(RootCertStore {
67 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
68 })
69 }
70
71 #[cfg(feature = "native-certs")]
72 pub fn new_with_native_certs() -> io::Result<Self> {
78 let mut root_store = RootCertStore::empty();
79 let mut certs_result = rustls_native_certs::load_native_certs();
80 if let Some(err) = certs_result.errors.pop() {
81 return Err(io::Error::other(err));
82 }
83 for cert in certs_result.certs {
84 if let Err(err) = root_store.add(cert) {
85 log::warn!("Got error while importing some native certificates: {err:?}");
86 }
87 }
88 Ok(Self(root_store))
89 }
90
91 pub fn add_parsable_certificates(
97 &mut self,
98 der_certs: Vec<CertificateDer<'_>>,
99 ) -> (usize, usize) {
100 self.0.add_parsable_certificates(der_certs)
101 }
102
103 pub fn connector_with_no_client_auth(self) -> RustlsConnector {
105 ClientConfig::builder()
106 .with_root_certificates(self.0)
107 .with_no_client_auth()
108 .into()
109 }
110
111 pub fn connector_with_single_cert(
118 self,
119 cert_chain: Vec<CertificateDer<'static>>,
120 key_der: PrivateKeyDer<'static>,
121 ) -> io::Result<RustlsConnector> {
122 Ok(ClientConfig::builder()
123 .with_root_certificates(self.0)
124 .with_client_auth_cert(cert_chain, key_der)
125 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
126 .into())
127 }
128}
129
130impl Default for RustlsConnectorConfig {
131 fn default() -> Self {
132 Self(RootCertStore::empty())
133 }
134}
135
136#[derive(Clone)]
138pub struct RustlsConnector(Arc<ClientConfig>);
139
140impl Default for RustlsConnector {
141 fn default() -> Self {
142 RustlsConnectorConfig::default().connector_with_no_client_auth()
143 }
144}
145
146impl From<ClientConfig> for RustlsConnector {
147 fn from(config: ClientConfig) -> Self {
148 Arc::new(config).into()
149 }
150}
151
152impl From<Arc<ClientConfig>> for RustlsConnector {
153 fn from(config: Arc<ClientConfig>) -> Self {
154 Self(config)
155 }
156}
157
158impl RustlsConnector {
159 #[cfg(feature = "webpki-roots-certs")]
160 pub fn new_with_webpki_roots_certs() -> Self {
162 RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
163 }
164
165 #[cfg(feature = "native-certs")]
166 pub fn new_with_native_certs() -> io::Result<Self> {
172 Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
173 }
174
175 pub fn connect<S: Read + Write + Send + 'static>(
182 &self,
183 domain: &str,
184 stream: S,
185 ) -> Result<TlsStream<S>, HandshakeError<S>> {
186 let session = ClientConnection::new(
187 self.0.clone(),
188 server_name(domain).map_err(HandshakeError::Failure)?,
189 )
190 .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
191 MidHandshakeTlsStream { session, stream }.handshake()
192 }
193
194 #[cfg(feature = "futures")]
195 pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
201 &self,
202 domain: &str,
203 stream: S,
204 ) -> io::Result<AsyncTlsStream<S>> {
205 futures_rustls::TlsConnector::from(self.0.clone())
206 .connect(server_name(domain)?, stream)
207 .await
208 }
209}
210
211fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
212 Ok(ServerName::try_from(domain)
213 .map_err(|err| {
214 io::Error::new(
215 io::ErrorKind::InvalidData,
216 format!("Invalid domain name ({err:?}): {domain}"),
217 )
218 })?
219 .to_owned())
220}
221
222#[derive(Debug)]
224pub struct MidHandshakeTlsStream<S: Read + Write> {
225 session: ClientConnection,
226 stream: S,
227}
228
229impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
230 pub fn get_ref(&self) -> &S {
232 &self.stream
233 }
234
235 pub fn get_mut(&mut self) -> &mut S {
237 &mut self.stream
238 }
239
240 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
247 if let Err(e) = self.session.complete_io(&mut self.stream) {
248 if e.kind() == io::ErrorKind::WouldBlock {
249 if self.session.is_handshaking() {
250 return Err(HandshakeError::WouldBlock(self));
251 }
252 } else {
253 return Err(e.into());
254 }
255 }
256 Ok(TlsStream::new(self.session, self.stream))
257 }
258}
259
260impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
261 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262 f.write_str("MidHandshakeTlsStream")
263 }
264}
265
266pub enum HandshakeError<S: Read + Write + Send + 'static> {
268 WouldBlock(MidHandshakeTlsStream<S>),
271 Failure(io::Error),
273}
274
275impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
276 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277 match self {
278 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
279 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
280 }
281 }
282}
283
284impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
285 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286 let mut d = f.debug_tuple("HandshakeError");
287 match self {
288 HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
289 HandshakeError::Failure(err) => d.field(&err),
290 }
291 .finish()
292 }
293}
294
295impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
296 fn source(&self) -> Option<&(dyn Error + 'static)> {
297 match self {
298 HandshakeError::Failure(err) => Some(err),
299 _ => None,
300 }
301 }
302}
303
304impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
305 fn from(err: io::Error) -> Self {
306 HandshakeError::Failure(err)
307 }
308}