tcp_stream/
rustls_impl.rs

1use crate::{
2    HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, TLSConfig, TcpStream,
3};
4
5#[cfg(feature = "rustls-futures")]
6use crate::AsyncTcpStream;
7
8use rustls_connector::rustls_pki_types::{
9    CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, pem::PemObject,
10};
11use std::io;
12
13/// Reexport rustls-connector's `TlsConnector`
14pub use rustls_connector::{RustlsConnector, RustlsConnectorConfig};
15
16/// A `TcpStream` wrapped by rustls
17pub type RustlsStream = rustls_connector::TlsStream<TcpStream>;
18
19/// A `MidHandshakeTlsStream` from rustls-connector
20pub type RustlsMidHandshakeTlsStream = rustls_connector::MidHandshakeTlsStream<TcpStream>;
21
22/// A `HandshakeError` from rustls-connector
23pub type RustlsHandshakeError = rustls_connector::HandshakeError<TcpStream>;
24
25fn update_rustls_config(
26    c: &mut RustlsConnectorConfig,
27    config: &TLSConfig<'_, '_, '_>,
28) -> io::Result<()> {
29    if let Some(cert_chain) = config.cert_chain {
30        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
31        let certs = rustls_pemfile::certs(&mut cert_chain)
32            .collect::<Result<Vec<_>, _>>()
33            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
34        c.add_parsable_certificates(certs);
35    }
36    Ok(())
37}
38
39fn rustls_identity(
40    identity: Identity<'_, '_>,
41) -> io::Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
42    let (certs, key) = match identity {
43        Identity::PKCS12 { der, password } => {
44            let pfx =
45                p12_keystore::KeyStore::from_pkcs12(der, password).map_err(io::Error::other)?;
46            let Some((_, keychain)) = pfx.private_key_chain() else {
47                return Err(io::Error::other("No private key in pkcs12 DER"));
48            };
49            let certs = keychain
50                .chain()
51                .iter()
52                .map(|cert| CertificateDer::from(cert.as_der().to_vec()))
53                .collect();
54            (
55                certs,
56                PrivateKeyDer::from(PrivatePkcs8KeyDer::from(keychain.key().to_vec())),
57            )
58        }
59        Identity::PKCS8 { pem, key } => {
60            let mut cert_reader = std::io::BufReader::new(pem);
61            let certs = rustls_pemfile::certs(&mut cert_reader)
62                .collect::<Result<Vec<_>, _>>()
63                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
64            (
65                certs,
66                PrivateKeyDer::from_pem_slice(key)
67                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
68            )
69        }
70    };
71    Ok((certs, key))
72}
73
74fn rustls_connector(
75    mut c: RustlsConnectorConfig,
76    config: TLSConfig<'_, '_, '_>,
77) -> io::Result<RustlsConnector> {
78    update_rustls_config(&mut c, &config)?;
79
80    let connector = if let Some(identity) = config.identity {
81        let (certs, key) = rustls_identity(identity)?;
82        c.connector_with_single_cert(certs, key)
83            .map_err(io::Error::other)?
84    } else {
85        c.connector_with_no_client_auth()
86    };
87    Ok(connector)
88}
89
90#[allow(dead_code)]
91pub(crate) fn into_rustls_impl(
92    s: TcpStream,
93    c: RustlsConnectorConfig,
94    domain: &str,
95    config: TLSConfig<'_, '_, '_>,
96) -> HandshakeResult {
97    s.into_rustls(&rustls_connector(c, config)?, domain)
98}
99
100#[cfg(feature = "rustls-futures")]
101#[allow(dead_code)]
102pub(crate) async fn into_rustls_impl_async(
103    s: AsyncTcpStream,
104    c: RustlsConnectorConfig,
105    domain: &str,
106    config: TLSConfig<'_, '_, '_>,
107) -> io::Result<AsyncTcpStream> {
108    s.into_rustls(&rustls_connector(c, config)?, domain).await
109}
110
111impl From<RustlsStream> for TcpStream {
112    fn from(s: RustlsStream) -> Self {
113        TcpStream::Rustls(Box::new(s))
114    }
115}
116
117impl From<RustlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
118    fn from(mid: RustlsMidHandshakeTlsStream) -> Self {
119        MidHandshakeTlsStream::Rustls(mid)
120    }
121}
122
123impl From<RustlsHandshakeError> for HandshakeError {
124    fn from(error: RustlsHandshakeError) -> Self {
125        match error {
126            rustls_connector::HandshakeError::WouldBlock(mid) => {
127                HandshakeError::WouldBlock((*mid).into())
128            }
129            rustls_connector::HandshakeError::Failure(failure) => HandshakeError::Failure(failure),
130        }
131    }
132}