1#![deny(missing_docs)]
2
3pub use rustls;
31#[cfg(feature = "native-certs")]
32pub use rustls_native_certs;
33pub use rustls_pki_types;
34pub use webpki;
35#[cfg(feature = "webpki-roots-certs")]
36pub use webpki_roots;
37
38#[cfg(feature = "futures")]
39use futures_io::{AsyncRead, AsyncWrite};
40use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
41use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
42
43use std::{
44 error::Error,
45 fmt,
46 io::{self, Read, Write},
47 sync::Arc,
48};
49
50pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
52
53#[cfg(feature = "futures")]
54pub type AsyncTlsStream<S> = futures_rustls::client::TlsStream<S>;
56
57#[derive(Clone)]
59pub struct RustlsConnectorConfig(RootCertStore);
60
61impl RustlsConnectorConfig {
62 #[cfg(feature = "webpki-roots-certs")]
63 pub fn new_with_webpki_roots_certs() -> Self {
65 Self(RootCertStore {
66 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
67 })
68 }
69
70 #[cfg(feature = "native-certs")]
71 pub fn new_with_native_certs() -> io::Result<Self> {
77 let mut root_store = RootCertStore::empty();
78 let mut certs_result = rustls_native_certs::load_native_certs();
79 if let Some(err) = certs_result.errors.pop() {
80 return Err(io::Error::other(err));
81 }
82 for cert in certs_result.certs {
83 if let Err(err) = root_store.add(cert) {
84 log::warn!("Got error while importing some native certificates: {err:?}");
85 }
86 }
87 Ok(Self(root_store))
88 }
89
90 pub fn add_parsable_certificates(
96 &mut self,
97 der_certs: Vec<CertificateDer<'_>>,
98 ) -> (usize, usize) {
99 self.0.add_parsable_certificates(der_certs)
100 }
101
102 pub fn connector_with_no_client_auth(self) -> RustlsConnector {
104 ClientConfig::builder()
105 .with_root_certificates(self.0)
106 .with_no_client_auth()
107 .into()
108 }
109
110 pub fn connector_with_single_cert(
117 self,
118 cert_chain: Vec<CertificateDer<'static>>,
119 key_der: PrivateKeyDer<'static>,
120 ) -> io::Result<RustlsConnector> {
121 Ok(ClientConfig::builder()
122 .with_root_certificates(self.0)
123 .with_client_auth_cert(cert_chain, key_der)
124 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
125 .into())
126 }
127}
128
129impl Default for RustlsConnectorConfig {
130 fn default() -> Self {
131 Self(RootCertStore::empty())
132 }
133}
134
135#[derive(Clone)]
137pub struct RustlsConnector(Arc<ClientConfig>);
138
139impl Default for RustlsConnector {
140 fn default() -> Self {
141 RustlsConnectorConfig::default().connector_with_no_client_auth()
142 }
143}
144
145impl From<ClientConfig> for RustlsConnector {
146 fn from(config: ClientConfig) -> Self {
147 Arc::new(config).into()
148 }
149}
150
151impl From<Arc<ClientConfig>> for RustlsConnector {
152 fn from(config: Arc<ClientConfig>) -> Self {
153 Self(config)
154 }
155}
156
157impl RustlsConnector {
158 #[cfg(feature = "webpki-roots-certs")]
159 pub fn new_with_webpki_roots_certs() -> Self {
161 RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
162 }
163
164 #[cfg(feature = "native-certs")]
165 pub fn new_with_native_certs() -> io::Result<Self> {
171 Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
172 }
173
174 pub fn connect<S: Read + Write + Send + 'static>(
181 &self,
182 domain: &str,
183 stream: S,
184 ) -> Result<TlsStream<S>, HandshakeError<S>> {
185 let session = ClientConnection::new(
186 self.0.clone(),
187 server_name(domain).map_err(HandshakeError::Failure)?,
188 )
189 .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
190 MidHandshakeTlsStream { session, stream }.handshake()
191 }
192
193 #[cfg(feature = "futures")]
194 pub async fn connect_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
200 &self,
201 domain: &str,
202 stream: S,
203 ) -> io::Result<AsyncTlsStream<S>> {
204 futures_rustls::TlsConnector::from(self.0.clone())
205 .connect(server_name(domain)?, stream)
206 .await
207 }
208}
209
210fn server_name(domain: &str) -> io::Result<ServerName<'static>> {
211 Ok(ServerName::try_from(domain)
212 .map_err(|err| {
213 io::Error::new(
214 io::ErrorKind::InvalidData,
215 format!("Invalid domain name ({err:?}): {domain}"),
216 )
217 })?
218 .to_owned())
219}
220
221#[derive(Debug)]
223pub struct MidHandshakeTlsStream<S: Read + Write> {
224 session: ClientConnection,
225 stream: S,
226}
227
228impl<S: Read + Send + Write + 'static> MidHandshakeTlsStream<S> {
229 pub fn get_ref(&self) -> &S {
231 &self.stream
232 }
233
234 pub fn get_mut(&mut self) -> &mut S {
236 &mut self.stream
237 }
238
239 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
246 if let Err(e) = self.session.complete_io(&mut self.stream) {
247 if e.kind() == io::ErrorKind::WouldBlock {
248 if self.session.is_handshaking() {
249 return Err(HandshakeError::WouldBlock(Box::new(self)));
250 }
251 } else {
252 return Err(e.into());
253 }
254 }
255 Ok(TlsStream::new(self.session, self.stream))
256 }
257}
258
259impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 f.write_str("MidHandshakeTlsStream")
262 }
263}
264
265pub enum HandshakeError<S: Read + Write + Send + 'static> {
267 WouldBlock(Box<MidHandshakeTlsStream<S>>),
270 Failure(io::Error),
272}
273
274impl<S: Read + Write + Send + 'static> fmt::Display for HandshakeError<S> {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 match self {
277 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
278 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
279 }
280 }
281}
282
283impl<S: Read + Write + Send + 'static> fmt::Debug for HandshakeError<S> {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 let mut d = f.debug_tuple("HandshakeError");
286 match self {
287 HandshakeError::WouldBlock(_) => d.field(&"WouldBlock"),
288 HandshakeError::Failure(err) => d.field(&err),
289 }
290 .finish()
291 }
292}
293
294impl<S: Read + Write + Send + 'static> Error for HandshakeError<S> {
295 fn source(&self) -> Option<&(dyn Error + 'static)> {
296 match self {
297 HandshakeError::Failure(err) => Some(err),
298 _ => None,
299 }
300 }
301}
302
303impl<S: Read + Send + Write + 'static> From<io::Error> for HandshakeError<S> {
304 fn from(err: io::Error) -> Self {
305 HandshakeError::Failure(err)
306 }
307}