tokio_postgres_tls/
lib.rs

1use std::{
2  convert::TryFrom,
3  future::Future,
4  io,
5  pin::Pin,
6  sync::Arc,
7  task::{Context, Poll},
8};
9
10use ring::digest;
11use rustls::{pki_types::ServerName, ClientConfig};
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
14use tokio_rustls::{client::TlsStream, TlsConnector};
15use x509_certificate::{DigestAlgorithm, SignatureAlgorithm, X509Certificate};
16use DigestAlgorithm::{Sha1, Sha256, Sha384, Sha512};
17use SignatureAlgorithm::{
18  EcdsaSha256, EcdsaSha384, Ed25519, NoSignature, RsaSha1, RsaSha256, RsaSha384, RsaSha512,
19};
20
21#[derive(Clone)]
22pub struct MakeRustlsConnect {
23  pub config: Arc<ClientConfig>,
24}
25
26impl MakeRustlsConnect {
27  pub fn new(config: ClientConfig) -> Self {
28    Self {
29      config: Arc::new(config),
30    }
31  }
32}
33
34impl<S> MakeTlsConnect<S> for MakeRustlsConnect
35where
36  S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
37{
38  type Stream = RustlsStream<S>;
39  type TlsConnect = RustlsConnect;
40  type Error = rustls::pki_types::InvalidDnsNameError;
41
42  fn make_tls_connect(&mut self, hostname: &str) -> Result<RustlsConnect, Self::Error> {
43    ServerName::try_from(hostname).map(|dns_name| {
44      RustlsConnect(RustlsConnectData {
45        hostname: dns_name.to_owned(),
46        connector: Arc::clone(&self.config).into(),
47      })
48    })
49  }
50}
51
52pub struct RustlsConnect(RustlsConnectData);
53
54struct RustlsConnectData {
55  hostname: ServerName<'static>,
56  connector: TlsConnector,
57}
58
59impl<S> TlsConnect<S> for RustlsConnect
60where
61  S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
62{
63  type Stream = RustlsStream<S>;
64  type Error = io::Error;
65  type Future = Pin<Box<dyn Future<Output = io::Result<RustlsStream<S>>> + Send>>;
66
67  fn connect(self, stream: S) -> Self::Future {
68    Box::pin(async move {
69      self
70        .0
71        .connector
72        .connect(self.0.hostname, stream)
73        .await
74        .map(|s| RustlsStream(Box::pin(s)))
75    })
76  }
77}
78
79pub struct RustlsStream<S>(Pin<Box<TlsStream<S>>>);
80
81impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
82where
83  S: AsyncRead + AsyncWrite + Unpin,
84{
85  fn channel_binding(&self) -> ChannelBinding {
86    let (_, session) = self.0.get_ref();
87    match session.peer_certificates() {
88      Some(certs) if !certs.is_empty() => X509Certificate::from_der(&certs[0])
89        .ok()
90        .and_then(|cert| cert.signature_algorithm())
91        .map(|algorithm| match algorithm {
92          // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
93          RsaSha1 | RsaSha256 | EcdsaSha256 => &digest::SHA256,
94          RsaSha384 | EcdsaSha384 => &digest::SHA384,
95          RsaSha512 => &digest::SHA512,
96          Ed25519 => &digest::SHA512,
97          NoSignature(algo) => match algo {
98            Sha1 | Sha256 => &digest::SHA256,
99            Sha384 => &digest::SHA384,
100            Sha512 => &digest::SHA512,
101          },
102        })
103        .map(|algorithm| {
104          let hash = digest::digest(algorithm, certs[0].as_ref());
105          ChannelBinding::tls_server_end_point(hash.as_ref().into())
106        })
107        .unwrap_or(ChannelBinding::none()),
108      _ => ChannelBinding::none(),
109    }
110  }
111}
112
113impl<S> AsyncRead for RustlsStream<S>
114where
115  S: AsyncRead + AsyncWrite + Unpin,
116{
117  fn poll_read(
118    mut self: Pin<&mut Self>,
119    cx: &mut Context,
120    buf: &mut ReadBuf<'_>,
121  ) -> Poll<tokio::io::Result<()>> {
122    self.0.as_mut().poll_read(cx, buf)
123  }
124}
125
126impl<S> AsyncWrite for RustlsStream<S>
127where
128  S: AsyncRead + AsyncWrite + Unpin,
129{
130  fn poll_write(
131    mut self: Pin<&mut Self>,
132    cx: &mut Context,
133    buf: &[u8],
134  ) -> Poll<tokio::io::Result<usize>> {
135    self.0.as_mut().poll_write(cx, buf)
136  }
137
138  fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
139    self.0.as_mut().poll_flush(cx)
140  }
141
142  fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
143    self.0.as_mut().poll_shutdown(cx)
144  }
145}
146
147#[cfg(test)]
148mod tests {
149  use rustls::{
150    client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
151    pki_types::{CertificateDer, UnixTime},
152    Error, SignatureScheme,
153  };
154
155  use super::*;
156
157  #[derive(Debug)]
158  struct AcceptAllVerifier {}
159  impl ServerCertVerifier for AcceptAllVerifier {
160    fn verify_server_cert(
161      &self,
162      _end_entity: &CertificateDer<'_>,
163      _intermediates: &[CertificateDer<'_>],
164      _server_name: &ServerName<'_>,
165      _ocsp_response: &[u8],
166      _now: UnixTime,
167    ) -> Result<ServerCertVerified, Error> {
168      Ok(ServerCertVerified::assertion())
169    }
170
171    fn verify_tls12_signature(
172      &self,
173      _message: &[u8],
174      _cert: &CertificateDer<'_>,
175      _dss: &rustls::DigitallySignedStruct,
176    ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
177      Ok(HandshakeSignatureValid::assertion())
178    }
179
180    fn verify_tls13_signature(
181      &self,
182      _message: &[u8],
183      _cert: &CertificateDer<'_>,
184      _dss: &rustls::DigitallySignedStruct,
185    ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
186      Ok(HandshakeSignatureValid::assertion())
187    }
188
189    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
190      vec![
191        SignatureScheme::ECDSA_NISTP384_SHA384,
192        SignatureScheme::ECDSA_NISTP256_SHA256,
193        SignatureScheme::RSA_PSS_SHA512,
194        SignatureScheme::RSA_PSS_SHA384,
195        SignatureScheme::RSA_PSS_SHA256,
196        SignatureScheme::ED25519,
197      ]
198    }
199  }
200
201  #[tokio::test]
202  async fn it_works() {
203    env_logger::builder().is_test(true).try_init().unwrap();
204
205    let mut config = rustls::ClientConfig::builder()
206      .with_root_certificates(rustls::RootCertStore::empty())
207      .with_no_client_auth();
208    config
209      .dangerous()
210      .set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
211    let tls = super::MakeRustlsConnect::new(config);
212    let (client, conn) = tokio_postgres::connect(
213      "sslmode=require host=localhost port=5432 user=postgres",
214      tls,
215    )
216    .await
217    .expect("connect");
218    tokio::spawn(async move { conn.await.map_err(|e| panic!("{:?}", e)) });
219    let stmt = client.prepare("SELECT 1").await.expect("prepare");
220    let _ = client.query(&stmt, &[]).await.expect("query");
221  }
222}