1#![deny(missing_docs)]
2
3pub use rustls;
31#[cfg(feature = "native-certs")]
32pub use rustls_native_certs;
33pub use rustls_pki_types;
34pub use webpki;
35#[cfg(feature = "webpki-roots-certs")]
36pub use webpki_roots;
37
38#[cfg(feature = "futures")]
39use futures_io::{AsyncRead, AsyncWrite};
40use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
41use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
42
43use std::{
44 error::Error,
45 fmt::{self, Debug},
46 io::{self, Read, Write},
47 sync::Arc,
48};
49
50pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
52
53#[cfg(feature = "futures")]
54pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
56
57#[derive(Clone)]
59pub struct RustlsConnectorConfig(RootCertStore);
60
61impl RustlsConnectorConfig {
62 #[cfg(feature = "webpki-roots-certs")]
63 pub fn new_with_webpki_roots_certs() -> Self {
65 Self(RootCertStore {
66 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
67 })
68 }
69
70 #[cfg(feature = "native-certs")]
71 pub fn new_with_native_certs() -> io::Result<Self> {
77 let mut root_store = RootCertStore::empty();
78 let mut certs_result = rustls_native_certs::load_native_certs();
79 if let Some(err) = certs_result.errors.pop() {
80 return Err(io::Error::other(err));
81 }
82 for cert in certs_result.certs {
83 if let Err(err) = root_store.add(cert) {
84 log::warn!("Got error while importing some native certificates: {err:?}");
85 }
86 }
87 Ok(Self(root_store))
88 }
89
90 pub fn add_parsable_certificates(
96 &mut self,
97 der_certs: Vec<CertificateDer<'_>>,
98 ) -> (usize, usize) {
99 self.0.add_parsable_certificates(der_certs)
100 }
101
102 pub fn connector_with_no_client_auth(self) -> RustlsConnector {
104 ClientConfig::builder()
105 .with_root_certificates(self.0)
106 .with_no_client_auth()
107 .into()
108 }
109
110 pub fn connector_with_single_cert(
117 self,
118 cert_chain: Vec<CertificateDer<'static>>,
119 key_der: PrivateKeyDer<'static>,
120 ) -> io::Result<RustlsConnector> {
121 Ok(ClientConfig::builder()
122 .with_root_certificates(self.0)
123 .with_client_auth_cert(cert_chain, key_der)
124 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
125 .into())
126 }
127}
128
129impl Default for RustlsConnectorConfig {
130 fn default() -> Self {
131 Self(RootCertStore::empty())
132 }
133}
134
135#[derive(Clone)]
137pub struct RustlsConnector(Arc<ClientConfig>);
138
139impl Default for RustlsConnector {
140 fn default() -> Self {
141 RustlsConnectorConfig::default().connector_with_no_client_auth()
142 }
143}
144
145impl From<ClientConfig> for RustlsConnector {
146 fn from(config: ClientConfig) -> Self {
147 Arc::new(config).into()
148 }
149}
150
151impl From<Arc<ClientConfig>> for RustlsConnector {
152 fn from(config: Arc<ClientConfig>) -> Self {
153 Self(config)
154 }
155}
156
157impl RustlsConnector {
158 #[cfg(feature = "webpki-roots-certs")]
159 pub fn new_with_webpki_roots_certs() -> Self {
161 RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
162 }
163
164 #[cfg(feature = "native-certs")]
165 pub fn new_with_native_certs() -> io::Result<Self> {
171 Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
172 }
173
174 pub fn connect<S: Debug + Read + Send + Sync + Write + 'static>(
181 &self,
182 domain: &str,
183 stream: S,
184 ) -> Result<TlsStream<S>, HandshakeError<S>> {
185 let session = ClientConnection::new(
186 self.0.clone(),
187 server_name(domain).map_err(HandshakeError::Failure)?,
188 )
189 .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
190 MidHandshakeTlsStream { session, stream }.handshake()
191 }
192
193 #[cfg(feature = "futures")]
194 pub async fn connect_async<
200 S: Debug + AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
201 >(
202 &self,
203 domain: &str,
204 stream: S,
205 ) -> io::Result<AsyncTlsStream<S>> {
206 futures_rustls::TlsConnector::from(self.0.clone())
207 .connect(server_name(domain)?, stream)
208 .await
209 }
210}
211
212fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
213 Ok(ServerName::try_from(domain)
214 .map_err(|err| {
215 io::Error::new(
216 io::ErrorKind::InvalidData,
217 format!("Invalid domain name ({err:?}): {domain}"),
218 )
219 })?
220 .to_owned())
221}
222
223#[derive(Debug)]
225pub struct MidHandshakeTlsStream<S: Read + Write> {
226 session: ClientConnection,
227 stream: S,
228}
229
230impl<S: Debug + Read + Send + Sync + Write + 'static> MidHandshakeTlsStream<S> {
231 pub fn get_ref(&self) -> &S {
233 &self.stream
234 }
235
236 pub fn get_mut(&mut self) -> &mut S {
238 &mut self.stream
239 }
240
241 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
248 if let Err(e) = self.session.complete_io(&mut self.stream) {
249 if e.kind() == io::ErrorKind::WouldBlock {
250 if self.session.is_handshaking() {
251 return Err(HandshakeError::WouldBlock(Box::new(self)));
252 }
253 } else {
254 return Err(e.into());
255 }
256 }
257 Ok(TlsStream::new(self.session, self.stream))
258 }
259}
260
261impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 f.write_str("MidHandshakeTlsStream")
264 }
265}
266
267#[derive(Debug)]
269pub enum HandshakeError<S: Read + Send + Sync + Write + 'static> {
270 WouldBlock(Box<MidHandshakeTlsStream<S>>),
273 Failure(io::Error),
275}
276
277impl<S: Debug + Read + Send + Sync + Write + 'static> fmt::Display for HandshakeError<S> {
278 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279 match self {
280 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
281 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
282 }
283 }
284}
285
286impl<S: Debug + Read + Send + Sync + Write + 'static> Error for HandshakeError<S> {
287 fn source(&self) -> Option<&(dyn Error + 'static)> {
288 match self {
289 HandshakeError::Failure(err) => Some(err),
290 _ => None,
291 }
292 }
293}
294
295impl<S: Debug + Read + Send + Sync + Write + 'static> From<io::Error> for HandshakeError<S> {
296 fn from(err: io::Error) -> Self {
297 HandshakeError::Failure(err)
298 }
299}