tokio_postgres_rustls_improved/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![forbid(missing_docs, unsafe_code)]
4#![warn(clippy::all, clippy::pedantic)]
5
6use std::{convert::TryFrom, sync::Arc};
7
8use rustls::{pki_types::ServerName, ClientConfig};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_postgres::tls::MakeTlsConnect;
11
12mod private {
13    use std::{
14        future::Future,
15        io,
16        pin::Pin,
17        task::{Context, Poll},
18    };
19
20    use rustls::pki_types::ServerName;
21    use sha2::digest::const_oid::db::{
22        rfc5912::{
23            ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
24            SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
25            SHA_512_WITH_RSA_ENCRYPTION,
26        },
27        rfc8410::ID_ED_25519,
28    };
29    use sha2::{Digest, Sha256, Sha384, Sha512};
30    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
31    use tokio_postgres::tls::{ChannelBinding, TlsConnect};
32    use tokio_rustls::{client::TlsStream, TlsConnector};
33    use x509_cert::{der::Decode, Certificate};
34
35    pub struct TlsConnectFuture<S> {
36        inner: tokio_rustls::Connect<S>,
37    }
38
39    impl<S> Future for TlsConnectFuture<S>
40    where
41        S: AsyncRead + AsyncWrite + Unpin,
42    {
43        type Output = io::Result<RustlsStream<S>>;
44
45        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
46            Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
47        }
48    }
49
50    pub struct RustlsConnect(pub RustlsConnectData);
51
52    pub struct RustlsConnectData {
53        pub hostname: ServerName<'static>,
54        pub connector: TlsConnector,
55    }
56
57    impl<S> TlsConnect<S> for RustlsConnect
58    where
59        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
60    {
61        type Stream = RustlsStream<S>;
62        type Error = io::Error;
63        type Future = TlsConnectFuture<S>;
64
65        fn connect(self, stream: S) -> Self::Future {
66            TlsConnectFuture {
67                inner: self.0.connector.connect(self.0.hostname, stream),
68            }
69        }
70    }
71
72    pub struct RustlsStream<S>(TlsStream<S>);
73
74    impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
75    where
76        S: AsyncRead + AsyncWrite + Unpin,
77    {
78        fn channel_binding(&self) -> ChannelBinding {
79            let (_, session) = self.0.get_ref();
80            match session.peer_certificates() {
81                Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0]).map_or_else(
82                    |_| ChannelBinding::none(),
83                    |cert| {
84                        match cert.signature_algorithm.oid {
85                            // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
86                            ID_SHA_1
87                            | ID_SHA_256
88                            | SHA_1_WITH_RSA_ENCRYPTION
89                            | SHA_256_WITH_RSA_ENCRYPTION
90                            | ECDSA_WITH_SHA_256 => ChannelBinding::tls_server_end_point(
91                                Sha256::digest(certs[0].as_ref()).to_vec(),
92                            ),
93                            ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
94                                ChannelBinding::tls_server_end_point(
95                                    Sha384::digest(certs[0].as_ref()).to_vec(),
96                                )
97                            }
98                            ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
99                                ChannelBinding::tls_server_end_point(
100                                    Sha512::digest(certs[0].as_ref()).to_vec(),
101                                )
102                            }
103                            _ => ChannelBinding::none(),
104                        }
105                    },
106                ),
107                _ => ChannelBinding::none(),
108            }
109        }
110    }
111
112    impl<S> AsyncRead for RustlsStream<S>
113    where
114        S: AsyncRead + AsyncWrite + Unpin,
115    {
116        fn poll_read(
117            mut self: Pin<&mut Self>,
118            cx: &mut Context<'_>,
119            buf: &mut ReadBuf<'_>,
120        ) -> Poll<tokio::io::Result<()>> {
121            Pin::new(&mut self.0).poll_read(cx, buf)
122        }
123    }
124
125    impl<S> AsyncWrite for RustlsStream<S>
126    where
127        S: AsyncRead + AsyncWrite + Unpin,
128    {
129        fn poll_write(
130            mut self: Pin<&mut Self>,
131            cx: &mut Context<'_>,
132            buf: &[u8],
133        ) -> Poll<tokio::io::Result<usize>> {
134            Pin::new(&mut self.0).poll_write(cx, buf)
135        }
136
137        fn poll_flush(
138            mut self: Pin<&mut Self>,
139            cx: &mut Context<'_>,
140        ) -> Poll<tokio::io::Result<()>> {
141            Pin::new(&mut self.0).poll_flush(cx)
142        }
143
144        fn poll_shutdown(
145            mut self: Pin<&mut Self>,
146            cx: &mut Context<'_>,
147        ) -> Poll<tokio::io::Result<()>> {
148            Pin::new(&mut self.0).poll_shutdown(cx)
149        }
150    }
151}
152
153/// A `MakeTlsConnect` implementation using `rustls`.
154///
155/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
156#[derive(Clone)]
157pub struct MakeRustlsConnect {
158    config: Arc<ClientConfig>,
159}
160
161impl MakeRustlsConnect {
162    /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
163    #[must_use]
164    pub fn new(config: ClientConfig) -> Self {
165        Self {
166            config: Arc::new(config),
167        }
168    }
169}
170
171impl<S> MakeTlsConnect<S> for MakeRustlsConnect
172where
173    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
174{
175    type Stream = private::RustlsStream<S>;
176    type TlsConnect = private::RustlsConnect;
177    type Error = rustls::pki_types::InvalidDnsNameError;
178
179    fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
180        ServerName::try_from(hostname).map(|dns_name| {
181            private::RustlsConnect(private::RustlsConnectData {
182                hostname: dns_name.to_owned(),
183                connector: Arc::clone(&self.config).into(),
184            })
185        })
186    }
187}