tcp_stream/
rustls_impl.rs1use crate::{
2 HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, TLSConfig, TcpStream,
3};
4
5use rustls_connector::rustls_pki_types::{
6 CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, pem::PemObject,
7};
8use std::io;
9
10pub use rustls_connector::{RustlsConnector, RustlsConnectorConfig};
12
13pub type RustlsStream = rustls_connector::TlsStream<TcpStream>;
15
16pub type RustlsMidHandshakeTlsStream = rustls_connector::MidHandshakeTlsStream<TcpStream>;
18
19pub type RustlsHandshakeError = rustls_connector::HandshakeError<TcpStream>;
21
22fn update_rustls_config(
23 c: &mut RustlsConnectorConfig,
24 config: &TLSConfig<'_, '_, '_>,
25) -> io::Result<()> {
26 if let Some(cert_chain) = config.cert_chain {
27 let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
28 let certs = rustls_pemfile::certs(&mut cert_chain)
29 .collect::<Result<Vec<_>, _>>()
30 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
31 c.add_parsable_certificates(certs);
32 }
33 Ok(())
34}
35
36fn rustls_identity(
37 identity: Identity<'_, '_>,
38) -> io::Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
39 let (certs, key) = match identity {
40 Identity::PKCS12 { der, password } => {
41 let pfx =
42 p12_keystore::KeyStore::from_pkcs12(der, password).map_err(io::Error::other)?;
43 let Some((_, keychain)) = pfx.private_key_chain() else {
44 return Err(io::Error::other("No private key in pkcs12 DER"));
45 };
46 let certs = keychain
47 .chain()
48 .iter()
49 .map(|cert| CertificateDer::from(cert.as_der().to_vec()))
50 .collect();
51 (
52 certs,
53 PrivateKeyDer::from(PrivatePkcs8KeyDer::from(keychain.key().to_vec())),
54 )
55 }
56 Identity::PKCS8 { pem, key } => {
57 let mut cert_reader = std::io::BufReader::new(pem);
58 let certs = rustls_pemfile::certs(&mut cert_reader)
59 .collect::<Result<Vec<_>, _>>()
60 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
61 (
62 certs,
63 PrivateKeyDer::from_pem_slice(key)
64 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
65 )
66 }
67 };
68 Ok((certs, key))
69}
70
71fn rustls_connector(
72 mut c: RustlsConnectorConfig,
73 config: TLSConfig<'_, '_, '_>,
74) -> io::Result<RustlsConnector> {
75 update_rustls_config(&mut c, &config)?;
76
77 let connector = if let Some(identity) = config.identity {
78 let (certs, key) = rustls_identity(identity)?;
79 c.connector_with_single_cert(certs, key)
80 .map_err(io::Error::other)?
81 } else {
82 c.connector_with_no_client_auth()
83 };
84 Ok(connector)
85}
86
87pub(crate) fn into_rustls_impl(
88 s: TcpStream,
89 c: RustlsConnectorConfig,
90 domain: &str,
91 config: TLSConfig<'_, '_, '_>,
92) -> HandshakeResult {
93 s.into_rustls(&rustls_connector(c, config)?, domain)
94}
95
96impl From<RustlsStream> for TcpStream {
97 fn from(s: RustlsStream) -> Self {
98 TcpStream::Rustls(Box::new(s))
99 }
100}
101
102impl From<RustlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
103 fn from(mid: RustlsMidHandshakeTlsStream) -> Self {
104 MidHandshakeTlsStream::Rustls(mid)
105 }
106}
107
108impl From<RustlsHandshakeError> for HandshakeError {
109 fn from(error: RustlsHandshakeError) -> Self {
110 match error {
111 rustls_connector::HandshakeError::WouldBlock(mid) => {
112 HandshakeError::WouldBlock((*mid).into())
113 }
114 rustls_connector::HandshakeError::Failure(failure) => HandshakeError::Failure(failure),
115 }
116 }
117}