tiberius_rustls/client/tls_stream/
rustls_tls_stream.rs1use crate::{
2 client::{config::Config, TrustConfig},
3 error::IoErrorKind,
4 Error,
5};
6use futures_util::io::{AsyncRead, AsyncWrite};
7use std::{
8 fs, io,
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12 time::SystemTime,
13};
14use tokio_rustls::{
15 rustls::{
16 client::{
17 HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier,
18 WantsTransparencyPolicyOrClientCert,
19 },
20 Certificate, ClientConfig, ConfigBuilder, DigitallySignedStruct, Error as RustlsError,
21 RootCertStore, ServerName, WantsVerifier,
22 },
23 TlsConnector,
24};
25use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
26use tracing::{event, Level};
27
28impl From<tokio_rustls::rustls::Error> for Error {
29 fn from(e: tokio_rustls::rustls::Error) -> Self {
30 crate::Error::Tls(e.to_string())
31 }
32}
33
34pub(crate) struct TlsStream<S: AsyncRead + AsyncWrite + Unpin + Send>(
35 Compat<tokio_rustls::client::TlsStream<Compat<S>>>,
36);
37
38struct NoCertVerifier;
39
40impl ServerCertVerifier for NoCertVerifier {
41 fn verify_server_cert(
42 &self,
43 _end_entity: &Certificate,
44 _intermediates: &[Certificate],
45 _server_name: &ServerName,
46 _scts: &mut dyn Iterator<Item = &[u8]>,
47 _ocsp_response: &[u8],
48 _now: SystemTime,
49 ) -> Result<ServerCertVerified, RustlsError> {
50 Ok(ServerCertVerified::assertion())
51 }
52
53 fn verify_tls12_signature(
54 &self,
55 _message: &[u8],
56 _cert: &Certificate,
57 _dss: &DigitallySignedStruct,
58 ) -> Result<HandshakeSignatureValid, RustlsError> {
59 Ok(HandshakeSignatureValid::assertion())
60 }
61}
62
63fn get_server_name(config: &Config) -> crate::Result<ServerName> {
64 match (ServerName::try_from(config.get_host()), &config.trust) {
65 (Ok(sn), _) => Ok(sn),
66 (Err(_), TrustConfig::TrustAll) => {
67 Ok(ServerName::try_from("placeholder.domain.com").unwrap())
68 }
69 (Err(e), _) => Err(crate::Error::Tls(e.to_string())),
70 }
71}
72
73impl<S: AsyncRead + AsyncWrite + Unpin + Send> TlsStream<S> {
74 pub(super) async fn new(config: &Config, stream: S) -> crate::Result<Self> {
75 event!(Level::INFO, "Performing a TLS handshake");
76
77 let builder = ClientConfig::builder().with_safe_defaults();
78
79 let client_config = match &config.trust {
80 TrustConfig::CaCertificateLocation(path) => {
81 if let Ok(buf) = fs::read(path) {
82 let cert = match path.extension() {
83 Some(ext)
84 if ext.to_ascii_lowercase() == "pem"
85 || ext.to_ascii_lowercase() == "crt" =>
86 {
87 let pem_cert = rustls_pemfile::certs(&mut buf.as_slice())?;
88 if pem_cert.len() != 1 {
89 return Err(crate::Error::Io {
90 kind: IoErrorKind::InvalidInput,
91 message: format!("Certificate file {} contain 0 or more than 1 certs", path.to_string_lossy()),
92 });
93 }
94
95 Certificate(pem_cert.into_iter().next().unwrap())
96 }
97 Some(ext) if ext.to_ascii_lowercase() == "der" => {
98 Certificate(buf)
99 }
100 Some(_) | None => return Err(crate::Error::Io {
101 kind: IoErrorKind::InvalidInput,
102 message: "Provided CA certificate with unsupported file-extension! Supported types are pem, crt and der.".to_string(),
103 }),
104 };
105 let mut cert_store = RootCertStore::empty();
106 cert_store.add(&cert)?;
107 builder
108 .with_root_certificates(cert_store)
109 .with_no_client_auth()
110 } else {
111 return Err(Error::Io {
112 kind: IoErrorKind::InvalidData,
113 message: "Could not read provided CA certificate!".to_string(),
114 });
115 }
116 }
117 TrustConfig::TrustAll => {
118 event!(
119 Level::WARN,
120 "Trusting the server certificate without validation."
121 );
122 let mut config = builder
123 .with_root_certificates(RootCertStore::empty())
124 .with_no_client_auth();
125 config
126 .dangerous()
127 .set_certificate_verifier(Arc::new(NoCertVerifier {}));
128 config
130 }
131 TrustConfig::Default => {
132 event!(Level::INFO, "Using default trust configuration.");
133 builder.with_native_roots().with_no_client_auth()
134 }
135 };
136
137 let connector = TlsConnector::from(Arc::new(client_config));
138
139 let tls_stream = connector
140 .connect(get_server_name(config)?, stream.compat())
141 .await?;
142
143 Ok(TlsStream(tls_stream.compat()))
144 }
145
146 pub(crate) fn get_mut(&mut self) -> &mut S {
147 self.0.get_mut().get_mut().0.get_mut()
148 }
149}
150
151impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsStream<S> {
152 fn poll_read(
153 self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: &mut [u8],
156 ) -> Poll<io::Result<usize>> {
157 let inner = Pin::get_mut(self);
158 Pin::new(&mut inner.0).poll_read(cx, buf)
159 }
160}
161
162impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsStream<S> {
163 fn poll_write(
164 self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 buf: &[u8],
167 ) -> Poll<io::Result<usize>> {
168 let inner = Pin::get_mut(self);
169 Pin::new(&mut inner.0).poll_write(cx, buf)
170 }
171
172 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173 let inner = Pin::get_mut(self);
174 Pin::new(&mut inner.0).poll_flush(cx)
175 }
176
177 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178 let inner = Pin::get_mut(self);
179 Pin::new(&mut inner.0).poll_close(cx)
180 }
181}
182
183trait ConfigBuilderExt {
184 fn with_native_roots(self) -> ConfigBuilder<ClientConfig, WantsTransparencyPolicyOrClientCert>;
185}
186
187impl ConfigBuilderExt for ConfigBuilder<ClientConfig, WantsVerifier> {
188 fn with_native_roots(self) -> ConfigBuilder<ClientConfig, WantsTransparencyPolicyOrClientCert> {
189 let mut roots = RootCertStore::empty();
190 let mut valid_count = 0;
191 let mut invalid_count = 0;
192
193 for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
194 {
195 let cert = Certificate(cert.0);
196 match roots.add(&cert) {
197 Ok(_) => valid_count += 1,
198 Err(err) => {
199 tracing::event!(Level::TRACE, "invalid cert der {:?}", cert.0);
200 tracing::event!(Level::DEBUG, "certificate parsing failed: {:?}", err);
201 invalid_count += 1
202 }
203 }
204 }
205 tracing::event!(
206 Level::TRACE,
207 "with_native_roots processed {} valid and {} invalid certs",
208 valid_count,
209 invalid_count
210 );
211 assert!(!roots.is_empty(), "no CA certificates found");
212
213 self.with_root_certificates(roots)
214 }
215}