tcp_stream/
rustls_impl.rs

1use crate::{
2    HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, StdTcpStream, TLSConfig,
3    TcpStream,
4};
5
6#[cfg(feature = "rustls-futures")]
7use {
8    crate::AsyncTcpStream,
9    futures_io::{AsyncRead, AsyncWrite},
10};
11
12use rustls_connector::rustls_pki_types::{
13    CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, pem::PemObject,
14};
15use std::io;
16
17/// Reexport rustls-connector's `TlsConnector`
18pub use rustls_connector::{RustlsConnector, RustlsConnectorConfig};
19
20/// A `TcpStream` wrapped by rustls
21pub type RustlsStream = rustls_connector::TlsStream<StdTcpStream>;
22
23/// A `MidHandshakeTlsStream` from rustls-connector
24pub type RustlsMidHandshakeTlsStream = rustls_connector::MidHandshakeTlsStream<StdTcpStream>;
25
26/// A `HandshakeError` from rustls-connector
27pub type RustlsHandshakeError = rustls_connector::HandshakeError<StdTcpStream>;
28
29#[cfg(feature = "rustls-futures")]
30/// An async `TcpStream` wrapped by rustls
31pub type RustlsAsyncStream<S> = rustls_connector::AsyncTlsStream<S>;
32
33fn update_rustls_config(
34    c: &mut RustlsConnectorConfig,
35    config: &TLSConfig<'_, '_, '_>,
36) -> io::Result<()> {
37    if let Some(cert_chain) = config.cert_chain {
38        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
39        let certs = rustls_pemfile::certs(&mut cert_chain)
40            .collect::<Result<Vec<_>, _>>()
41            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
42        c.add_parsable_certificates(certs);
43    }
44    Ok(())
45}
46
47fn rustls_identity(
48    identity: Identity<'_, '_>,
49) -> io::Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
50    let (certs, key) = match identity {
51        Identity::PKCS12 { der, password } => {
52            let pfx =
53                p12_keystore::KeyStore::from_pkcs12(der, password).map_err(io::Error::other)?;
54            let Some((_, keychain)) = pfx.private_key_chain() else {
55                return Err(io::Error::other("No private key in pkcs12 DER"));
56            };
57            let certs = keychain
58                .chain()
59                .iter()
60                .map(|cert| CertificateDer::from(cert.as_der().to_vec()))
61                .collect();
62            (
63                certs,
64                PrivateKeyDer::from(PrivatePkcs8KeyDer::from(keychain.key().to_vec())),
65            )
66        }
67        Identity::PKCS8 { pem, key } => {
68            let mut cert_reader = std::io::BufReader::new(pem);
69            let certs = rustls_pemfile::certs(&mut cert_reader)
70                .collect::<Result<Vec<_>, _>>()
71                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
72            (
73                certs,
74                PrivateKeyDer::from_pem_slice(key)
75                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
76            )
77        }
78    };
79    Ok((certs, key))
80}
81
82fn rustls_connector(
83    mut c: RustlsConnectorConfig,
84    config: TLSConfig<'_, '_, '_>,
85) -> io::Result<RustlsConnector> {
86    update_rustls_config(&mut c, &config)?;
87
88    let connector = if let Some(identity) = config.identity {
89        let (certs, key) = rustls_identity(identity)?;
90        c.connector_with_single_cert(certs, key)
91            .map_err(io::Error::other)?
92    } else {
93        c.connector_with_no_client_auth()
94    };
95    Ok(connector)
96}
97
98#[allow(dead_code)]
99pub(crate) fn into_rustls_impl(
100    s: TcpStream,
101    c: RustlsConnectorConfig,
102    domain: &str,
103    config: TLSConfig<'_, '_, '_>,
104) -> HandshakeResult {
105    s.into_rustls(&rustls_connector(c, config)?, domain)
106}
107
108#[cfg(feature = "rustls-futures")]
109#[allow(dead_code)]
110pub(crate) async fn into_rustls_impl_async<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
111    s: AsyncTcpStream<S>,
112    c: RustlsConnectorConfig,
113    domain: &str,
114    config: TLSConfig<'_, '_, '_>,
115) -> io::Result<AsyncTcpStream<S>> {
116    s.into_rustls(&rustls_connector(c, config)?, domain).await
117}
118
119impl From<RustlsStream> for TcpStream {
120    fn from(s: RustlsStream) -> Self {
121        Self::Rustls(s)
122    }
123}
124
125impl From<RustlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
126    fn from(mid: RustlsMidHandshakeTlsStream) -> Self {
127        Self::Rustls(mid)
128    }
129}
130
131impl From<RustlsHandshakeError> for HandshakeError {
132    fn from(error: RustlsHandshakeError) -> Self {
133        match error {
134            rustls_connector::HandshakeError::WouldBlock(mid) => Self::WouldBlock(mid.into()),
135            rustls_connector::HandshakeError::Failure(failure) => Self::Failure(failure),
136        }
137    }
138}