1#![deny(missing_docs)]
2
3pub use rustls;
31#[cfg(feature = "native-certs")]
32pub use rustls_native_certs;
33pub use rustls_pki_types;
34pub use webpki;
35#[cfg(feature = "webpki-roots-certs")]
36pub use webpki_roots;
37
38use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
39use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
40
41use std::{
42 error::Error,
43 fmt::{self, Debug},
44 io::{self, Read, Write},
45 sync::Arc,
46};
47
48pub type TlsStream<S> = StreamOwned<ClientConnection, S>;
50
51#[derive(Clone)]
53pub struct RustlsConnectorConfig(RootCertStore);
54
55impl RustlsConnectorConfig {
56 #[cfg(feature = "webpki-roots-certs")]
57 pub fn new_with_webpki_roots_certs() -> Self {
59 Self(RootCertStore {
60 roots: webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect(),
61 })
62 }
63
64 #[cfg(feature = "native-certs")]
65 pub fn new_with_native_certs() -> io::Result<Self> {
71 let mut root_store = RootCertStore::empty();
72 for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
73 {
74 if let Err(err) = root_store.add(cert) {
75 log::warn!(
76 "Got error while importing some native certificates: {:?}",
77 err
78 );
79 }
80 }
81 Ok(Self(root_store))
82 }
83
84 pub fn add_parsable_certificates(
90 &mut self,
91 der_certs: Vec<CertificateDer<'_>>,
92 ) -> (usize, usize) {
93 self.0.add_parsable_certificates(der_certs)
94 }
95
96 pub fn connector_with_no_client_auth(self) -> RustlsConnector {
98 ClientConfig::builder()
99 .with_root_certificates(self.0)
100 .with_no_client_auth()
101 .into()
102 }
103
104 pub fn connector_with_single_cert(
111 self,
112 cert_chain: Vec<CertificateDer<'static>>,
113 key_der: PrivateKeyDer<'static>,
114 ) -> io::Result<RustlsConnector> {
115 Ok(ClientConfig::builder()
116 .with_root_certificates(self.0)
117 .with_client_auth_cert(cert_chain, key_der)
118 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
119 .into())
120 }
121}
122
123impl Default for RustlsConnectorConfig {
124 fn default() -> Self {
125 Self(RootCertStore::empty())
126 }
127}
128
129#[derive(Clone)]
131pub struct RustlsConnector(Arc<ClientConfig>);
132
133impl Default for RustlsConnector {
134 fn default() -> Self {
135 RustlsConnectorConfig::default().connector_with_no_client_auth()
136 }
137}
138
139impl From<ClientConfig> for RustlsConnector {
140 fn from(config: ClientConfig) -> Self {
141 Arc::new(config).into()
142 }
143}
144
145impl From<Arc<ClientConfig>> for RustlsConnector {
146 fn from(config: Arc<ClientConfig>) -> Self {
147 Self(config)
148 }
149}
150
151impl RustlsConnector {
152 #[cfg(feature = "webpki-roots-certs")]
153 pub fn new_with_webpki_roots_certs() -> Self {
155 RustlsConnectorConfig::new_with_webpki_roots_certs().connector_with_no_client_auth()
156 }
157
158 #[cfg(feature = "native-certs")]
159 pub fn new_with_native_certs() -> io::Result<Self> {
165 Ok(RustlsConnectorConfig::new_with_native_certs()?.connector_with_no_client_auth())
166 }
167
168 pub fn connect<S: Debug + Read + Send + Sync + Write + 'static>(
175 &self,
176 domain: &str,
177 stream: S,
178 ) -> Result<TlsStream<S>, HandshakeError<S>> {
179 let session = ClientConnection::new(
180 self.0.clone(),
181 ServerName::try_from(domain)
182 .map_err(|err| {
183 HandshakeError::Failure(io::Error::new(
184 io::ErrorKind::InvalidData,
185 format!("Invalid domain name ({:?}): {}", err, domain),
186 ))
187 })?
188 .to_owned(),
189 )
190 .map_err(|err| io::Error::new(io::ErrorKind::ConnectionAborted, err))?;
191 MidHandshakeTlsStream { session, stream }.handshake()
192 }
193}
194
195#[derive(Debug)]
197pub struct MidHandshakeTlsStream<S: Read + Write> {
198 session: ClientConnection,
199 stream: S,
200}
201
202impl<S: Debug + Read + Send + Sync + Write + 'static> MidHandshakeTlsStream<S> {
203 pub fn get_ref(&self) -> &S {
205 &self.stream
206 }
207
208 pub fn get_mut(&mut self) -> &mut S {
210 &mut self.stream
211 }
212
213 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
220 if let Err(e) = self.session.complete_io(&mut self.stream) {
221 if e.kind() == io::ErrorKind::WouldBlock {
222 if self.session.is_handshaking() {
223 return Err(HandshakeError::WouldBlock(Box::new(self)));
224 }
225 } else {
226 return Err(e.into());
227 }
228 }
229 Ok(TlsStream::new(self.session, self.stream))
230 }
231}
232
233impl<S: Read + Write> fmt::Display for MidHandshakeTlsStream<S> {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 f.write_str("MidHandshakeTlsStream")
236 }
237}
238
239#[derive(Debug)]
241pub enum HandshakeError<S: Read + Send + Sync + Write + 'static> {
242 WouldBlock(Box<MidHandshakeTlsStream<S>>),
245 Failure(io::Error),
247}
248
249impl<S: Debug + Read + Send + Sync + Write + 'static> fmt::Display for HandshakeError<S> {
250 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 match self {
252 HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
253 HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {}", err)),
254 }
255 }
256}
257
258impl<S: Debug + Read + Send + Sync + Write + 'static> Error for HandshakeError<S> {
259 fn source(&self) -> Option<&(dyn Error + 'static)> {
260 match self {
261 HandshakeError::Failure(err) => Some(err),
262 _ => None,
263 }
264 }
265}
266
267impl<S: Debug + Read + Send + Sync + Write + 'static> From<io::Error> for HandshakeError<S> {
268 fn from(err: io::Error) -> Self {
269 HandshakeError::Failure(err)
270 }
271}