tcp_stream/
native_tls_impl.rs1use crate::{
2 HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, TLSConfig, TcpStream,
3};
4
5#[cfg(feature = "native-tls-futures")]
6use crate::AsyncTcpStream;
7
8use native_tls::Certificate;
9use std::io;
10
11pub use native_tls::TlsConnector as NativeTlsConnector;
13
14pub use native_tls::TlsConnectorBuilder as NativeTlsConnectorBuilder;
16
17pub type NativeTlsStream = native_tls::TlsStream<TcpStream>;
19
20pub type NativeTlsMidHandshakeTlsStream = native_tls::MidHandshakeTlsStream<TcpStream>;
22
23pub type NativeTlsHandshakeError = native_tls::HandshakeError<TcpStream>;
25
26fn native_tls_connector_builder(
27 config: TLSConfig<'_, '_, '_>,
28) -> io::Result<NativeTlsConnectorBuilder> {
29 let mut builder = NativeTlsConnector::builder();
30 if let Some(identity) = config.identity {
31 let native_identity = match identity {
32 Identity::PKCS8 { pem, key } => native_tls::Identity::from_pkcs8(pem, key),
33 Identity::PKCS12 { der, password } => native_tls::Identity::from_pkcs12(der, password),
34 };
35 builder.identity(native_identity.map_err(io::Error::other)?);
36 }
37 if let Some(cert_chain) = config.cert_chain {
38 let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
39 for cert in rustls_pemfile::certs(&mut cert_chain).collect::<Result<Vec<_>, _>>()? {
40 builder
41 .add_root_certificate(Certificate::from_der(&cert[..]).map_err(io::Error::other)?);
42 }
43 }
44 Ok(builder)
45}
46
47fn native_tls_connector(config: TLSConfig<'_, '_, '_>) -> io::Result<NativeTlsConnector> {
48 native_tls_connector_builder(config)?
49 .build()
50 .map_err(io::Error::other)
51}
52
53#[allow(dead_code)]
54pub(crate) fn into_native_tls_impl(
55 s: TcpStream,
56 domain: &str,
57 config: TLSConfig<'_, '_, '_>,
58) -> HandshakeResult {
59 s.into_native_tls(&native_tls_connector(config)?, domain)
60}
61
62#[cfg(feature = "native-tls-futures")]
63#[allow(dead_code)]
64pub(crate) async fn into_native_tls_impl_async(
65 s: AsyncTcpStream,
66 domain: &str,
67 config: TLSConfig<'_, '_, '_>,
68) -> io::Result<AsyncTcpStream> {
69 s.into_native_tls(native_tls_connector_builder(config)?, domain)
70 .await
71}
72
73impl From<NativeTlsStream> for TcpStream {
74 fn from(s: NativeTlsStream) -> Self {
75 TcpStream::NativeTls(Box::new(s))
76 }
77}
78
79impl From<NativeTlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
80 fn from(mid: NativeTlsMidHandshakeTlsStream) -> Self {
81 MidHandshakeTlsStream::NativeTls(mid)
82 }
83}
84
85impl From<NativeTlsHandshakeError> for HandshakeError {
86 fn from(error: NativeTlsHandshakeError) -> Self {
87 match error {
88 native_tls::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
89 native_tls::HandshakeError::Failure(failure) => {
90 HandshakeError::Failure(io::Error::other(failure))
91 }
92 }
93 }
94}