tcp_stream/
native_tls_impl.rs

1use crate::{
2    HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, TLSConfig, TcpStream,
3};
4
5use native_tls::Certificate;
6use std::io;
7
8/// Reexport native-tls's `TlsConnector`
9pub use native_tls::TlsConnector as NativeTlsConnector;
10
11/// A `TcpStream` wrapped by native-tls
12pub type NativeTlsStream = native_tls::TlsStream<TcpStream>;
13
14/// A `MidHandshakeTlsStream` from native-tls
15pub type NativeTlsMidHandshakeTlsStream = native_tls::MidHandshakeTlsStream<TcpStream>;
16
17/// A `HandshakeError` from native-tls
18pub type NativeTlsHandshakeError = native_tls::HandshakeError<TcpStream>;
19
20fn native_tls_connector(config: TLSConfig<'_, '_, '_>) -> io::Result<NativeTlsConnector> {
21    let mut builder = NativeTlsConnector::builder();
22    if let Some(identity) = config.identity {
23        let native_identity = match identity {
24            Identity::PKCS8 { pem, key } => native_tls::Identity::from_pkcs8(pem, key),
25            Identity::PKCS12 { der, password } => native_tls::Identity::from_pkcs12(der, password),
26        };
27        builder.identity(native_identity.map_err(io::Error::other)?);
28    }
29    if let Some(cert_chain) = config.cert_chain {
30        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
31        for cert in rustls_pemfile::certs(&mut cert_chain).collect::<Result<Vec<_>, _>>()? {
32            builder
33                .add_root_certificate(Certificate::from_der(&cert[..]).map_err(io::Error::other)?);
34        }
35    }
36    builder.build().map_err(io::Error::other)
37}
38
39pub(crate) fn into_native_tls_impl(
40    s: TcpStream,
41    domain: &str,
42    config: TLSConfig<'_, '_, '_>,
43) -> HandshakeResult {
44    s.into_native_tls(&native_tls_connector(config)?, domain)
45}
46
47impl From<NativeTlsStream> for TcpStream {
48    fn from(s: NativeTlsStream) -> Self {
49        TcpStream::NativeTls(Box::new(s))
50    }
51}
52
53impl From<NativeTlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
54    fn from(mid: NativeTlsMidHandshakeTlsStream) -> Self {
55        MidHandshakeTlsStream::NativeTls(mid)
56    }
57}
58
59impl From<NativeTlsHandshakeError> for HandshakeError {
60    fn from(error: NativeTlsHandshakeError) -> Self {
61        match error {
62            native_tls::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
63            native_tls::HandshakeError::Failure(failure) => {
64                HandshakeError::Failure(io::Error::other(failure))
65            }
66        }
67    }
68}