tcp_stream/
openssl_impl.rs

1use crate::{
2    HandshakeError, HandshakeResult, Identity, MidHandshakeTlsStream, TLSConfig, TcpStream,
3};
4
5use openssl::x509::X509;
6use std::io;
7
8/// Reexport openssl's `TlsConnector`
9pub use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod as OpenSslMethod};
10
11/// A `TcpStream` wrapped by openssl
12pub type OpenSslStream = openssl::ssl::SslStream<TcpStream>;
13
14/// A `MidHandshakeTlsStream` from openssl
15pub type OpenSslMidHandshakeTlsStream = openssl::ssl::MidHandshakeSslStream<TcpStream>;
16
17/// A `HandshakeError` from openssl
18pub type OpenSslHandshakeError = openssl::ssl::HandshakeError<TcpStream>;
19
20/// An `ErrorStack` from openssl
21pub type OpenSslErrorStack = openssl::error::ErrorStack;
22
23fn openssl_connector(config: TLSConfig<'_, '_, '_>) -> io::Result<OpenSslConnector> {
24    let mut builder = OpenSslConnector::builder(OpenSslMethod::tls())?;
25    if let Some(identity) = config.identity {
26        let (cert, pkey, chain) = match identity {
27            Identity::PKCS8 { pem, key } => {
28                let pkey = openssl::pkey::PKey::private_key_from_pem(key)?;
29                let mut chain = openssl::x509::X509::stack_from_pem(pem)?.into_iter();
30                let cert = chain.next();
31                (cert, Some(pkey), Some(chain.collect()))
32            }
33            Identity::PKCS12 { der, password } => {
34                let mut openssl_identity =
35                    openssl::pkcs12::Pkcs12::from_der(der)?.parse2(password)?;
36                (
37                    openssl_identity.cert,
38                    openssl_identity.pkey,
39                    openssl_identity
40                        .ca
41                        .take()
42                        .map(|stack| stack.into_iter().collect::<Vec<_>>()),
43                )
44            }
45        };
46        if let Some(cert) = cert.as_ref() {
47            builder.set_certificate(cert)?;
48        }
49        if let Some(pkey) = pkey.as_ref() {
50            builder.set_private_key(pkey)?;
51        }
52        if let Some(chain) = chain.as_ref() {
53            for cert in chain.iter().rev() {
54                builder.add_extra_chain_cert(cert.to_owned())?;
55            }
56        }
57    }
58    if let Some(cert_chain) = config.cert_chain.as_ref() {
59        for cert in X509::stack_from_pem(cert_chain.as_bytes())?.drain(..).rev() {
60            builder.cert_store_mut().add_cert(cert)?;
61        }
62    }
63    Ok(builder.build())
64}
65
66#[allow(dead_code)]
67pub(crate) fn into_openssl_impl(
68    s: TcpStream,
69    domain: &str,
70    config: TLSConfig<'_, '_, '_>,
71) -> HandshakeResult {
72    s.into_openssl(&openssl_connector(config)?, domain)
73}
74
75impl From<OpenSslStream> for TcpStream {
76    fn from(s: OpenSslStream) -> Self {
77        TcpStream::OpenSsl(Box::new(s))
78    }
79}
80
81impl From<OpenSslMidHandshakeTlsStream> for MidHandshakeTlsStream {
82    fn from(mid: OpenSslMidHandshakeTlsStream) -> Self {
83        MidHandshakeTlsStream::Openssl(mid)
84    }
85}
86
87impl From<OpenSslHandshakeError> for HandshakeError {
88    fn from(error: OpenSslHandshakeError) -> Self {
89        match error {
90            openssl::ssl::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
91            openssl::ssl::HandshakeError::Failure(failure) => {
92                HandshakeError::Failure(io::Error::other(failure.into_error()))
93            }
94            openssl::ssl::HandshakeError::SetupFailure(failure) => failure.into(),
95        }
96    }
97}
98
99impl From<OpenSslErrorStack> for HandshakeError {
100    fn from(error: OpenSslErrorStack) -> Self {
101        Self::Failure(error.into())
102    }
103}