tcp_stream/
native_tls_impl.rs

1use crate::{
2    HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, TLSConfig, TcpStream,
3};
4
5use native_tls::Certificate;
6use std::io;
7
8/// Reexport native-tls's `TlsConnector`
9pub use native_tls::TlsConnector as NativeTlsConnector;
10
11/// A `TcpStream` wrapped by native-tls
12pub type NativeTlsStream = native_tls::TlsStream<TcpStream>;
13
14/// A `MidHandshakeTlsStream` from native-tls
15pub type NativeTlsMidHandshakeTlsStream = native_tls::MidHandshakeTlsStream<TcpStream>;
16
17/// A `HandshakeError` from native-tls
18pub type NativeTlsHandshakeError = native_tls::HandshakeError<TcpStream>;
19
20pub(crate) fn into_native_tls_impl(
21    s: TcpStream,
22    domain: &str,
23    config: TLSConfig<'_, '_, '_>,
24) -> HandshakeResult {
25    let mut builder = NativeTlsConnector::builder();
26    if let Some(identity) = config.identity {
27        let native_identity = match identity {
28            Identity::PKCS8 { pem, key } => native_tls::Identity::from_pkcs8(pem, key),
29            Identity::PKCS12 { der, password } => native_tls::Identity::from_pkcs12(der, password),
30        };
31        builder.identity(native_identity.map_err(io::Error::other)?);
32    }
33    if let Some(cert_chain) = config.cert_chain {
34        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
35        for cert in rustls_pemfile::certs(&mut cert_chain).collect::<Result<Vec<_>, _>>()? {
36            builder
37                .add_root_certificate(Certificate::from_der(&cert[..]).map_err(io::Error::other)?);
38        }
39    }
40    s.into_native_tls(&builder.build().map_err(io::Error::other)?, domain)
41}
42
43impl From<NativeTlsStream> for TcpStream {
44    fn from(s: NativeTlsStream) -> Self {
45        TcpStream::NativeTls(Box::new(s))
46    }
47}
48
49impl From<NativeTlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
50    fn from(mid: NativeTlsMidHandshakeTlsStream) -> Self {
51        MidHandshakeTlsStream::NativeTls(mid)
52    }
53}
54
55impl From<NativeTlsHandshakeError> for HandshakeError {
56    fn from(error: NativeTlsHandshakeError) -> Self {
57        match error {
58            native_tls::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
59            native_tls::HandshakeError::Failure(failure) => {
60                HandshakeError::Failure(io::Error::other(failure))
61            }
62        }
63    }
64}