tokio_postgres_rustls_improved/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![forbid(rust_2018_idioms)]
4#![forbid(missing_docs, unsafe_code, unused)]
5#![deny(
6    clippy::all,
7    clippy::pedantic,
8    clippy::unwrap_used,
9    clippy::expect_used,
10    clippy::nursery,
11    clippy::dbg_macro,
12    clippy::todo
13)]
14
15use std::{convert::TryFrom, sync::Arc};
16
17use rustls::{ClientConfig, pki_types::ServerName};
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_postgres::tls::MakeTlsConnect;
20
21#[cfg(feature = "config-stream")]
22mod dynamic_config;
23#[cfg(feature = "config-stream")]
24#[cfg_attr(docsrs, doc(cfg(feature = "config-stream")))]
25pub use dynamic_config::MakeDynamicRustlsConnect;
26
27mod private {
28    use std::{
29        future::Future,
30        io,
31        pin::Pin,
32        task::{Context, Poll},
33    };
34
35    use rustls::pki_types::ServerName;
36    use sha2::digest::const_oid::db::rfc5912::{
37        ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ECDSA_WITH_SHA_512, ID_MD_5, ID_SHA_1, ID_SHA_256,
38        ID_SHA_384, ID_SHA_512, SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION,
39        SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
40    };
41    use sha2::{Digest, Sha256, Sha384, Sha512};
42    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
43    use tokio_postgres::tls::{ChannelBinding, TlsConnect};
44    use tokio_rustls::{TlsConnector, client::TlsStream};
45    use x509_cert::{Certificate, der::Decode};
46
47    pub struct TlsConnectFuture<S> {
48        inner: tokio_rustls::Connect<S>,
49    }
50
51    impl<S> Future for TlsConnectFuture<S>
52    where
53        S: AsyncRead + AsyncWrite + Unpin,
54    {
55        type Output = io::Result<RustlsStream<S>>;
56
57        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
58            Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
59        }
60    }
61
62    pub struct RustlsConnect(pub RustlsConnectData);
63
64    pub struct RustlsConnectData {
65        pub hostname: ServerName<'static>,
66        pub connector: TlsConnector,
67    }
68
69    impl<S> TlsConnect<S> for RustlsConnect
70    where
71        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
72    {
73        type Stream = RustlsStream<S>;
74        type Error = io::Error;
75        type Future = TlsConnectFuture<S>;
76
77        fn connect(self, stream: S) -> Self::Future {
78            TlsConnectFuture {
79                inner: self.0.connector.connect(self.0.hostname, stream),
80            }
81        }
82    }
83
84    pub struct RustlsStream<S>(TlsStream<S>);
85
86    impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
87    where
88        S: AsyncRead + AsyncWrite + Unpin,
89    {
90        fn channel_binding(&self) -> ChannelBinding {
91            let (_, session) = self.0.get_ref();
92            match session.peer_certificates() {
93                Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0]).map_or_else(
94                    |_| ChannelBinding::none(),
95                    |cert| {
96                        match cert.signature_algorithm.oid {
97                            // Note: MD5 and SHA1 are upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
98                            ID_MD_5
99                            | ID_SHA_1
100                            | ID_SHA_256
101                            | SHA_1_WITH_RSA_ENCRYPTION
102                            | SHA_256_WITH_RSA_ENCRYPTION
103                            | ECDSA_WITH_SHA_256 => ChannelBinding::tls_server_end_point(
104                                Sha256::digest(certs[0].as_ref()).to_vec(),
105                            ),
106                            ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
107                                ChannelBinding::tls_server_end_point(
108                                    Sha384::digest(certs[0].as_ref()).to_vec(),
109                                )
110                            }
111                            ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_512 => {
112                                ChannelBinding::tls_server_end_point(
113                                    Sha512::digest(certs[0].as_ref()).to_vec(),
114                                )
115                            }
116                            _ => ChannelBinding::none(),
117                        }
118                    },
119                ),
120                _ => ChannelBinding::none(),
121            }
122        }
123    }
124
125    impl<S> AsyncRead for RustlsStream<S>
126    where
127        S: AsyncRead + AsyncWrite + Unpin,
128    {
129        fn poll_read(
130            mut self: Pin<&mut Self>,
131            cx: &mut Context<'_>,
132            buf: &mut ReadBuf<'_>,
133        ) -> Poll<tokio::io::Result<()>> {
134            Pin::new(&mut self.0).poll_read(cx, buf)
135        }
136    }
137
138    impl<S> AsyncWrite for RustlsStream<S>
139    where
140        S: AsyncRead + AsyncWrite + Unpin,
141    {
142        fn poll_write(
143            mut self: Pin<&mut Self>,
144            cx: &mut Context<'_>,
145            buf: &[u8],
146        ) -> Poll<tokio::io::Result<usize>> {
147            Pin::new(&mut self.0).poll_write(cx, buf)
148        }
149
150        fn poll_flush(
151            mut self: Pin<&mut Self>,
152            cx: &mut Context<'_>,
153        ) -> Poll<tokio::io::Result<()>> {
154            Pin::new(&mut self.0).poll_flush(cx)
155        }
156
157        fn poll_shutdown(
158            mut self: Pin<&mut Self>,
159            cx: &mut Context<'_>,
160        ) -> Poll<tokio::io::Result<()>> {
161            Pin::new(&mut self.0).poll_shutdown(cx)
162        }
163    }
164}
165
166/// A [`MakeTlsConnect`] implementation backed by [`rustls`].
167///
168/// This type allows you to establish PostgreSQL connections using `rustls` as the TLS provider,
169/// instead of relying on system-native TLS stacks. It wraps an [`Arc<ClientConfig>`] so that
170/// the TLS configuration can be cheaply cloned and reused across multiple connections.
171#[derive(Clone)]
172pub struct MakeRustlsConnect {
173    config: Arc<ClientConfig>,
174}
175
176impl MakeRustlsConnect {
177    /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
178    #[must_use]
179    pub fn new(config: ClientConfig) -> Self {
180        Self {
181            config: Arc::new(config),
182        }
183    }
184}
185
186impl<S> MakeTlsConnect<S> for MakeRustlsConnect
187where
188    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
189{
190    type Stream = private::RustlsStream<S>;
191    type TlsConnect = private::RustlsConnect;
192    type Error = rustls::pki_types::InvalidDnsNameError;
193
194    /// Creates a new [`MakeRustlsConnect`] from the given [`ClientConfig`].
195    ///
196    /// The configuration is stored inside an [`Arc`], so the returned
197    /// connector can be cloned and shared across tasks or connections
198    /// without duplicating the underlying TLS state.
199    ///
200    /// # Parameters
201    ///
202    /// - `config`: The `rustls` client configuration that determines how TLS
203    ///   handshakes are performed (e.g. certificates, ciphers, root stores).
204    ///
205    /// # Returns
206    ///
207    /// A ready-to-use [`MakeRustlsConnect`] instance.
208    fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
209        ServerName::try_from(hostname).map(|dns_name| {
210            private::RustlsConnect(private::RustlsConnectData {
211                hostname: dns_name.to_owned(),
212                connector: Arc::clone(&self.config).into(),
213            })
214        })
215    }
216}