tiberius_rustls/client/tls_stream/
rustls_tls_stream.rs

1use crate::{
2    client::{config::Config, TrustConfig},
3    error::IoErrorKind,
4    Error,
5};
6use futures_util::io::{AsyncRead, AsyncWrite};
7use std::{
8    fs, io,
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12    time::SystemTime,
13};
14use tokio_rustls::{
15    rustls::{
16        client::{
17            HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier,
18            WantsTransparencyPolicyOrClientCert,
19        },
20        Certificate, ClientConfig, ConfigBuilder, DigitallySignedStruct, Error as RustlsError,
21        RootCertStore, ServerName, WantsVerifier,
22    },
23    TlsConnector,
24};
25use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
26use tracing::{event, Level};
27
28impl From<tokio_rustls::rustls::Error> for Error {
29    fn from(e: tokio_rustls::rustls::Error) -> Self {
30        crate::Error::Tls(e.to_string())
31    }
32}
33
34pub(crate) struct TlsStream<S: AsyncRead + AsyncWrite + Unpin + Send>(
35    Compat<tokio_rustls::client::TlsStream<Compat<S>>>,
36);
37
38struct NoCertVerifier;
39
40impl ServerCertVerifier for NoCertVerifier {
41    fn verify_server_cert(
42        &self,
43        _end_entity: &Certificate,
44        _intermediates: &[Certificate],
45        _server_name: &ServerName,
46        _scts: &mut dyn Iterator<Item = &[u8]>,
47        _ocsp_response: &[u8],
48        _now: SystemTime,
49    ) -> Result<ServerCertVerified, RustlsError> {
50        Ok(ServerCertVerified::assertion())
51    }
52
53    fn verify_tls12_signature(
54        &self,
55        _message: &[u8],
56        _cert: &Certificate,
57        _dss: &DigitallySignedStruct,
58    ) -> Result<HandshakeSignatureValid, RustlsError> {
59        Ok(HandshakeSignatureValid::assertion())
60    }
61}
62
63fn get_server_name(config: &Config) -> crate::Result<ServerName> {
64    match (ServerName::try_from(config.get_host()), &config.trust) {
65        (Ok(sn), _) => Ok(sn),
66        (Err(_), TrustConfig::TrustAll) => {
67            Ok(ServerName::try_from("placeholder.domain.com").unwrap())
68        }
69        (Err(e), _) => Err(crate::Error::Tls(e.to_string())),
70    }
71}
72
73impl<S: AsyncRead + AsyncWrite + Unpin + Send> TlsStream<S> {
74    pub(super) async fn new(config: &Config, stream: S) -> crate::Result<Self> {
75        event!(Level::INFO, "Performing a TLS handshake");
76
77        let builder = ClientConfig::builder().with_safe_defaults();
78
79        let client_config = match &config.trust {
80            TrustConfig::CaCertificateLocation(path) => {
81                if let Ok(buf) = fs::read(path) {
82                    let cert = match path.extension() {
83                            Some(ext)
84                            if ext.to_ascii_lowercase() == "pem"
85                                || ext.to_ascii_lowercase() == "crt" =>
86                                {
87                                    let pem_cert = rustls_pemfile::certs(&mut buf.as_slice())?;
88                                    if pem_cert.len() != 1 {
89                                        return Err(crate::Error::Io {
90                                            kind: IoErrorKind::InvalidInput,
91                                            message: format!("Certificate file {} contain 0 or more than 1 certs", path.to_string_lossy()),
92                                        });
93                                    }
94
95                                    Certificate(pem_cert.into_iter().next().unwrap())
96                                }
97                            Some(ext) if ext.to_ascii_lowercase() == "der" => {
98                                Certificate(buf)
99                            }
100                            Some(_) | None => return Err(crate::Error::Io {
101                                kind: IoErrorKind::InvalidInput,
102                                message: "Provided CA certificate with unsupported file-extension! Supported types are pem, crt and der.".to_string(),
103                            }),
104                        };
105                    let mut cert_store = RootCertStore::empty();
106                    cert_store.add(&cert)?;
107                    builder
108                        .with_root_certificates(cert_store)
109                        .with_no_client_auth()
110                } else {
111                    return Err(Error::Io {
112                        kind: IoErrorKind::InvalidData,
113                        message: "Could not read provided CA certificate!".to_string(),
114                    });
115                }
116            }
117            TrustConfig::TrustAll => {
118                event!(
119                    Level::WARN,
120                    "Trusting the server certificate without validation."
121                );
122                let mut config = builder
123                    .with_root_certificates(RootCertStore::empty())
124                    .with_no_client_auth();
125                config
126                    .dangerous()
127                    .set_certificate_verifier(Arc::new(NoCertVerifier {}));
128                // config.enable_sni = false;
129                config
130            }
131            TrustConfig::Default => {
132                event!(Level::INFO, "Using default trust configuration.");
133                builder.with_native_roots().with_no_client_auth()
134            }
135        };
136
137        let connector = TlsConnector::from(Arc::new(client_config));
138
139        let tls_stream = connector
140            .connect(get_server_name(config)?, stream.compat())
141            .await?;
142
143        Ok(TlsStream(tls_stream.compat()))
144    }
145
146    pub(crate) fn get_mut(&mut self) -> &mut S {
147        self.0.get_mut().get_mut().0.get_mut()
148    }
149}
150
151impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsStream<S> {
152    fn poll_read(
153        self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: &mut [u8],
156    ) -> Poll<io::Result<usize>> {
157        let inner = Pin::get_mut(self);
158        Pin::new(&mut inner.0).poll_read(cx, buf)
159    }
160}
161
162impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsStream<S> {
163    fn poll_write(
164        self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166        buf: &[u8],
167    ) -> Poll<io::Result<usize>> {
168        let inner = Pin::get_mut(self);
169        Pin::new(&mut inner.0).poll_write(cx, buf)
170    }
171
172    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173        let inner = Pin::get_mut(self);
174        Pin::new(&mut inner.0).poll_flush(cx)
175    }
176
177    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178        let inner = Pin::get_mut(self);
179        Pin::new(&mut inner.0).poll_close(cx)
180    }
181}
182
183trait ConfigBuilderExt {
184    fn with_native_roots(self) -> ConfigBuilder<ClientConfig, WantsTransparencyPolicyOrClientCert>;
185}
186
187impl ConfigBuilderExt for ConfigBuilder<ClientConfig, WantsVerifier> {
188    fn with_native_roots(self) -> ConfigBuilder<ClientConfig, WantsTransparencyPolicyOrClientCert> {
189        let mut roots = RootCertStore::empty();
190        let mut valid_count = 0;
191        let mut invalid_count = 0;
192
193        for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
194        {
195            let cert = Certificate(cert.0);
196            match roots.add(&cert) {
197                Ok(_) => valid_count += 1,
198                Err(err) => {
199                    tracing::event!(Level::TRACE, "invalid cert der {:?}", cert.0);
200                    tracing::event!(Level::DEBUG, "certificate parsing failed: {:?}", err);
201                    invalid_count += 1
202                }
203            }
204        }
205        tracing::event!(
206            Level::TRACE,
207            "with_native_roots processed {} valid and {} invalid certs",
208            valid_count,
209            invalid_count
210        );
211        assert!(!roots.is_empty(), "no CA certificates found");
212
213        self.with_root_certificates(roots)
214    }
215}