tokio_postgres_rustls_improved/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![forbid(missing_docs, unsafe_code, unused)]
4#![deny(
5    clippy::all,
6    clippy::pedantic,
7    clippy::unwrap_used,
8    clippy::expect_used,
9    clippy::nursery,
10    clippy::dbg_macro,
11    clippy::todo
12)]
13
14use std::{convert::TryFrom, sync::Arc};
15
16use rustls::{pki_types::ServerName, ClientConfig};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio_postgres::tls::MakeTlsConnect;
19
20mod private {
21    use std::{
22        future::Future,
23        io,
24        pin::Pin,
25        task::{Context, Poll},
26    };
27
28    use rustls::pki_types::ServerName;
29    use sha2::digest::const_oid::db::{
30        rfc5912::{
31            ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
32            SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
33            SHA_512_WITH_RSA_ENCRYPTION,
34        },
35        rfc8410::ID_ED_25519,
36    };
37    use sha2::{Digest, Sha256, Sha384, Sha512};
38    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
39    use tokio_postgres::tls::{ChannelBinding, TlsConnect};
40    use tokio_rustls::{client::TlsStream, TlsConnector};
41    use x509_cert::{der::Decode, Certificate};
42
43    pub struct TlsConnectFuture<S> {
44        inner: tokio_rustls::Connect<S>,
45    }
46
47    impl<S> Future for TlsConnectFuture<S>
48    where
49        S: AsyncRead + AsyncWrite + Unpin,
50    {
51        type Output = io::Result<RustlsStream<S>>;
52
53        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54            Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
55        }
56    }
57
58    pub struct RustlsConnect(pub RustlsConnectData);
59
60    pub struct RustlsConnectData {
61        pub hostname: ServerName<'static>,
62        pub connector: TlsConnector,
63    }
64
65    impl<S> TlsConnect<S> for RustlsConnect
66    where
67        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
68    {
69        type Stream = RustlsStream<S>;
70        type Error = io::Error;
71        type Future = TlsConnectFuture<S>;
72
73        fn connect(self, stream: S) -> Self::Future {
74            TlsConnectFuture {
75                inner: self.0.connector.connect(self.0.hostname, stream),
76            }
77        }
78    }
79
80    pub struct RustlsStream<S>(TlsStream<S>);
81
82    impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
83    where
84        S: AsyncRead + AsyncWrite + Unpin,
85    {
86        fn channel_binding(&self) -> ChannelBinding {
87            let (_, session) = self.0.get_ref();
88            match session.peer_certificates() {
89                Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0]).map_or_else(
90                    |_| ChannelBinding::none(),
91                    |cert| {
92                        match cert.signature_algorithm.oid {
93                            // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
94                            ID_SHA_1
95                            | ID_SHA_256
96                            | SHA_1_WITH_RSA_ENCRYPTION
97                            | SHA_256_WITH_RSA_ENCRYPTION
98                            | ECDSA_WITH_SHA_256 => ChannelBinding::tls_server_end_point(
99                                Sha256::digest(certs[0].as_ref()).to_vec(),
100                            ),
101                            ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
102                                ChannelBinding::tls_server_end_point(
103                                    Sha384::digest(certs[0].as_ref()).to_vec(),
104                                )
105                            }
106                            ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
107                                ChannelBinding::tls_server_end_point(
108                                    Sha512::digest(certs[0].as_ref()).to_vec(),
109                                )
110                            }
111                            _ => ChannelBinding::none(),
112                        }
113                    },
114                ),
115                _ => ChannelBinding::none(),
116            }
117        }
118    }
119
120    impl<S> AsyncRead for RustlsStream<S>
121    where
122        S: AsyncRead + AsyncWrite + Unpin,
123    {
124        fn poll_read(
125            mut self: Pin<&mut Self>,
126            cx: &mut Context<'_>,
127            buf: &mut ReadBuf<'_>,
128        ) -> Poll<tokio::io::Result<()>> {
129            Pin::new(&mut self.0).poll_read(cx, buf)
130        }
131    }
132
133    impl<S> AsyncWrite for RustlsStream<S>
134    where
135        S: AsyncRead + AsyncWrite + Unpin,
136    {
137        fn poll_write(
138            mut self: Pin<&mut Self>,
139            cx: &mut Context<'_>,
140            buf: &[u8],
141        ) -> Poll<tokio::io::Result<usize>> {
142            Pin::new(&mut self.0).poll_write(cx, buf)
143        }
144
145        fn poll_flush(
146            mut self: Pin<&mut Self>,
147            cx: &mut Context<'_>,
148        ) -> Poll<tokio::io::Result<()>> {
149            Pin::new(&mut self.0).poll_flush(cx)
150        }
151
152        fn poll_shutdown(
153            mut self: Pin<&mut Self>,
154            cx: &mut Context<'_>,
155        ) -> Poll<tokio::io::Result<()>> {
156            Pin::new(&mut self.0).poll_shutdown(cx)
157        }
158    }
159}
160
161/// A [`MakeTlsConnect`](tokio_postgres::tls::MakeTlsConnect) implementation backed by [`rustls`].
162///
163/// This type allows you to establish PostgreSQL connections using `rustls` as the TLS provider,
164/// instead of relying on system-native TLS stacks. It wraps an [`Arc<ClientConfig>`] so that
165/// the TLS configuration can be cheaply cloned and reused across multiple connections.
166#[derive(Clone)]
167pub struct MakeRustlsConnect {
168    config: Arc<ClientConfig>,
169}
170
171impl MakeRustlsConnect {
172    /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
173    #[must_use]
174    pub fn new(config: ClientConfig) -> Self {
175        Self {
176            config: Arc::new(config),
177        }
178    }
179}
180
181impl<S> MakeTlsConnect<S> for MakeRustlsConnect
182where
183    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
184{
185    type Stream = private::RustlsStream<S>;
186    type TlsConnect = private::RustlsConnect;
187    type Error = rustls::pki_types::InvalidDnsNameError;
188
189    /// Creates a new [`MakeRustlsConnect`] from the given [`ClientConfig`].
190    ///
191    /// The configuration is stored inside an [`Arc`], so the returned
192    /// connector can be cloned and shared across tasks or connections
193    /// without duplicating the underlying TLS state.
194    ///
195    /// # Parameters
196    ///
197    /// - `config`: The `rustls` client configuration that determines how TLS
198    ///   handshakes are performed (e.g. certificates, ciphers, root stores).
199    ///
200    /// # Returns
201    ///
202    /// A ready-to-use [`MakeRustlsConnect`] instance.
203    fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
204        ServerName::try_from(hostname).map(|dns_name| {
205            private::RustlsConnect(private::RustlsConnectData {
206                hostname: dns_name.to_owned(),
207                connector: Arc::clone(&self.config).into(),
208            })
209        })
210    }
211}