tokio_postgres_generic_rustls/
lib.rs

1//! # tokio-postgres-generic-rustls
2//! _An impelementation of TLS based on rustls for tokio-postgres_
3//!
4//! This crate allows users to select a crypto backend, or bring their own, rather than relying on
5//! primitives provided by `ring` directly. This is done through the use of x509-cert for
6//! certificate parsing rather than X509-certificate, while also adding an abstraction for
7//! computing digests.
8//!
9//! By default, tokio-postgres-generic-rustls does not provide a digest implementation, but one or
10//! more are provided behind crate features.
11//!
12//! | Feature      | Impelementation    |
13//! | ------------ | ------------------ |
14//! | `aws-lc-rs`  | `AwsLcRsDigest`    |
15//! | `ring`       | `RingDigest`       |
16//! | `rustcrypto` | `RustcryptoDigest` |
17//!
18//! ## Usage
19//! Using this crate is fairly straightforward. First, select your digest impelementation via crate
20//! features (or provide your own), then construct rustls connector for tokio-postgres with your
21//! rustls client configuration.
22//!
23//! The following example demonstrates providing a custom digest backend.
24//!
25//! ```rust
26//! use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm, MakeRustlsConnect};
27//!
28//! #[derive(Clone)]
29//! struct DemoDigest;
30//!
31//! impl DigestImplementation for DemoDigest {
32//!     fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
33//!         todo!("digest it")
34//!     }
35//! }
36//!
37//! let cert_store = rustls::RootCertStore::empty();
38//!
39//! let config = rustls::ClientConfig::builder()
40//!     .with_root_certificates(cert_store)
41//!     .with_no_client_auth();
42//!
43//! let tls = MakeRustlsConnect::new(config, DemoDigest);
44//!
45//! let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
46//!
47//! // connect_future.await;
48//! ```
49//!
50//! ## License
51//! This project is licensed under either of
52//!
53//! - Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
54//! - MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
55//!
56//! at your option.
57
58#![deny(unsafe_code)]
59#![deny(missing_docs)]
60
61use std::{future::Future, pin::Pin, sync::Arc};
62
63use rustls::{pki_types::ServerName, ClientConfig};
64use tokio::io::{AsyncRead, AsyncWrite};
65use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
66use tokio_rustls::{client::TlsStream, Connect, TlsConnector};
67use x509_cert::{
68    der::{
69        oid::db::rfc5912::{
70            ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, SHA_1_WITH_RSA_ENCRYPTION,
71            SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
72        },
73        Decode,
74    },
75    spki::ObjectIdentifier,
76    Certificate,
77};
78
79/// Trait used to provide a custom digest backend to tokio_postgres_generic_rustls
80///
81/// This trait is implementated for three types by default, each behind it's own feature flag. The
82/// provided backends are aws-lc-rs, ring, and rustcrypto.
83pub trait DigestImplementation {
84    /// Hash the provided bytes with the provided algorithm
85    ///
86    /// ```rust
87    /// use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm};
88    ///
89    /// struct CustomAwsLcRsDigest;
90    ///
91    /// impl DigestImplementation for CustomAwsLcRsDigest {
92    ///     fn digest(&self, digest_algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
93    ///         let digest_alg = match digest_algorithm {
94    ///             // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
95    ///             DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &aws_lc_rs::digest::SHA256,
96    ///             DigestAlgorithm::Sha384 => &aws_lc_rs::digest::SHA384,
97    ///             DigestAlgorithm::Sha512 => &aws_lc_rs::digest::SHA512,
98    ///         };
99    ///
100    ///         aws_lc_rs::digest::digest(digest_alg, bytes).as_ref().into()
101    ///     }
102    /// }
103    /// ```
104    fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8>;
105}
106
107#[cfg(feature = "aws-lc-rs")]
108pub use aws_lc_rs_backend::AwsLcRsDigest;
109
110#[cfg(feature = "ring")]
111pub use ring_backend::RingDigest;
112
113#[cfg(feature = "rustcrypto")]
114pub use rustcrypto_backend::RustcryptoDigest;
115
116#[cfg(feature = "aws-lc-rs")]
117mod aws_lc_rs_backend {
118    use super::{DigestAlgorithm, DigestImplementation};
119
120    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
121    /// A digest backend provided by aws-lc-rs
122    ///
123    /// Usage:
124    /// ```rust
125    /// # let cert_store = rustls::RootCertStore::empty();
126    /// #
127    /// # let rustls_config = rustls::ClientConfig::builder()
128    /// #     .with_root_certificates(cert_store)
129    /// #     .with_no_client_auth();
130    /// #
131    /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, AwsLcRsDigest};
132    ///
133    /// let tls = MakeRustlsConnect::new(rustls_config, AwsLcRsDigest);
134    ///
135    /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
136    /// ```
137    pub struct AwsLcRsDigest;
138
139    impl DigestImplementation for AwsLcRsDigest {
140        fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
141            let digest_alg = match algorithm {
142                // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
143                DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &aws_lc_rs::digest::SHA256,
144                DigestAlgorithm::Sha384 => &aws_lc_rs::digest::SHA384,
145                DigestAlgorithm::Sha512 => &aws_lc_rs::digest::SHA512,
146            };
147
148            aws_lc_rs::digest::digest(digest_alg, bytes).as_ref().into()
149        }
150    }
151}
152
153#[cfg(feature = "ring")]
154mod ring_backend {
155    use super::{DigestAlgorithm, DigestImplementation};
156
157    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
158    /// a digest backend provided by ring
159    ///
160    /// Usage:
161    /// ```rust
162    /// # let cert_store = rustls::RootCertStore::empty();
163    /// #
164    /// # let rustls_config = rustls::ClientConfig::builder()
165    /// #     .with_root_certificates(cert_store)
166    /// #     .with_no_client_auth();
167    /// #
168    /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, RingDigest};
169    ///
170    /// let tls = MakeRustlsConnect::new(rustls_config, RingDigest);
171    ///
172    /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
173    /// ```
174    pub struct RingDigest;
175
176    impl DigestImplementation for RingDigest {
177        fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
178            let digest_alg = match algorithm {
179                // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
180                DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &ring::digest::SHA256,
181                DigestAlgorithm::Sha384 => &ring::digest::SHA384,
182                DigestAlgorithm::Sha512 => &ring::digest::SHA512,
183            };
184
185            ring::digest::digest(digest_alg, bytes).as_ref().into()
186        }
187    }
188}
189
190#[cfg(feature = "rustcrypto")]
191mod rustcrypto_backend {
192    use super::{DigestAlgorithm, DigestImplementation};
193    use sha2::{Digest, Sha256, Sha384, Sha512};
194
195    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
196    /// a digest backend provided by rustcrypto
197    ///
198    /// Usage:
199    /// ```rust
200    /// # let cert_store = rustls::RootCertStore::empty();
201    /// #
202    /// # let rustls_config = rustls::ClientConfig::builder()
203    /// #     .with_root_certificates(cert_store)
204    /// #     .with_no_client_auth();
205    /// #
206    /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, RustcryptoDigest};
207    ///
208    /// let tls = MakeRustlsConnect::new(rustls_config, RustcryptoDigest);
209    ///
210    /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
211    pub struct RustcryptoDigest;
212
213    impl DigestImplementation for RustcryptoDigest {
214        fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
215            match algorithm {
216                // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
217                DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => {
218                    Sha256::digest(bytes).as_slice().into()
219                }
220                DigestAlgorithm::Sha384 => Sha384::digest(bytes).as_slice().into(),
221                DigestAlgorithm::Sha512 => Sha512::digest(bytes).as_slice().into(),
222            }
223        }
224    }
225}
226
227#[derive(Clone)]
228/// The primary interface for consumers of this crate
229///
230/// This type can be provided to tokio-postgres' `connect` method in order to add TLS to a postgres
231/// connection.
232/// ```rust
233/// # use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm};
234/// #
235/// # #[derive(Clone)]
236/// # struct DemoDigest;
237/// #
238/// # impl DigestImplementation for DemoDigest {
239/// #     fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
240/// #         todo!("digest it")
241/// #     }
242/// # }
243/// #
244/// # let cert_store = rustls::RootCertStore::empty();
245/// #
246/// # let rustls_config = rustls::ClientConfig::builder()
247/// #     .with_root_certificates(cert_store)
248/// #     .with_no_client_auth();
249/// #
250/// use tokio_postgres_generic_rustls::MakeRustlsConnect;
251///
252/// let tls = MakeRustlsConnect::new(rustls_config, DemoDigest);
253///
254/// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
255/// ```
256pub struct MakeRustlsConnect<D> {
257    config: Arc<ClientConfig>,
258    digest_impl: D,
259}
260
261impl<D> MakeRustlsConnect<D>
262where
263    D: DigestImplementation,
264{
265    /// Create a new MakeRustlsConnect instance
266    pub fn new(config: ClientConfig, digest_impl: D) -> Self {
267        Self {
268            config: Arc::new(config),
269            digest_impl,
270        }
271    }
272}
273
274impl<D, S> MakeTlsConnect<S> for MakeRustlsConnect<D>
275where
276    D: DigestImplementation + Clone + Unpin,
277    S: AsyncRead + AsyncWrite + Unpin + Send,
278{
279    type Stream = RustlsStream<D, S>;
280    type TlsConnect = RustlsConnect<D>;
281    type Error = rustls::pki_types::InvalidDnsNameError;
282
283    fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
284        ServerName::try_from(domain).map(|dns_name| RustlsConnect {
285            dns_name: dns_name.to_owned(),
286            connector: Arc::clone(&self.config).into(),
287            digest_impl: self.digest_impl.clone(),
288        })
289    }
290}
291
292#[doc(hidden)]
293pub struct RustlsConnect<D> {
294    dns_name: ServerName<'static>,
295    connector: TlsConnector,
296    digest_impl: D,
297}
298
299#[doc(hidden)]
300pub struct ConnectFuture<D, S> {
301    connect: Connect<S>,
302    digest_impl: D,
303}
304
305impl<D, S> Future for ConnectFuture<D, S>
306where
307    D: DigestImplementation + Clone + Unpin,
308    S: AsyncRead + AsyncWrite + Unpin,
309{
310    type Output = std::io::Result<RustlsStream<D, S>>;
311
312    fn poll(
313        self: Pin<&mut Self>,
314        cx: &mut std::task::Context<'_>,
315    ) -> std::task::Poll<Self::Output> {
316        let this = self.get_mut();
317
318        let res = std::task::ready!(Pin::new(&mut this.connect).poll(cx));
319
320        std::task::Poll::Ready(res.map(|io| RustlsStream {
321            io,
322            digest_impl: this.digest_impl.clone(),
323        }))
324    }
325}
326
327impl<D, S> TlsConnect<S> for RustlsConnect<D>
328where
329    D: DigestImplementation + Clone + Unpin,
330    S: AsyncRead + AsyncWrite + Unpin + Send,
331{
332    type Stream = RustlsStream<D, S>;
333    type Error = std::io::Error;
334    type Future = ConnectFuture<D, S>;
335
336    fn connect(self, stream: S) -> Self::Future {
337        ConnectFuture {
338            connect: self.connector.connect(self.dns_name, stream),
339            digest_impl: self.digest_impl.clone(),
340        }
341    }
342}
343
344enum SignatureAlgorithm {
345    // 1.2.840.113549.1.1.5
346    Sha1Rsa,
347
348    // 1.2.840.113549.1.1.11
349    Sha256Rsa,
350
351    // 1.2.840.113549.1.1.12
352    Sha384Rsa,
353
354    // 1.2.840.113549.1.1.13
355    Sha512Rsa,
356
357    // 1.2.840.10045.4.3.2
358    EcdsaSha256,
359
360    // 1.2.840.10045.4.3.3
361    EcdsaSha384,
362}
363
364impl SignatureAlgorithm {
365    fn try_from_identifier(oid: &ObjectIdentifier) -> Option<Self> {
366        if oid == &SHA_1_WITH_RSA_ENCRYPTION {
367            Some(Self::Sha1Rsa)
368        } else if oid == &SHA_256_WITH_RSA_ENCRYPTION {
369            Some(Self::Sha256Rsa)
370        } else if oid == &SHA_384_WITH_RSA_ENCRYPTION {
371            Some(Self::Sha384Rsa)
372        } else if oid == &SHA_512_WITH_RSA_ENCRYPTION {
373            Some(Self::Sha512Rsa)
374        } else if oid == &ECDSA_WITH_SHA_256 {
375            Some(Self::EcdsaSha256)
376        } else if oid == &ECDSA_WITH_SHA_384 {
377            Some(Self::EcdsaSha384)
378        } else {
379            None
380        }
381    }
382
383    fn digest_algorithm(self) -> DigestAlgorithm {
384        match self {
385            Self::Sha1Rsa => DigestAlgorithm::Sha1,
386            Self::Sha256Rsa | Self::EcdsaSha256 => DigestAlgorithm::Sha256,
387            Self::Sha384Rsa | Self::EcdsaSha384 => DigestAlgorithm::Sha384,
388            Self::Sha512Rsa => DigestAlgorithm::Sha512,
389        }
390    }
391}
392
393/// Digest algorithms that can be used in tls-server-end-point channel bindings.
394///
395/// This type is only useful in the context of defining a custom digest impelementation, and
396/// otherwise can be ignored.
397pub enum DigestAlgorithm {
398    /// The provided certificate requests Sha1 hasing. as per
399    /// <https://datatracker.ietf.org/doc/html/rfc5929#section-4.1> sha256 hashing should be used
400    /// instead
401    Sha1,
402
403    /// The provided certificate requests Sha256 hasing
404    Sha256,
405
406    /// The provided certificate requests sha284 hashing
407    Sha384,
408
409    /// The provided certificate requests Sha512 hashing
410    Sha512,
411}
412
413#[doc(hidden)]
414pub struct RustlsStream<D, S> {
415    io: TlsStream<S>,
416    digest_impl: D,
417}
418
419impl<D, S> tokio_postgres::tls::TlsStream for RustlsStream<D, S>
420where
421    D: DigestImplementation + Unpin,
422    S: AsyncRead + AsyncWrite + Unpin,
423{
424    fn channel_binding(&self) -> tokio_postgres::tls::ChannelBinding {
425        let (_, session) = self.io.get_ref();
426
427        match session.peer_certificates() {
428            Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0])
429                .ok()
430                .and_then(|cert| {
431                    SignatureAlgorithm::try_from_identifier(&cert.signature_algorithm.oid)
432                })
433                .map(|signature_algorithm| {
434                    let digest = self
435                        .digest_impl
436                        .digest(signature_algorithm.digest_algorithm(), &certs[0]);
437
438                    ChannelBinding::tls_server_end_point(digest)
439                })
440                .unwrap_or_else(ChannelBinding::none),
441            _ => ChannelBinding::none(),
442        }
443    }
444}
445
446impl<D, S> AsyncRead for RustlsStream<D, S>
447where
448    D: Unpin,
449    S: AsyncRead + AsyncWrite + Unpin,
450{
451    fn poll_read(
452        self: Pin<&mut Self>,
453        cx: &mut std::task::Context<'_>,
454        buf: &mut tokio::io::ReadBuf<'_>,
455    ) -> std::task::Poll<std::io::Result<()>> {
456        Pin::new(&mut self.get_mut().io).poll_read(cx, buf)
457    }
458}
459
460impl<D, S> AsyncWrite for RustlsStream<D, S>
461where
462    D: Unpin,
463    S: AsyncRead + AsyncWrite + Unpin,
464{
465    fn poll_write(
466        self: Pin<&mut Self>,
467        cx: &mut std::task::Context<'_>,
468        buf: &[u8],
469    ) -> std::task::Poll<Result<usize, std::io::Error>> {
470        Pin::new(&mut self.get_mut().io).poll_write(cx, buf)
471    }
472
473    fn poll_flush(
474        self: Pin<&mut Self>,
475        cx: &mut std::task::Context<'_>,
476    ) -> std::task::Poll<Result<(), std::io::Error>> {
477        Pin::new(&mut self.get_mut().io).poll_flush(cx)
478    }
479
480    fn poll_shutdown(
481        self: Pin<&mut Self>,
482        cx: &mut std::task::Context<'_>,
483    ) -> std::task::Poll<Result<(), std::io::Error>> {
484        Pin::new(&mut self.get_mut().io).poll_shutdown(cx)
485    }
486
487    fn is_write_vectored(&self) -> bool {
488        self.io.is_write_vectored()
489    }
490
491    fn poll_write_vectored(
492        self: Pin<&mut Self>,
493        cx: &mut std::task::Context<'_>,
494        bufs: &[std::io::IoSlice<'_>],
495    ) -> std::task::Poll<Result<usize, std::io::Error>> {
496        Pin::new(&mut self.get_mut().io).poll_write_vectored(cx, bufs)
497    }
498}