Skip to main content

tokio_postgres_generic_rustls/
lib.rs

1//! # tokio-postgres-generic-rustls
2//! _An impelementation of TLS based on rustls for tokio-postgres_
3//!
4//! This crate allows users to select a crypto backend, or bring their own, rather than relying on
5//! primitives provided by `ring` directly. This is done through the use of x509-cert for
6//! certificate parsing rather than X509-certificate, while also adding an abstraction for
7//! computing digests.
8//!
9//! By default, tokio-postgres-generic-rustls does not provide a digest implementation, but one or
10//! more are provided behind crate features.
11//!
12//! | Feature      | Impelementation    |
13//! | ------------ | ------------------ |
14//! | `aws-lc-rs`  | `AwsLcRsDigest`    |
15//! | `graviola`   | `GraviolaDigest`    |
16//! | `ring`       | `RingDigest`       |
17//! | `rustcrypto` | `RustcryptoDigest` |
18//!
19//! ## Usage
20//! Using this crate is fairly straightforward. First, select your digest impelementation via crate
21//! features (or provide your own), then construct rustls connector for tokio-postgres with your
22//! rustls client configuration.
23//!
24//! The following example demonstrates providing a custom digest backend.
25//!
26//! ```rust
27//! use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm, MakeRustlsConnect};
28//!
29//! #[derive(Clone)]
30//! struct DemoDigest;
31//!
32//! impl DigestImplementation for DemoDigest {
33//!     fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
34//!         todo!("digest it")
35//!     }
36//! }
37//!
38//! let cert_store = rustls::RootCertStore::empty();
39//!
40//! let config = rustls::ClientConfig::builder()
41//!     .with_root_certificates(cert_store)
42//!     .with_no_client_auth();
43//!
44//! let tls = MakeRustlsConnect::new(config, DemoDigest);
45//!
46//! let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
47//!
48//! // connect_future.await;
49//! ```
50//!
51//! ## License
52//! This project is licensed under either of
53//!
54//! - Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
55//! - MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
56//!
57//! at your option.
58
59#![deny(unsafe_code)]
60#![deny(missing_docs)]
61
62use std::{future::Future, pin::Pin, sync::Arc};
63
64use rustls::{pki_types::ServerName, ClientConfig};
65use tokio::io::{AsyncRead, AsyncWrite};
66use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
67use tokio_rustls::{client::TlsStream, Connect, TlsConnector};
68use x509_cert::{
69    der::{
70        oid::db::rfc5912::{
71            ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, SHA_1_WITH_RSA_ENCRYPTION,
72            SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
73        },
74        Decode,
75    },
76    spki::ObjectIdentifier,
77    Certificate,
78};
79
80/// Trait used to provide a custom digest backend to tokio_postgres_generic_rustls
81///
82/// This trait is implementated for three types by default, each behind it's own feature flag. The
83/// provided backends are aws-lc-rs, ring, and rustcrypto.
84pub trait DigestImplementation {
85    /// Hash the provided bytes with the provided algorithm
86    ///
87    /// ```rust
88    /// use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm};
89    ///
90    /// struct CustomAwsLcRsDigest;
91    ///
92    /// impl DigestImplementation for CustomAwsLcRsDigest {
93    ///     fn digest(&self, digest_algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
94    ///         let digest_alg = match digest_algorithm {
95    ///             // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
96    ///             DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &aws_lc_rs::digest::SHA256,
97    ///             DigestAlgorithm::Sha384 => &aws_lc_rs::digest::SHA384,
98    ///             DigestAlgorithm::Sha512 => &aws_lc_rs::digest::SHA512,
99    ///         };
100    ///
101    ///         aws_lc_rs::digest::digest(digest_alg, bytes).as_ref().into()
102    ///     }
103    /// }
104    /// ```
105    fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8>;
106}
107
108#[cfg(feature = "aws-lc-rs")]
109pub use aws_lc_rs_backend::AwsLcRsDigest;
110
111#[cfg(feature = "graviola")]
112pub use graviola_backend::GraviolaDigest;
113
114#[cfg(feature = "ring")]
115pub use ring_backend::RingDigest;
116
117#[cfg(feature = "rustcrypto")]
118pub use rustcrypto_backend::RustcryptoDigest;
119
120#[cfg(feature = "aws-lc-rs")]
121mod aws_lc_rs_backend {
122    use super::{DigestAlgorithm, DigestImplementation};
123
124    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
125    /// A digest backend provided by aws-lc-rs
126    ///
127    /// Usage:
128    /// ```rust
129    /// # let cert_store = rustls::RootCertStore::empty();
130    /// #
131    /// # let rustls_config = rustls::ClientConfig::builder()
132    /// #     .with_root_certificates(cert_store)
133    /// #     .with_no_client_auth();
134    /// #
135    /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, AwsLcRsDigest};
136    ///
137    /// let tls = MakeRustlsConnect::new(rustls_config, AwsLcRsDigest);
138    ///
139    /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
140    /// ```
141    pub struct AwsLcRsDigest;
142
143    impl DigestImplementation for AwsLcRsDigest {
144        fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
145            let digest_alg = match algorithm {
146                // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
147                DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &aws_lc_rs::digest::SHA256,
148                DigestAlgorithm::Sha384 => &aws_lc_rs::digest::SHA384,
149                DigestAlgorithm::Sha512 => &aws_lc_rs::digest::SHA512,
150            };
151
152            aws_lc_rs::digest::digest(digest_alg, bytes).as_ref().into()
153        }
154    }
155}
156
157#[cfg(feature = "graviola")]
158mod graviola_backend {
159    use super::{DigestAlgorithm, DigestImplementation};
160
161    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
162    /// A digest backend provided by aws-lc-rs
163    ///
164    /// Usage:
165    /// ```rust
166    /// # let cert_store = rustls::RootCertStore::empty();
167    /// #
168    /// # let rustls_config = rustls::ClientConfig::builder()
169    /// #     .with_root_certificates(cert_store)
170    /// #     .with_no_client_auth();
171    /// #
172    /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, GraviolaDigest};
173    ///
174    /// let tls = MakeRustlsConnect::new(rustls_config, GraviolaDigest);
175    ///
176    /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
177    /// ```
178    pub struct GraviolaDigest;
179
180    impl DigestImplementation for GraviolaDigest {
181        fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
182            use graviola::hashing::Hash;
183
184            match algorithm {
185                // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
186                DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => {
187                    graviola::hashing::Sha256::hash(bytes).as_ref().to_vec()
188                }
189                DigestAlgorithm::Sha384 => graviola::hashing::Sha384::hash(bytes).as_ref().to_vec(),
190                DigestAlgorithm::Sha512 => graviola::hashing::Sha512::hash(bytes).as_ref().to_vec(),
191            }
192        }
193    }
194}
195
196#[cfg(feature = "ring")]
197mod ring_backend {
198    use super::{DigestAlgorithm, DigestImplementation};
199
200    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
201    /// a digest backend provided by ring
202    ///
203    /// Usage:
204    /// ```rust
205    /// # let cert_store = rustls::RootCertStore::empty();
206    /// #
207    /// # let rustls_config = rustls::ClientConfig::builder()
208    /// #     .with_root_certificates(cert_store)
209    /// #     .with_no_client_auth();
210    /// #
211    /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, RingDigest};
212    ///
213    /// let tls = MakeRustlsConnect::new(rustls_config, RingDigest);
214    ///
215    /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
216    /// ```
217    pub struct RingDigest;
218
219    impl DigestImplementation for RingDigest {
220        fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
221            let digest_alg = match algorithm {
222                // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
223                DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &ring::digest::SHA256,
224                DigestAlgorithm::Sha384 => &ring::digest::SHA384,
225                DigestAlgorithm::Sha512 => &ring::digest::SHA512,
226            };
227
228            ring::digest::digest(digest_alg, bytes).as_ref().into()
229        }
230    }
231}
232
233#[cfg(feature = "rustcrypto")]
234mod rustcrypto_backend {
235    use super::{DigestAlgorithm, DigestImplementation};
236    use sha2::{Digest, Sha256, Sha384, Sha512};
237
238    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
239    /// a digest backend provided by rustcrypto
240    ///
241    /// Usage:
242    /// ```rust
243    /// # let cert_store = rustls::RootCertStore::empty();
244    /// #
245    /// # let rustls_config = rustls::ClientConfig::builder()
246    /// #     .with_root_certificates(cert_store)
247    /// #     .with_no_client_auth();
248    /// #
249    /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, RustcryptoDigest};
250    ///
251    /// let tls = MakeRustlsConnect::new(rustls_config, RustcryptoDigest);
252    ///
253    /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
254    pub struct RustcryptoDigest;
255
256    impl DigestImplementation for RustcryptoDigest {
257        fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
258            match algorithm {
259                // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
260                DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => {
261                    Sha256::digest(bytes).as_slice().into()
262                }
263                DigestAlgorithm::Sha384 => Sha384::digest(bytes).as_slice().into(),
264                DigestAlgorithm::Sha512 => Sha512::digest(bytes).as_slice().into(),
265            }
266        }
267    }
268}
269
270#[derive(Clone)]
271/// The primary interface for consumers of this crate
272///
273/// This type can be provided to tokio-postgres' `connect` method in order to add TLS to a postgres
274/// connection.
275/// ```rust
276/// # use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm};
277/// #
278/// # #[derive(Clone)]
279/// # struct DemoDigest;
280/// #
281/// # impl DigestImplementation for DemoDigest {
282/// #     fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
283/// #         todo!("digest it")
284/// #     }
285/// # }
286/// #
287/// # let cert_store = rustls::RootCertStore::empty();
288/// #
289/// # let rustls_config = rustls::ClientConfig::builder()
290/// #     .with_root_certificates(cert_store)
291/// #     .with_no_client_auth();
292/// #
293/// use tokio_postgres_generic_rustls::MakeRustlsConnect;
294///
295/// let tls = MakeRustlsConnect::new(rustls_config, DemoDigest);
296///
297/// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
298/// ```
299pub struct MakeRustlsConnect<D> {
300    config: Arc<ClientConfig>,
301    digest_impl: D,
302}
303
304impl<D> MakeRustlsConnect<D>
305where
306    D: DigestImplementation,
307{
308    /// Create a new MakeRustlsConnect instance
309    pub fn new(config: ClientConfig, digest_impl: D) -> Self {
310        Self {
311            config: Arc::new(config),
312            digest_impl,
313        }
314    }
315}
316
317impl<D, S> MakeTlsConnect<S> for MakeRustlsConnect<D>
318where
319    D: DigestImplementation + Clone + Unpin,
320    S: AsyncRead + AsyncWrite + Unpin + Send,
321{
322    type Stream = RustlsStream<D, S>;
323    type TlsConnect = RustlsConnect<D>;
324    type Error = rustls::pki_types::InvalidDnsNameError;
325
326    fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
327        ServerName::try_from(domain).map(|dns_name| RustlsConnect {
328            dns_name: dns_name.to_owned(),
329            connector: Arc::clone(&self.config).into(),
330            digest_impl: self.digest_impl.clone(),
331        })
332    }
333}
334
335#[doc(hidden)]
336pub struct RustlsConnect<D> {
337    dns_name: ServerName<'static>,
338    connector: TlsConnector,
339    digest_impl: D,
340}
341
342#[doc(hidden)]
343pub struct ConnectFuture<D, S> {
344    connect: Connect<S>,
345    digest_impl: D,
346}
347
348impl<D, S> Future for ConnectFuture<D, S>
349where
350    D: DigestImplementation + Clone + Unpin,
351    S: AsyncRead + AsyncWrite + Unpin,
352{
353    type Output = std::io::Result<RustlsStream<D, S>>;
354
355    fn poll(
356        self: Pin<&mut Self>,
357        cx: &mut std::task::Context<'_>,
358    ) -> std::task::Poll<Self::Output> {
359        let this = self.get_mut();
360
361        let res = std::task::ready!(Pin::new(&mut this.connect).poll(cx));
362
363        std::task::Poll::Ready(res.map(|io| RustlsStream {
364            io,
365            digest_impl: this.digest_impl.clone(),
366        }))
367    }
368}
369
370impl<D, S> TlsConnect<S> for RustlsConnect<D>
371where
372    D: DigestImplementation + Clone + Unpin,
373    S: AsyncRead + AsyncWrite + Unpin + Send,
374{
375    type Stream = RustlsStream<D, S>;
376    type Error = std::io::Error;
377    type Future = ConnectFuture<D, S>;
378
379    fn connect(self, stream: S) -> Self::Future {
380        ConnectFuture {
381            connect: self.connector.connect(self.dns_name, stream),
382            digest_impl: self.digest_impl.clone(),
383        }
384    }
385}
386
387enum SignatureAlgorithm {
388    // 1.2.840.113549.1.1.5
389    Sha1Rsa,
390
391    // 1.2.840.113549.1.1.11
392    Sha256Rsa,
393
394    // 1.2.840.113549.1.1.12
395    Sha384Rsa,
396
397    // 1.2.840.113549.1.1.13
398    Sha512Rsa,
399
400    // 1.2.840.10045.4.3.2
401    EcdsaSha256,
402
403    // 1.2.840.10045.4.3.3
404    EcdsaSha384,
405}
406
407impl SignatureAlgorithm {
408    fn try_from_identifier(oid: &ObjectIdentifier) -> Option<Self> {
409        if oid == &SHA_1_WITH_RSA_ENCRYPTION {
410            Some(Self::Sha1Rsa)
411        } else if oid == &SHA_256_WITH_RSA_ENCRYPTION {
412            Some(Self::Sha256Rsa)
413        } else if oid == &SHA_384_WITH_RSA_ENCRYPTION {
414            Some(Self::Sha384Rsa)
415        } else if oid == &SHA_512_WITH_RSA_ENCRYPTION {
416            Some(Self::Sha512Rsa)
417        } else if oid == &ECDSA_WITH_SHA_256 {
418            Some(Self::EcdsaSha256)
419        } else if oid == &ECDSA_WITH_SHA_384 {
420            Some(Self::EcdsaSha384)
421        } else {
422            None
423        }
424    }
425
426    fn digest_algorithm(self) -> DigestAlgorithm {
427        match self {
428            Self::Sha1Rsa => DigestAlgorithm::Sha1,
429            Self::Sha256Rsa | Self::EcdsaSha256 => DigestAlgorithm::Sha256,
430            Self::Sha384Rsa | Self::EcdsaSha384 => DigestAlgorithm::Sha384,
431            Self::Sha512Rsa => DigestAlgorithm::Sha512,
432        }
433    }
434}
435
436/// Digest algorithms that can be used in tls-server-end-point channel bindings.
437///
438/// This type is only useful in the context of defining a custom digest impelementation, and
439/// otherwise can be ignored.
440pub enum DigestAlgorithm {
441    /// The provided certificate requests Sha1 hasing. as per
442    /// <https://datatracker.ietf.org/doc/html/rfc5929#section-4.1> sha256 hashing should be used
443    /// instead
444    Sha1,
445
446    /// The provided certificate requests Sha256 hasing
447    Sha256,
448
449    /// The provided certificate requests sha284 hashing
450    Sha384,
451
452    /// The provided certificate requests Sha512 hashing
453    Sha512,
454}
455
456#[doc(hidden)]
457pub struct RustlsStream<D, S> {
458    io: TlsStream<S>,
459    digest_impl: D,
460}
461
462impl<D, S> tokio_postgres::tls::TlsStream for RustlsStream<D, S>
463where
464    D: DigestImplementation + Unpin,
465    S: AsyncRead + AsyncWrite + Unpin,
466{
467    fn channel_binding(&self) -> tokio_postgres::tls::ChannelBinding {
468        let (_, session) = self.io.get_ref();
469
470        match session.peer_certificates() {
471            Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0])
472                .ok()
473                .and_then(|cert| {
474                    SignatureAlgorithm::try_from_identifier(&cert.signature_algorithm.oid)
475                })
476                .map(|signature_algorithm| {
477                    let digest = self
478                        .digest_impl
479                        .digest(signature_algorithm.digest_algorithm(), &certs[0]);
480
481                    ChannelBinding::tls_server_end_point(digest)
482                })
483                .unwrap_or_else(ChannelBinding::none),
484            _ => ChannelBinding::none(),
485        }
486    }
487}
488
489impl<D, S> AsyncRead for RustlsStream<D, S>
490where
491    D: Unpin,
492    S: AsyncRead + AsyncWrite + Unpin,
493{
494    fn poll_read(
495        self: Pin<&mut Self>,
496        cx: &mut std::task::Context<'_>,
497        buf: &mut tokio::io::ReadBuf<'_>,
498    ) -> std::task::Poll<std::io::Result<()>> {
499        Pin::new(&mut self.get_mut().io).poll_read(cx, buf)
500    }
501}
502
503impl<D, S> AsyncWrite for RustlsStream<D, S>
504where
505    D: Unpin,
506    S: AsyncRead + AsyncWrite + Unpin,
507{
508    fn poll_write(
509        self: Pin<&mut Self>,
510        cx: &mut std::task::Context<'_>,
511        buf: &[u8],
512    ) -> std::task::Poll<Result<usize, std::io::Error>> {
513        Pin::new(&mut self.get_mut().io).poll_write(cx, buf)
514    }
515
516    fn poll_flush(
517        self: Pin<&mut Self>,
518        cx: &mut std::task::Context<'_>,
519    ) -> std::task::Poll<Result<(), std::io::Error>> {
520        Pin::new(&mut self.get_mut().io).poll_flush(cx)
521    }
522
523    fn poll_shutdown(
524        self: Pin<&mut Self>,
525        cx: &mut std::task::Context<'_>,
526    ) -> std::task::Poll<Result<(), std::io::Error>> {
527        Pin::new(&mut self.get_mut().io).poll_shutdown(cx)
528    }
529
530    fn is_write_vectored(&self) -> bool {
531        self.io.is_write_vectored()
532    }
533
534    fn poll_write_vectored(
535        self: Pin<&mut Self>,
536        cx: &mut std::task::Context<'_>,
537        bufs: &[std::io::IoSlice<'_>],
538    ) -> std::task::Poll<Result<usize, std::io::Error>> {
539        Pin::new(&mut self.get_mut().io).poll_write_vectored(cx, bufs)
540    }
541}