Skip to main content

tcp_stream/
native_tls_impl.rs

1use crate::{
2    HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, StdTcpStream, TLSConfig,
3    TcpStream,
4};
5
6use rustls_pki_types::{CertificateDer, pem::PemObject};
7
8#[cfg(feature = "native-tls-futures")]
9use {
10    crate::AsyncTcpStream,
11    futures_io::{AsyncRead, AsyncWrite},
12};
13
14use native_tls::Certificate;
15use std::io;
16
17/// Reexport native-tls's `TlsConnector`
18pub use native_tls::TlsConnector as NativeTlsConnector;
19
20/// Reexport native-tls's `TlsConnectorBuilder`
21pub use native_tls::TlsConnectorBuilder as NativeTlsConnectorBuilder;
22
23/// A `TcpStream` wrapped by native-tls
24pub type NativeTlsStream = native_tls::TlsStream<StdTcpStream>;
25
26/// A `MidHandshakeTlsStream` from native-tls
27pub type NativeTlsMidHandshakeTlsStream = native_tls::MidHandshakeTlsStream<StdTcpStream>;
28
29/// A `HandshakeError` from native-tls
30pub type NativeTlsHandshakeError = native_tls::HandshakeError<StdTcpStream>;
31
32#[cfg(feature = "native-tls-futures")]
33/// An async `TcpStream` wrapped by native-tls
34pub type NativeTlsAsyncStream<S> = async_native_tls::TlsStream<S>;
35
36fn native_tls_connector_builder(
37    config: TLSConfig<'_, '_, '_>,
38) -> io::Result<NativeTlsConnectorBuilder> {
39    let mut builder = NativeTlsConnector::builder();
40    if let Some(identity) = config.identity {
41        let native_identity = match identity {
42            Identity::PKCS8 { pem, key } => native_tls::Identity::from_pkcs8(pem, key),
43            Identity::PKCS12 { der, password } => native_tls::Identity::from_pkcs12(der, password),
44        };
45        builder.identity(native_identity.map_err(io::Error::other)?);
46    }
47    if let Some(cert_chain) = config.cert_chain {
48        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
49        for cert in CertificateDer::pem_reader_iter(&mut cert_chain)
50            .collect::<Result<Vec<_>, _>>()
51            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
52        {
53            builder
54                .add_root_certificate(Certificate::from_der(&cert[..]).map_err(io::Error::other)?);
55        }
56    }
57    Ok(builder)
58}
59
60fn native_tls_connector(config: TLSConfig<'_, '_, '_>) -> io::Result<NativeTlsConnector> {
61    native_tls_connector_builder(config)?
62        .build()
63        .map_err(io::Error::other)
64}
65
66#[allow(dead_code)]
67pub(crate) fn into_native_tls_impl(
68    s: TcpStream,
69    domain: &str,
70    config: TLSConfig<'_, '_, '_>,
71) -> HandshakeResult {
72    s.into_native_tls(&native_tls_connector(config)?, domain)
73}
74
75#[cfg(feature = "native-tls-futures")]
76#[allow(dead_code)]
77pub(crate) async fn into_native_tls_impl_async<
78    S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
79>(
80    s: AsyncTcpStream<S>,
81    domain: &str,
82    config: TLSConfig<'_, '_, '_>,
83) -> io::Result<AsyncTcpStream<S>> {
84    s.into_native_tls(native_tls_connector_builder(config)?, domain)
85        .await
86}
87
88impl From<NativeTlsStream> for TcpStream {
89    fn from(s: NativeTlsStream) -> Self {
90        Self::NativeTls(s)
91    }
92}
93
94impl From<NativeTlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
95    fn from(mid: NativeTlsMidHandshakeTlsStream) -> Self {
96        Self::NativeTls(mid)
97    }
98}
99
100impl From<NativeTlsHandshakeError> for HandshakeError {
101    fn from(error: NativeTlsHandshakeError) -> Self {
102        match error {
103            native_tls::HandshakeError::WouldBlock(mid) => Self::WouldBlock(mid.into()),
104            native_tls::HandshakeError::Failure(failure) => {
105                Self::Failure(io::Error::other(failure))
106            }
107        }
108    }
109}