1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![forbid(missing_docs, unsafe_code)]
4#![warn(clippy::all, clippy::pedantic)]
5
6use std::sync::Arc;
7
8use rustls::ClientConfig;
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_postgres::tls::MakeTlsConnect;
11
12mod private {
13 use std::{
14 convert::TryFrom,
15 future::Future,
16 io,
17 pin::Pin,
18 task::{Context, Poll},
19 };
20
21 use rustls::pki_types::ServerName;
22 use sha2::{Digest, Sha256, Sha384, Sha512};
23 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24 use tokio_postgres::tls::{ChannelBinding, TlsConnect};
25 use tokio_rustls::{TlsConnector, client::TlsStream};
26 use x509_cert::der::oid::db::rfc5912::{
27 ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
28 SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
29 SHA_512_WITH_RSA_ENCRYPTION,
30 };
31 use x509_cert::{Certificate, der::Decode, der::oid::ObjectIdentifier};
32
33 pub enum TlsConnectFuture<S> {
34 Connect(Box<tokio_rustls::Connect<S>>),
35 Error(Option<io::Error>),
36 }
37
38 impl<S> Future for TlsConnectFuture<S>
39 where
40 S: AsyncRead + AsyncWrite + Unpin,
41 {
42 type Output = io::Result<RustlsStream<S>>;
43
44 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45 match &mut *self {
46 Self::Connect(inner) => Pin::new(inner.as_mut()).poll(cx).map_ok(RustlsStream),
47 Self::Error(error) => Poll::Ready(Err(error
48 .take()
49 .expect("TlsConnectFuture polled after completion"))),
50 }
51 }
52 }
53
54 pub struct RustlsConnect(pub RustlsConnectData);
55
56 pub struct RustlsConnectData {
57 pub hostname: String,
58 pub connector: TlsConnector,
59 }
60
61 impl<S> TlsConnect<S> for RustlsConnect
62 where
63 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
64 {
65 type Stream = RustlsStream<S>;
66 type Error = io::Error;
67 type Future = TlsConnectFuture<S>;
68
69 fn connect(self, stream: S) -> Self::Future {
70 match ServerName::try_from(self.0.hostname) {
71 Ok(hostname) => {
72 TlsConnectFuture::Connect(Box::new(self.0.connector.connect(hostname, stream)))
73 }
74 Err(error) => TlsConnectFuture::Error(Some(io::Error::new(
75 io::ErrorKind::InvalidInput,
76 error,
77 ))),
78 }
79 }
80 }
81
82 pub struct RustlsStream<S>(TlsStream<S>);
83
84 pub(super) enum ChannelBindingDigest {
85 Sha256,
86 Sha384,
87 Sha512,
88 }
89
90 impl ChannelBindingDigest {
91 pub(super) fn digest(&self, data: &[u8]) -> Vec<u8> {
92 match self {
93 Self::Sha256 => Sha256::digest(data).to_vec(),
94 Self::Sha384 => Sha384::digest(data).to_vec(),
95 Self::Sha512 => Sha512::digest(data).to_vec(),
96 }
97 }
98
99 #[cfg(test)]
100 pub(super) fn output_len(&self) -> usize {
101 match self {
102 Self::Sha256 => 32,
103 Self::Sha384 => 48,
104 Self::Sha512 => 64,
105 }
106 }
107 }
108
109 pub(super) fn channel_binding_digest(
110 signature_algorithm: ObjectIdentifier,
111 ) -> Option<ChannelBindingDigest> {
112 match signature_algorithm {
113 ID_SHA_1
115 | ID_SHA_256
116 | SHA_1_WITH_RSA_ENCRYPTION
117 | SHA_256_WITH_RSA_ENCRYPTION
118 | ECDSA_WITH_SHA_256 => Some(ChannelBindingDigest::Sha256),
119 ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
120 Some(ChannelBindingDigest::Sha384)
121 }
122 ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION => Some(ChannelBindingDigest::Sha512),
123 _ => None,
126 }
127 }
128
129 impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
130 where
131 S: AsyncRead + AsyncWrite + Unpin,
132 {
133 fn channel_binding(&self) -> ChannelBinding {
134 let (_, session) = self.0.get_ref();
135 match session.peer_certificates() {
136 Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0])
137 .ok()
138 .and_then(|cert| channel_binding_digest(cert.signature_algorithm.oid))
139 .map_or_else(ChannelBinding::none, |algorithm| {
140 ChannelBinding::tls_server_end_point(algorithm.digest(certs[0].as_ref()))
141 }),
142 _ => ChannelBinding::none(),
143 }
144 }
145 }
146
147 impl<S> AsyncRead for RustlsStream<S>
148 where
149 S: AsyncRead + AsyncWrite + Unpin,
150 {
151 fn poll_read(
152 mut self: Pin<&mut Self>,
153 cx: &mut Context<'_>,
154 buf: &mut ReadBuf<'_>,
155 ) -> Poll<tokio::io::Result<()>> {
156 Pin::new(&mut self.0).poll_read(cx, buf)
157 }
158 }
159
160 impl<S> AsyncWrite for RustlsStream<S>
161 where
162 S: AsyncRead + AsyncWrite + Unpin,
163 {
164 fn poll_write(
165 mut self: Pin<&mut Self>,
166 cx: &mut Context<'_>,
167 buf: &[u8],
168 ) -> Poll<tokio::io::Result<usize>> {
169 Pin::new(&mut self.0).poll_write(cx, buf)
170 }
171
172 fn poll_flush(
173 mut self: Pin<&mut Self>,
174 cx: &mut Context<'_>,
175 ) -> Poll<tokio::io::Result<()>> {
176 Pin::new(&mut self.0).poll_flush(cx)
177 }
178
179 fn poll_shutdown(
180 mut self: Pin<&mut Self>,
181 cx: &mut Context<'_>,
182 ) -> Poll<tokio::io::Result<()>> {
183 Pin::new(&mut self.0).poll_shutdown(cx)
184 }
185 }
186}
187
188#[derive(Clone)]
192pub struct MakeRustlsConnect {
193 config: Arc<ClientConfig>,
194}
195
196impl MakeRustlsConnect {
197 #[must_use]
199 pub fn new(config: ClientConfig) -> Self {
200 Self {
201 config: Arc::new(config),
202 }
203 }
204
205 #[cfg(any(feature = "native-certs", feature = "webpki-roots"))]
206 fn from_root_certificates(roots: rustls::RootCertStore) -> Self {
207 Self::new(
208 ClientConfig::builder()
209 .with_root_certificates(roots)
210 .with_no_client_auth(),
211 )
212 }
213
214 #[cfg(feature = "webpki-roots")]
219 #[must_use]
220 pub fn with_webpki_roots() -> Self {
221 Self::from_root_certificates(rustls::RootCertStore {
222 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
223 })
224 }
225
226 #[cfg(feature = "native-certs")]
235 pub fn with_native_certs()
236 -> Result<(Self, Vec<rustls_native_certs::Error>), Vec<rustls_native_certs::Error>> {
237 let result = rustls_native_certs::load_native_certs();
238 if !result.certs.is_empty() {
239 let mut roots = rustls::RootCertStore::empty();
240 roots.add_parsable_certificates(result.certs);
241 Ok((Self::from_root_certificates(roots), result.errors))
242 } else {
243 Err(result.errors)
244 }
245 }
246}
247
248impl<S> MakeTlsConnect<S> for MakeRustlsConnect
249where
250 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
251{
252 type Stream = private::RustlsStream<S>;
253 type TlsConnect = private::RustlsConnect;
254 type Error = std::convert::Infallible;
255
256 fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
257 Ok(private::RustlsConnect(private::RustlsConnectData {
258 hostname: hostname.to_owned(),
259 connector: Arc::clone(&self.config).into(),
260 }))
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
268 use rustls::pki_types::{CertificateDer, UnixTime};
269 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
270 use rustls::{
271 Error, SignatureScheme,
272 client::danger::ServerCertVerifier,
273 client::danger::{HandshakeSignatureValid, ServerCertVerified},
274 pki_types::ServerName,
275 };
276 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
277 use tokio::io::DuplexStream;
278 use x509_cert::der::oid::db::{rfc5912::SHA_512_WITH_RSA_ENCRYPTION, rfc8410::ID_ED_25519};
279
280 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
281 fn client_config_with_provider(
282 provider: rustls::crypto::CryptoProvider,
283 ) -> rustls::ClientConfig {
284 rustls::ClientConfig::builder_with_provider(provider.into())
285 .with_safe_default_protocol_versions()
286 .expect("default protocol versions")
287 .with_root_certificates(rustls::RootCertStore::empty())
288 .with_no_client_auth()
289 }
290
291 #[cfg(feature = "aws-lc-rs")]
292 fn aws_lc_rs_client_config() -> rustls::ClientConfig {
293 client_config_with_provider(rustls::crypto::aws_lc_rs::default_provider())
294 }
295
296 #[cfg(feature = "ring")]
297 fn ring_client_config() -> rustls::ClientConfig {
298 client_config_with_provider(rustls::crypto::ring::default_provider())
299 }
300
301 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
302 #[derive(Debug)]
303 struct AcceptAllVerifier {}
304
305 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
306 impl ServerCertVerifier for AcceptAllVerifier {
307 fn verify_server_cert(
308 &self,
309 _end_entity: &CertificateDer<'_>,
310 _intermediates: &[CertificateDer<'_>],
311 _server_name: &ServerName<'_>,
312 _ocsp_response: &[u8],
313 _now: UnixTime,
314 ) -> Result<ServerCertVerified, Error> {
315 Ok(ServerCertVerified::assertion())
316 }
317
318 fn verify_tls12_signature(
319 &self,
320 _message: &[u8],
321 _cert: &CertificateDer<'_>,
322 _dss: &rustls::DigitallySignedStruct,
323 ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
324 Ok(HandshakeSignatureValid::assertion())
325 }
326
327 fn verify_tls13_signature(
328 &self,
329 _message: &[u8],
330 _cert: &CertificateDer<'_>,
331 _dss: &rustls::DigitallySignedStruct,
332 ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
333 Ok(HandshakeSignatureValid::assertion())
334 }
335
336 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
337 vec![
338 SignatureScheme::ECDSA_NISTP384_SHA384,
339 SignatureScheme::ECDSA_NISTP256_SHA256,
340 SignatureScheme::RSA_PSS_SHA512,
341 SignatureScheme::RSA_PSS_SHA384,
342 SignatureScheme::RSA_PSS_SHA256,
343 SignatureScheme::ED25519,
344 ]
345 }
346 }
347
348 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
349 async fn connect_works(mut config: rustls::ClientConfig) {
350 env_logger::builder().is_test(true).try_init().unwrap();
351
352 config
353 .dangerous()
354 .set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
355 let tls = super::MakeRustlsConnect::new(config);
356 let (client, conn) = tokio_postgres::connect(
357 "sslmode=require host=localhost port=5432 user=postgres",
358 tls,
359 )
360 .await
361 .expect("connect");
362 tokio::spawn(async move { conn.await.map_err(|e| panic!("{e:?}")) });
363 let stmt = client.prepare("SELECT 1").await.expect("prepare");
364 let _ = client.query(&stmt, &[]).await.expect("query");
365 }
366
367 #[cfg(feature = "aws-lc-rs")]
368 #[tokio::test]
369 async fn it_works_with_aws_lc_rs() {
370 connect_works(aws_lc_rs_client_config()).await;
371 }
372
373 #[cfg(feature = "ring")]
374 #[tokio::test]
375 async fn it_works_with_ring() {
376 connect_works(ring_client_config()).await;
377 }
378
379 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
380 fn accepts_unix_socket_hostname_before_tls_is_used(config: rustls::ClientConfig) {
381 let mut tls = super::MakeRustlsConnect::new(config);
382
383 let tls_connect =
384 <super::MakeRustlsConnect as MakeTlsConnect<DuplexStream>>::make_tls_connect(
385 &mut tls,
386 "/var/run/postgresql",
387 );
388
389 assert!(tls_connect.is_ok());
390 }
391
392 #[cfg(feature = "aws-lc-rs")]
393 #[test]
394 fn accepts_unix_socket_hostname_before_tls_is_used_with_aws_lc_rs() {
395 accepts_unix_socket_hostname_before_tls_is_used(aws_lc_rs_client_config());
396 }
397
398 #[cfg(feature = "ring")]
399 #[test]
400 fn accepts_unix_socket_hostname_before_tls_is_used_with_ring() {
401 accepts_unix_socket_hostname_before_tls_is_used(ring_client_config());
402 }
403
404 #[test]
405 fn ed25519_has_no_channel_binding_digest() {
406 assert!(private::channel_binding_digest(ID_ED_25519).is_none());
407 }
408
409 #[test]
410 fn sha512_with_rsa_has_channel_binding_digest() {
411 let algorithm = private::channel_binding_digest(SHA_512_WITH_RSA_ENCRYPTION)
412 .expect("SHA-512 signature algorithm should map to a digest");
413
414 assert_eq!(algorithm.output_len(), 64);
415 }
416}