tokio_postgres_generic_rustls/lib.rs
1//! # tokio-postgres-generic-rustls
2//! _An impelementation of TLS based on rustls for tokio-postgres_
3//!
4//! This crate allows users to select a crypto backend, or bring their own, rather than relying on
5//! primitives provided by `ring` directly. This is done through the use of x509-cert for
6//! certificate parsing rather than X509-certificate, while also adding an abstraction for
7//! computing digests.
8//!
9//! By default, tokio-postgres-generic-rustls does not provide a digest implementation, but one or
10//! more are provided behind crate features.
11//!
12//! | Feature | Impelementation |
13//! | ------------ | ------------------ |
14//! | `aws-lc-rs` | `AwsLcRsDigest` |
15//! | `ring` | `RingDigest` |
16//! | `rustcrypto` | `RustcryptoDigest` |
17//!
18//! ## Usage
19//! Using this crate is fairly straightforward. First, select your digest impelementation via crate
20//! features (or provide your own), then construct rustls connector for tokio-postgres with your
21//! rustls client configuration.
22//!
23//! The following example demonstrates providing a custom digest backend.
24//!
25//! ```rust
26//! use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm, MakeRustlsConnect};
27//!
28//! #[derive(Clone)]
29//! struct DemoDigest;
30//!
31//! impl DigestImplementation for DemoDigest {
32//! fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
33//! todo!("digest it")
34//! }
35//! }
36//!
37//! let cert_store = rustls::RootCertStore::empty();
38//!
39//! let config = rustls::ClientConfig::builder()
40//! .with_root_certificates(cert_store)
41//! .with_no_client_auth();
42//!
43//! let tls = MakeRustlsConnect::new(config, DemoDigest);
44//!
45//! let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
46//!
47//! // connect_future.await;
48//! ```
49//!
50//! ## License
51//! This project is licensed under either of
52//!
53//! - Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
54//! - MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
55//!
56//! at your option.
57
58#![deny(unsafe_code)]
59#![deny(missing_docs)]
60
61use std::{future::Future, pin::Pin, sync::Arc};
62
63use rustls::{pki_types::ServerName, ClientConfig};
64use tokio::io::{AsyncRead, AsyncWrite};
65use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
66use tokio_rustls::{client::TlsStream, Connect, TlsConnector};
67use x509_cert::{
68 der::{
69 oid::db::rfc5912::{
70 ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, SHA_1_WITH_RSA_ENCRYPTION,
71 SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
72 },
73 Decode,
74 },
75 spki::ObjectIdentifier,
76 Certificate,
77};
78
79/// Trait used to provide a custom digest backend to tokio_postgres_generic_rustls
80///
81/// This trait is implementated for three types by default, each behind it's own feature flag. The
82/// provided backends are aws-lc-rs, ring, and rustcrypto.
83pub trait DigestImplementation {
84 /// Hash the provided bytes with the provided algorithm
85 ///
86 /// ```rust
87 /// use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm};
88 ///
89 /// struct CustomAwsLcRsDigest;
90 ///
91 /// impl DigestImplementation for CustomAwsLcRsDigest {
92 /// fn digest(&self, digest_algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
93 /// let digest_alg = match digest_algorithm {
94 /// // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
95 /// DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &aws_lc_rs::digest::SHA256,
96 /// DigestAlgorithm::Sha384 => &aws_lc_rs::digest::SHA384,
97 /// DigestAlgorithm::Sha512 => &aws_lc_rs::digest::SHA512,
98 /// };
99 ///
100 /// aws_lc_rs::digest::digest(digest_alg, bytes).as_ref().into()
101 /// }
102 /// }
103 /// ```
104 fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8>;
105}
106
107#[cfg(feature = "aws-lc-rs")]
108pub use aws_lc_rs_backend::AwsLcRsDigest;
109
110#[cfg(feature = "ring")]
111pub use ring_backend::RingDigest;
112
113#[cfg(feature = "rustcrypto")]
114pub use rustcrypto_backend::RustcryptoDigest;
115
116#[cfg(feature = "aws-lc-rs")]
117mod aws_lc_rs_backend {
118 use super::{DigestAlgorithm, DigestImplementation};
119
120 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
121 /// A digest backend provided by aws-lc-rs
122 ///
123 /// Usage:
124 /// ```rust
125 /// # let cert_store = rustls::RootCertStore::empty();
126 /// #
127 /// # let rustls_config = rustls::ClientConfig::builder()
128 /// # .with_root_certificates(cert_store)
129 /// # .with_no_client_auth();
130 /// #
131 /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, AwsLcRsDigest};
132 ///
133 /// let tls = MakeRustlsConnect::new(rustls_config, AwsLcRsDigest);
134 ///
135 /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
136 /// ```
137 pub struct AwsLcRsDigest;
138
139 impl DigestImplementation for AwsLcRsDigest {
140 fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
141 let digest_alg = match algorithm {
142 // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
143 DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &aws_lc_rs::digest::SHA256,
144 DigestAlgorithm::Sha384 => &aws_lc_rs::digest::SHA384,
145 DigestAlgorithm::Sha512 => &aws_lc_rs::digest::SHA512,
146 };
147
148 aws_lc_rs::digest::digest(digest_alg, bytes).as_ref().into()
149 }
150 }
151}
152
153#[cfg(feature = "ring")]
154mod ring_backend {
155 use super::{DigestAlgorithm, DigestImplementation};
156
157 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
158 /// a digest backend provided by ring
159 ///
160 /// Usage:
161 /// ```rust
162 /// # let cert_store = rustls::RootCertStore::empty();
163 /// #
164 /// # let rustls_config = rustls::ClientConfig::builder()
165 /// # .with_root_certificates(cert_store)
166 /// # .with_no_client_auth();
167 /// #
168 /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, RingDigest};
169 ///
170 /// let tls = MakeRustlsConnect::new(rustls_config, RingDigest);
171 ///
172 /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
173 /// ```
174 pub struct RingDigest;
175
176 impl DigestImplementation for RingDigest {
177 fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
178 let digest_alg = match algorithm {
179 // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
180 DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &ring::digest::SHA256,
181 DigestAlgorithm::Sha384 => &ring::digest::SHA384,
182 DigestAlgorithm::Sha512 => &ring::digest::SHA512,
183 };
184
185 ring::digest::digest(digest_alg, bytes).as_ref().into()
186 }
187 }
188}
189
190#[cfg(feature = "rustcrypto")]
191mod rustcrypto_backend {
192 use super::{DigestAlgorithm, DigestImplementation};
193 use sha2::{Digest, Sha256, Sha384, Sha512};
194
195 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
196 /// a digest backend provided by rustcrypto
197 ///
198 /// Usage:
199 /// ```rust
200 /// # let cert_store = rustls::RootCertStore::empty();
201 /// #
202 /// # let rustls_config = rustls::ClientConfig::builder()
203 /// # .with_root_certificates(cert_store)
204 /// # .with_no_client_auth();
205 /// #
206 /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, RustcryptoDigest};
207 ///
208 /// let tls = MakeRustlsConnect::new(rustls_config, RustcryptoDigest);
209 ///
210 /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
211 pub struct RustcryptoDigest;
212
213 impl DigestImplementation for RustcryptoDigest {
214 fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
215 match algorithm {
216 // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
217 DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => {
218 Sha256::digest(bytes).as_slice().into()
219 }
220 DigestAlgorithm::Sha384 => Sha384::digest(bytes).as_slice().into(),
221 DigestAlgorithm::Sha512 => Sha512::digest(bytes).as_slice().into(),
222 }
223 }
224 }
225}
226
227#[derive(Clone)]
228/// The primary interface for consumers of this crate
229///
230/// This type can be provided to tokio-postgres' `connect` method in order to add TLS to a postgres
231/// connection.
232/// ```rust
233/// # use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm};
234/// #
235/// # #[derive(Clone)]
236/// # struct DemoDigest;
237/// #
238/// # impl DigestImplementation for DemoDigest {
239/// # fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
240/// # todo!("digest it")
241/// # }
242/// # }
243/// #
244/// # let cert_store = rustls::RootCertStore::empty();
245/// #
246/// # let rustls_config = rustls::ClientConfig::builder()
247/// # .with_root_certificates(cert_store)
248/// # .with_no_client_auth();
249/// #
250/// use tokio_postgres_generic_rustls::MakeRustlsConnect;
251///
252/// let tls = MakeRustlsConnect::new(rustls_config, DemoDigest);
253///
254/// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
255/// ```
256pub struct MakeRustlsConnect<D> {
257 config: Arc<ClientConfig>,
258 digest_impl: D,
259}
260
261impl<D> MakeRustlsConnect<D>
262where
263 D: DigestImplementation,
264{
265 /// Create a new MakeRustlsConnect instance
266 pub fn new(config: ClientConfig, digest_impl: D) -> Self {
267 Self {
268 config: Arc::new(config),
269 digest_impl,
270 }
271 }
272}
273
274impl<D, S> MakeTlsConnect<S> for MakeRustlsConnect<D>
275where
276 D: DigestImplementation + Clone + Unpin,
277 S: AsyncRead + AsyncWrite + Unpin + Send,
278{
279 type Stream = RustlsStream<D, S>;
280 type TlsConnect = RustlsConnect<D>;
281 type Error = rustls::pki_types::InvalidDnsNameError;
282
283 fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
284 ServerName::try_from(domain).map(|dns_name| RustlsConnect {
285 dns_name: dns_name.to_owned(),
286 connector: Arc::clone(&self.config).into(),
287 digest_impl: self.digest_impl.clone(),
288 })
289 }
290}
291
292#[doc(hidden)]
293pub struct RustlsConnect<D> {
294 dns_name: ServerName<'static>,
295 connector: TlsConnector,
296 digest_impl: D,
297}
298
299#[doc(hidden)]
300pub struct ConnectFuture<D, S> {
301 connect: Connect<S>,
302 digest_impl: D,
303}
304
305impl<D, S> Future for ConnectFuture<D, S>
306where
307 D: DigestImplementation + Clone + Unpin,
308 S: AsyncRead + AsyncWrite + Unpin,
309{
310 type Output = std::io::Result<RustlsStream<D, S>>;
311
312 fn poll(
313 self: Pin<&mut Self>,
314 cx: &mut std::task::Context<'_>,
315 ) -> std::task::Poll<Self::Output> {
316 let this = self.get_mut();
317
318 let res = std::task::ready!(Pin::new(&mut this.connect).poll(cx));
319
320 std::task::Poll::Ready(res.map(|io| RustlsStream {
321 io,
322 digest_impl: this.digest_impl.clone(),
323 }))
324 }
325}
326
327impl<D, S> TlsConnect<S> for RustlsConnect<D>
328where
329 D: DigestImplementation + Clone + Unpin,
330 S: AsyncRead + AsyncWrite + Unpin + Send,
331{
332 type Stream = RustlsStream<D, S>;
333 type Error = std::io::Error;
334 type Future = ConnectFuture<D, S>;
335
336 fn connect(self, stream: S) -> Self::Future {
337 ConnectFuture {
338 connect: self.connector.connect(self.dns_name, stream),
339 digest_impl: self.digest_impl.clone(),
340 }
341 }
342}
343
344enum SignatureAlgorithm {
345 // 1.2.840.113549.1.1.5
346 Sha1Rsa,
347
348 // 1.2.840.113549.1.1.11
349 Sha256Rsa,
350
351 // 1.2.840.113549.1.1.12
352 Sha384Rsa,
353
354 // 1.2.840.113549.1.1.13
355 Sha512Rsa,
356
357 // 1.2.840.10045.4.3.2
358 EcdsaSha256,
359
360 // 1.2.840.10045.4.3.3
361 EcdsaSha384,
362}
363
364impl SignatureAlgorithm {
365 fn try_from_identifier(oid: &ObjectIdentifier) -> Option<Self> {
366 if oid == &SHA_1_WITH_RSA_ENCRYPTION {
367 Some(Self::Sha1Rsa)
368 } else if oid == &SHA_256_WITH_RSA_ENCRYPTION {
369 Some(Self::Sha256Rsa)
370 } else if oid == &SHA_384_WITH_RSA_ENCRYPTION {
371 Some(Self::Sha384Rsa)
372 } else if oid == &SHA_512_WITH_RSA_ENCRYPTION {
373 Some(Self::Sha512Rsa)
374 } else if oid == &ECDSA_WITH_SHA_256 {
375 Some(Self::EcdsaSha256)
376 } else if oid == &ECDSA_WITH_SHA_384 {
377 Some(Self::EcdsaSha384)
378 } else {
379 None
380 }
381 }
382
383 fn digest_algorithm(self) -> DigestAlgorithm {
384 match self {
385 Self::Sha1Rsa => DigestAlgorithm::Sha1,
386 Self::Sha256Rsa | Self::EcdsaSha256 => DigestAlgorithm::Sha256,
387 Self::Sha384Rsa | Self::EcdsaSha384 => DigestAlgorithm::Sha384,
388 Self::Sha512Rsa => DigestAlgorithm::Sha512,
389 }
390 }
391}
392
393/// Digest algorithms that can be used in tls-server-end-point channel bindings.
394///
395/// This type is only useful in the context of defining a custom digest impelementation, and
396/// otherwise can be ignored.
397pub enum DigestAlgorithm {
398 /// The provided certificate requests Sha1 hasing. as per
399 /// <https://datatracker.ietf.org/doc/html/rfc5929#section-4.1> sha256 hashing should be used
400 /// instead
401 Sha1,
402
403 /// The provided certificate requests Sha256 hasing
404 Sha256,
405
406 /// The provided certificate requests sha284 hashing
407 Sha384,
408
409 /// The provided certificate requests Sha512 hashing
410 Sha512,
411}
412
413#[doc(hidden)]
414pub struct RustlsStream<D, S> {
415 io: TlsStream<S>,
416 digest_impl: D,
417}
418
419impl<D, S> tokio_postgres::tls::TlsStream for RustlsStream<D, S>
420where
421 D: DigestImplementation + Unpin,
422 S: AsyncRead + AsyncWrite + Unpin,
423{
424 fn channel_binding(&self) -> tokio_postgres::tls::ChannelBinding {
425 let (_, session) = self.io.get_ref();
426
427 match session.peer_certificates() {
428 Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0])
429 .ok()
430 .and_then(|cert| {
431 SignatureAlgorithm::try_from_identifier(&cert.signature_algorithm.oid)
432 })
433 .map(|signature_algorithm| {
434 let digest = self
435 .digest_impl
436 .digest(signature_algorithm.digest_algorithm(), &certs[0]);
437
438 ChannelBinding::tls_server_end_point(digest)
439 })
440 .unwrap_or_else(ChannelBinding::none),
441 _ => ChannelBinding::none(),
442 }
443 }
444}
445
446impl<D, S> AsyncRead for RustlsStream<D, S>
447where
448 D: Unpin,
449 S: AsyncRead + AsyncWrite + Unpin,
450{
451 fn poll_read(
452 self: Pin<&mut Self>,
453 cx: &mut std::task::Context<'_>,
454 buf: &mut tokio::io::ReadBuf<'_>,
455 ) -> std::task::Poll<std::io::Result<()>> {
456 Pin::new(&mut self.get_mut().io).poll_read(cx, buf)
457 }
458}
459
460impl<D, S> AsyncWrite for RustlsStream<D, S>
461where
462 D: Unpin,
463 S: AsyncRead + AsyncWrite + Unpin,
464{
465 fn poll_write(
466 self: Pin<&mut Self>,
467 cx: &mut std::task::Context<'_>,
468 buf: &[u8],
469 ) -> std::task::Poll<Result<usize, std::io::Error>> {
470 Pin::new(&mut self.get_mut().io).poll_write(cx, buf)
471 }
472
473 fn poll_flush(
474 self: Pin<&mut Self>,
475 cx: &mut std::task::Context<'_>,
476 ) -> std::task::Poll<Result<(), std::io::Error>> {
477 Pin::new(&mut self.get_mut().io).poll_flush(cx)
478 }
479
480 fn poll_shutdown(
481 self: Pin<&mut Self>,
482 cx: &mut std::task::Context<'_>,
483 ) -> std::task::Poll<Result<(), std::io::Error>> {
484 Pin::new(&mut self.get_mut().io).poll_shutdown(cx)
485 }
486
487 fn is_write_vectored(&self) -> bool {
488 self.io.is_write_vectored()
489 }
490
491 fn poll_write_vectored(
492 self: Pin<&mut Self>,
493 cx: &mut std::task::Context<'_>,
494 bufs: &[std::io::IoSlice<'_>],
495 ) -> std::task::Poll<Result<usize, std::io::Error>> {
496 Pin::new(&mut self.get_mut().io).poll_write_vectored(cx, bufs)
497 }
498}