tcp_stream/
native_tls_impl.rs

1use crate::{
2    HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, TLSConfig, TcpStream,
3};
4
5#[cfg(feature = "native-tls-futures")]
6use crate::AsyncTcpStream;
7
8use native_tls::Certificate;
9use std::io;
10
11/// Reexport native-tls's `TlsConnector`
12pub use native_tls::TlsConnector as NativeTlsConnector;
13
14/// Reexport native-tls's `TlsConnectorBuilder`
15pub use native_tls::TlsConnectorBuilder as NativeTlsConnectorBuilder;
16
17/// A `TcpStream` wrapped by native-tls
18pub type NativeTlsStream = native_tls::TlsStream<TcpStream>;
19
20/// A `MidHandshakeTlsStream` from native-tls
21pub type NativeTlsMidHandshakeTlsStream = native_tls::MidHandshakeTlsStream<TcpStream>;
22
23/// A `HandshakeError` from native-tls
24pub type NativeTlsHandshakeError = native_tls::HandshakeError<TcpStream>;
25
26fn native_tls_connector_builder(
27    config: TLSConfig<'_, '_, '_>,
28) -> io::Result<NativeTlsConnectorBuilder> {
29    let mut builder = NativeTlsConnector::builder();
30    if let Some(identity) = config.identity {
31        let native_identity = match identity {
32            Identity::PKCS8 { pem, key } => native_tls::Identity::from_pkcs8(pem, key),
33            Identity::PKCS12 { der, password } => native_tls::Identity::from_pkcs12(der, password),
34        };
35        builder.identity(native_identity.map_err(io::Error::other)?);
36    }
37    if let Some(cert_chain) = config.cert_chain {
38        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
39        for cert in rustls_pemfile::certs(&mut cert_chain).collect::<Result<Vec<_>, _>>()? {
40            builder
41                .add_root_certificate(Certificate::from_der(&cert[..]).map_err(io::Error::other)?);
42        }
43    }
44    Ok(builder)
45}
46
47fn native_tls_connector(config: TLSConfig<'_, '_, '_>) -> io::Result<NativeTlsConnector> {
48    native_tls_connector_builder(config)?
49        .build()
50        .map_err(io::Error::other)
51}
52
53#[allow(dead_code)]
54pub(crate) fn into_native_tls_impl(
55    s: TcpStream,
56    domain: &str,
57    config: TLSConfig<'_, '_, '_>,
58) -> HandshakeResult {
59    s.into_native_tls(&native_tls_connector(config)?, domain)
60}
61
62#[cfg(feature = "native-tls-futures")]
63#[allow(dead_code)]
64pub(crate) async fn into_native_tls_impl_async(
65    s: AsyncTcpStream,
66    domain: &str,
67    config: TLSConfig<'_, '_, '_>,
68) -> io::Result<AsyncTcpStream> {
69    s.into_native_tls(native_tls_connector_builder(config)?, domain)
70        .await
71}
72
73impl From<NativeTlsStream> for TcpStream {
74    fn from(s: NativeTlsStream) -> Self {
75        TcpStream::NativeTls(Box::new(s))
76    }
77}
78
79impl From<NativeTlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
80    fn from(mid: NativeTlsMidHandshakeTlsStream) -> Self {
81        MidHandshakeTlsStream::NativeTls(mid)
82    }
83}
84
85impl From<NativeTlsHandshakeError> for HandshakeError {
86    fn from(error: NativeTlsHandshakeError) -> Self {
87        match error {
88            native_tls::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
89            native_tls::HandshakeError::Failure(failure) => {
90                HandshakeError::Failure(io::Error::other(failure))
91            }
92        }
93    }
94}