tokio_postgres_rustls_improved/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![forbid(missing_docs, unsafe_code, unused)]
4#![deny(
5    clippy::all,
6    clippy::pedantic,
7    clippy::unwrap_used,
8    clippy::expect_used,
9    clippy::nursery,
10    clippy::dbg_macro,
11    clippy::todo
12)]
13
14use std::{convert::TryFrom, sync::Arc};
15
16use rustls::{ClientConfig, pki_types::ServerName};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio_postgres::tls::MakeTlsConnect;
19
20mod private {
21    use std::{
22        future::Future,
23        io,
24        pin::Pin,
25        task::{Context, Poll},
26    };
27
28    use rustls::pki_types::ServerName;
29    use sha2::digest::const_oid::db::rfc5912::{
30        ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ECDSA_WITH_SHA_512, ID_MD_5, ID_SHA_1, ID_SHA_256,
31        ID_SHA_384, ID_SHA_512, SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION,
32        SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
33    };
34    use sha2::{Digest, Sha256, Sha384, Sha512};
35    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
36    use tokio_postgres::tls::{ChannelBinding, TlsConnect};
37    use tokio_rustls::{TlsConnector, client::TlsStream};
38    use x509_cert::{Certificate, der::Decode};
39
40    pub struct TlsConnectFuture<S> {
41        inner: tokio_rustls::Connect<S>,
42    }
43
44    impl<S> Future for TlsConnectFuture<S>
45    where
46        S: AsyncRead + AsyncWrite + Unpin,
47    {
48        type Output = io::Result<RustlsStream<S>>;
49
50        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
51            Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
52        }
53    }
54
55    pub struct RustlsConnect(pub RustlsConnectData);
56
57    pub struct RustlsConnectData {
58        pub hostname: ServerName<'static>,
59        pub connector: TlsConnector,
60    }
61
62    impl<S> TlsConnect<S> for RustlsConnect
63    where
64        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
65    {
66        type Stream = RustlsStream<S>;
67        type Error = io::Error;
68        type Future = TlsConnectFuture<S>;
69
70        fn connect(self, stream: S) -> Self::Future {
71            TlsConnectFuture {
72                inner: self.0.connector.connect(self.0.hostname, stream),
73            }
74        }
75    }
76
77    pub struct RustlsStream<S>(TlsStream<S>);
78
79    impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
80    where
81        S: AsyncRead + AsyncWrite + Unpin,
82    {
83        fn channel_binding(&self) -> ChannelBinding {
84            let (_, session) = self.0.get_ref();
85            match session.peer_certificates() {
86                Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0]).map_or_else(
87                    |_| ChannelBinding::none(),
88                    |cert| {
89                        match cert.signature_algorithm.oid {
90                            // Note: MD5 and SHA1 are upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
91                            ID_MD_5
92                            | ID_SHA_1
93                            | ID_SHA_256
94                            | SHA_1_WITH_RSA_ENCRYPTION
95                            | SHA_256_WITH_RSA_ENCRYPTION
96                            | ECDSA_WITH_SHA_256 => ChannelBinding::tls_server_end_point(
97                                Sha256::digest(certs[0].as_ref()).to_vec(),
98                            ),
99                            ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
100                                ChannelBinding::tls_server_end_point(
101                                    Sha384::digest(certs[0].as_ref()).to_vec(),
102                                )
103                            }
104                            ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_512 => {
105                                ChannelBinding::tls_server_end_point(
106                                    Sha512::digest(certs[0].as_ref()).to_vec(),
107                                )
108                            }
109                            _ => ChannelBinding::none(),
110                        }
111                    },
112                ),
113                _ => ChannelBinding::none(),
114            }
115        }
116    }
117
118    impl<S> AsyncRead for RustlsStream<S>
119    where
120        S: AsyncRead + AsyncWrite + Unpin,
121    {
122        fn poll_read(
123            mut self: Pin<&mut Self>,
124            cx: &mut Context<'_>,
125            buf: &mut ReadBuf<'_>,
126        ) -> Poll<tokio::io::Result<()>> {
127            Pin::new(&mut self.0).poll_read(cx, buf)
128        }
129    }
130
131    impl<S> AsyncWrite for RustlsStream<S>
132    where
133        S: AsyncRead + AsyncWrite + Unpin,
134    {
135        fn poll_write(
136            mut self: Pin<&mut Self>,
137            cx: &mut Context<'_>,
138            buf: &[u8],
139        ) -> Poll<tokio::io::Result<usize>> {
140            Pin::new(&mut self.0).poll_write(cx, buf)
141        }
142
143        fn poll_flush(
144            mut self: Pin<&mut Self>,
145            cx: &mut Context<'_>,
146        ) -> Poll<tokio::io::Result<()>> {
147            Pin::new(&mut self.0).poll_flush(cx)
148        }
149
150        fn poll_shutdown(
151            mut self: Pin<&mut Self>,
152            cx: &mut Context<'_>,
153        ) -> Poll<tokio::io::Result<()>> {
154            Pin::new(&mut self.0).poll_shutdown(cx)
155        }
156    }
157}
158
159/// A [`MakeTlsConnect`](tokio_postgres::tls::MakeTlsConnect) implementation backed by [`rustls`].
160///
161/// This type allows you to establish PostgreSQL connections using `rustls` as the TLS provider,
162/// instead of relying on system-native TLS stacks. It wraps an [`Arc<ClientConfig>`] so that
163/// the TLS configuration can be cheaply cloned and reused across multiple connections.
164#[derive(Clone)]
165pub struct MakeRustlsConnect {
166    config: Arc<ClientConfig>,
167}
168
169impl MakeRustlsConnect {
170    /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
171    #[must_use]
172    pub fn new(config: ClientConfig) -> Self {
173        Self {
174            config: Arc::new(config),
175        }
176    }
177}
178
179impl<S> MakeTlsConnect<S> for MakeRustlsConnect
180where
181    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
182{
183    type Stream = private::RustlsStream<S>;
184    type TlsConnect = private::RustlsConnect;
185    type Error = rustls::pki_types::InvalidDnsNameError;
186
187    /// Creates a new [`MakeRustlsConnect`] from the given [`ClientConfig`].
188    ///
189    /// The configuration is stored inside an [`Arc`], so the returned
190    /// connector can be cloned and shared across tasks or connections
191    /// without duplicating the underlying TLS state.
192    ///
193    /// # Parameters
194    ///
195    /// - `config`: The `rustls` client configuration that determines how TLS
196    ///   handshakes are performed (e.g. certificates, ciphers, root stores).
197    ///
198    /// # Returns
199    ///
200    /// A ready-to-use [`MakeRustlsConnect`] instance.
201    fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
202        ServerName::try_from(hostname).map(|dns_name| {
203            private::RustlsConnect(private::RustlsConnectData {
204                hostname: dns_name.to_owned(),
205                connector: Arc::clone(&self.config).into(),
206            })
207        })
208    }
209}