tcp_stream/
rustls_impl.rs

1use crate::{
2    HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, TLSConfig, TcpStream,
3};
4
5use rustls_connector::rustls_pki_types::{
6    CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, pem::PemObject,
7};
8use std::io;
9
10/// Reexport rustls-connector's `TlsConnector`
11pub use rustls_connector::{RustlsConnector, RustlsConnectorConfig};
12
13/// A `TcpStream` wrapped by rustls
14pub type RustlsStream = rustls_connector::TlsStream<TcpStream>;
15
16/// A `MidHandshakeTlsStream` from rustls-connector
17pub type RustlsMidHandshakeTlsStream = rustls_connector::MidHandshakeTlsStream<TcpStream>;
18
19/// A `HandshakeError` from rustls-connector
20pub type RustlsHandshakeError = rustls_connector::HandshakeError<TcpStream>;
21
22fn update_rustls_config(
23    c: &mut RustlsConnectorConfig,
24    config: &TLSConfig<'_, '_, '_>,
25) -> io::Result<()> {
26    if let Some(cert_chain) = config.cert_chain {
27        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
28        let certs = rustls_pemfile::certs(&mut cert_chain)
29            .collect::<Result<Vec<_>, _>>()
30            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
31        c.add_parsable_certificates(certs);
32    }
33    Ok(())
34}
35
36fn rustls_identity(
37    identity: Identity<'_, '_>,
38) -> io::Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
39    let (certs, key) = match identity {
40        Identity::PKCS12 { der, password } => {
41            let pfx =
42                p12_keystore::KeyStore::from_pkcs12(der, password).map_err(io::Error::other)?;
43            let Some((_, keychain)) = pfx.private_key_chain() else {
44                return Err(io::Error::other("No private key in pkcs12 DER"));
45            };
46            let certs = keychain
47                .chain()
48                .iter()
49                .map(|cert| CertificateDer::from(cert.as_der().to_vec()))
50                .collect();
51            (
52                certs,
53                PrivateKeyDer::from(PrivatePkcs8KeyDer::from(keychain.key().to_vec())),
54            )
55        }
56        Identity::PKCS8 { pem, key } => {
57            let mut cert_reader = std::io::BufReader::new(pem);
58            let certs = rustls_pemfile::certs(&mut cert_reader)
59                .collect::<Result<Vec<_>, _>>()
60                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
61            (
62                certs,
63                PrivateKeyDer::from_pem_slice(key)
64                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
65            )
66        }
67    };
68    Ok((certs, key))
69}
70
71fn rustls_connector(
72    mut c: RustlsConnectorConfig,
73    config: TLSConfig<'_, '_, '_>,
74) -> io::Result<RustlsConnector> {
75    update_rustls_config(&mut c, &config)?;
76
77    let connector = if let Some(identity) = config.identity {
78        let (certs, key) = rustls_identity(identity)?;
79        c.connector_with_single_cert(certs, key)
80            .map_err(io::Error::other)?
81    } else {
82        c.connector_with_no_client_auth()
83    };
84    Ok(connector)
85}
86
87pub(crate) fn into_rustls_impl(
88    s: TcpStream,
89    c: RustlsConnectorConfig,
90    domain: &str,
91    config: TLSConfig<'_, '_, '_>,
92) -> HandshakeResult {
93    s.into_rustls(&rustls_connector(c, config)?, domain)
94}
95
96impl From<RustlsStream> for TcpStream {
97    fn from(s: RustlsStream) -> Self {
98        TcpStream::Rustls(Box::new(s))
99    }
100}
101
102impl From<RustlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
103    fn from(mid: RustlsMidHandshakeTlsStream) -> Self {
104        MidHandshakeTlsStream::Rustls(mid)
105    }
106}
107
108impl From<RustlsHandshakeError> for HandshakeError {
109    fn from(error: RustlsHandshakeError) -> Self {
110        match error {
111            rustls_connector::HandshakeError::WouldBlock(mid) => {
112                HandshakeError::WouldBlock((*mid).into())
113            }
114            rustls_connector::HandshakeError::Failure(failure) => HandshakeError::Failure(failure),
115        }
116    }
117}