Skip to main content

tcp_stream/
rustls_impl.rs

1use crate::{
2    HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, StdTcpStream, TLSConfig,
3    TcpStream,
4};
5
6#[cfg(feature = "rustls-futures")]
7use {
8    crate::AsyncTcpStream,
9    futures_io::{AsyncRead, AsyncWrite},
10};
11
12use rustls_connector::rustls_pki_types::{
13    CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, pem::PemObject,
14};
15use std::io;
16
17/// Reexport rustls-connector's `TlsConnector`
18pub use rustls_connector::{RustlsConnector, RustlsConnectorConfig};
19
20/// A `TcpStream` wrapped by rustls
21pub type RustlsStream = rustls_connector::TlsStream<StdTcpStream>;
22
23/// A `MidHandshakeTlsStream` from rustls-connector
24pub type RustlsMidHandshakeTlsStream = rustls_connector::MidHandshakeTlsStream<StdTcpStream>;
25
26/// A `HandshakeError` from rustls-connector
27pub type RustlsHandshakeError = rustls_connector::HandshakeError<StdTcpStream>;
28
29#[cfg(feature = "rustls-futures")]
30/// An async `TcpStream` wrapped by rustls
31pub type RustlsAsyncStream<S> = rustls_connector::AsyncTlsStream<S>;
32
33fn update_rustls_config(
34    c: &mut RustlsConnectorConfig,
35    config: &TLSConfig<'_, '_, '_>,
36) -> io::Result<()> {
37    if let Some(cert_chain) = config.cert_chain {
38        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
39        let certs = CertificateDer::pem_reader_iter(&mut cert_chain)
40            .collect::<Result<Vec<_>, _>>()
41            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
42        c.add_parsable_certificates(certs);
43    }
44    Ok(())
45}
46
47fn rustls_identity(
48    identity: Identity<'_, '_>,
49) -> io::Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
50    let (certs, key) = match identity {
51        Identity::PKCS12 { der, password } => {
52            let pfx =
53                p12_keystore::KeyStore::from_pkcs12(der, password).map_err(io::Error::other)?;
54            let Some((_, keychain)) = pfx.private_key_chain() else {
55                return Err(io::Error::other("No private key in pkcs12 DER"));
56            };
57            let certs = keychain
58                .chain()
59                .iter()
60                .map(|cert| CertificateDer::from(cert.as_der().to_vec()))
61                .collect();
62            (
63                certs,
64                PrivateKeyDer::from(PrivatePkcs8KeyDer::from(keychain.key().to_vec())),
65            )
66        }
67        Identity::PKCS8 { pem, key } => {
68            let mut cert_reader = std::io::BufReader::new(pem);
69            let certs = CertificateDer::pem_reader_iter(&mut cert_reader)
70                .collect::<Result<Vec<_>, _>>()
71                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
72            (
73                certs,
74                PrivateKeyDer::from_pem_slice(key)
75                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
76            )
77        }
78    };
79    Ok((certs, key))
80}
81
82fn rustls_connector(
83    mut c: RustlsConnectorConfig,
84    config: TLSConfig<'_, '_, '_>,
85) -> io::Result<RustlsConnector> {
86    update_rustls_config(&mut c, &config)?;
87
88    let connector = if let Some(identity) = config.identity {
89        let (certs, key) = rustls_identity(identity)?;
90        c.connector_with_single_cert(certs, key)?
91    } else {
92        c.connector_with_no_client_auth()?
93    };
94    Ok(connector)
95}
96
97#[allow(dead_code)]
98pub(crate) fn into_rustls_impl(
99    s: TcpStream,
100    c: RustlsConnectorConfig,
101    domain: &str,
102    config: TLSConfig<'_, '_, '_>,
103) -> HandshakeResult {
104    s.into_rustls(&rustls_connector(c, config)?, domain)
105}
106
107#[cfg(feature = "rustls-futures")]
108#[allow(dead_code)]
109pub(crate) async fn into_rustls_impl_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
110    s: AsyncTcpStream<S>,
111    c: RustlsConnectorConfig,
112    domain: &str,
113    config: TLSConfig<'_, '_, '_>,
114) -> io::Result<AsyncTcpStream<S>> {
115    s.into_rustls(&rustls_connector(c, config)?, domain).await
116}
117
118impl From<RustlsStream> for TcpStream {
119    fn from(s: RustlsStream) -> Self {
120        Self::Rustls(s)
121    }
122}
123
124impl From<RustlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
125    fn from(mid: RustlsMidHandshakeTlsStream) -> Self {
126        Self::Rustls(mid)
127    }
128}
129
130impl From<RustlsHandshakeError> for HandshakeError {
131    fn from(error: RustlsHandshakeError) -> Self {
132        match error {
133            rustls_connector::HandshakeError::WouldBlock(mid) => Self::WouldBlock(mid.into()),
134            rustls_connector::HandshakeError::Failure(failure) => Self::Failure(failure),
135        }
136    }
137}