1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![deny(missing_docs, unsafe_code)]
4#![warn(clippy::all, clippy::pedantic)]
5
6use std::{convert::TryFrom, sync::Arc};
7
8use rustls::{pki_types::ServerName, ClientConfig};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_postgres::tls::MakeTlsConnect;
11
12mod private {
13 use std::{
14 future::Future,
15 io,
16 pin::Pin,
17 task::{Context, Poll},
18 };
19
20 use const_oid::db::{
21 rfc5912::{
22 ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
23 SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
24 SHA_512_WITH_RSA_ENCRYPTION,
25 },
26 rfc8410::ID_ED_25519,
27 };
28 use ring::digest;
29 use rustls::pki_types::ServerName;
30 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
31 use tokio_postgres::tls::{ChannelBinding, TlsConnect};
32 use tokio_rustls::{client::TlsStream, TlsConnector};
33 use x509_cert::{der::Decode, TbsCertificate};
34
35 pub struct TlsConnectFuture<S> {
36 pub inner: tokio_rustls::Connect<S>,
37 }
38
39 impl<S> Future for TlsConnectFuture<S>
40 where
41 S: AsyncRead + AsyncWrite + Unpin,
42 {
43 type Output = io::Result<RustlsStream<S>>;
44
45 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
46 #[allow(unsafe_code)]
48 let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) };
49 fut.poll(cx).map_ok(RustlsStream)
50 }
51 }
52
53 pub struct RustlsConnect(pub RustlsConnectData);
54
55 pub struct RustlsConnectData {
56 pub hostname: ServerName<'static>,
57 pub connector: TlsConnector,
58 }
59
60 impl<S> TlsConnect<S> for RustlsConnect
61 where
62 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
63 {
64 type Stream = RustlsStream<S>;
65 type Error = io::Error;
66 type Future = TlsConnectFuture<S>;
67
68 fn connect(self, stream: S) -> Self::Future {
69 TlsConnectFuture {
70 inner: self.0.connector.connect(self.0.hostname, stream),
71 }
72 }
73 }
74
75 pub struct RustlsStream<S>(TlsStream<S>);
76
77 impl<S> RustlsStream<S> {
78 pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream<S>> {
79 #[allow(unsafe_code)]
81 unsafe {
82 self.map_unchecked_mut(|this| &mut this.0)
83 }
84 }
85 }
86
87 impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
88 where
89 S: AsyncRead + AsyncWrite + Unpin,
90 {
91 fn channel_binding(&self) -> ChannelBinding {
92 let (_, session) = self.0.get_ref();
93 match session.peer_certificates() {
94 Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0])
95 .ok()
96 .and_then(|cert| {
97 let digest = match cert.signature.oid {
98 ID_SHA_1
100 | ID_SHA_256
101 | SHA_1_WITH_RSA_ENCRYPTION
102 | SHA_256_WITH_RSA_ENCRYPTION
103 | ECDSA_WITH_SHA_256 => &digest::SHA256,
104 ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
105 &digest::SHA384
106 }
107 ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
108 &digest::SHA512
109 }
110 _ => return None,
111 };
112
113 Some(digest)
114 })
115 .map_or_else(ChannelBinding::none, |algorithm| {
116 let hash = digest::digest(algorithm, certs[0].as_ref());
117 ChannelBinding::tls_server_end_point(hash.as_ref().into())
118 }),
119 _ => ChannelBinding::none(),
120 }
121 }
122 }
123
124 impl<S> AsyncRead for RustlsStream<S>
125 where
126 S: AsyncRead + AsyncWrite + Unpin,
127 {
128 fn poll_read(
129 self: Pin<&mut Self>,
130 cx: &mut Context<'_>,
131 buf: &mut ReadBuf<'_>,
132 ) -> Poll<tokio::io::Result<()>> {
133 self.project_stream().poll_read(cx, buf)
134 }
135 }
136
137 impl<S> AsyncWrite for RustlsStream<S>
138 where
139 S: AsyncRead + AsyncWrite + Unpin,
140 {
141 fn poll_write(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 buf: &[u8],
145 ) -> Poll<tokio::io::Result<usize>> {
146 self.project_stream().poll_write(cx, buf)
147 }
148
149 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
150 self.project_stream().poll_flush(cx)
151 }
152
153 fn poll_shutdown(
154 self: Pin<&mut Self>,
155 cx: &mut Context<'_>,
156 ) -> Poll<tokio::io::Result<()>> {
157 self.project_stream().poll_shutdown(cx)
158 }
159 }
160}
161
162#[derive(Clone)]
166pub struct MakeRustlsConnect {
167 config: Arc<ClientConfig>,
168}
169
170impl MakeRustlsConnect {
171 #[must_use]
173 pub fn new(config: ClientConfig) -> Self {
174 Self {
175 config: Arc::new(config),
176 }
177 }
178}
179
180impl<S> MakeTlsConnect<S> for MakeRustlsConnect
181where
182 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
183{
184 type Stream = private::RustlsStream<S>;
185 type TlsConnect = private::RustlsConnect;
186 type Error = rustls::pki_types::InvalidDnsNameError;
187
188 fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
189 ServerName::try_from(hostname).map(|dns_name| {
190 private::RustlsConnect(private::RustlsConnectData {
191 hostname: dns_name.to_owned(),
192 connector: Arc::clone(&self.config).into(),
193 })
194 })
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use rustls::pki_types::{CertificateDer, UnixTime};
202 use rustls::{
203 client::danger::ServerCertVerifier,
204 client::danger::{HandshakeSignatureValid, ServerCertVerified},
205 Error, SignatureScheme,
206 };
207
208 #[derive(Debug)]
209 struct AcceptAllVerifier {}
210 impl ServerCertVerifier for AcceptAllVerifier {
211 fn verify_server_cert(
212 &self,
213 _end_entity: &CertificateDer<'_>,
214 _intermediates: &[CertificateDer<'_>],
215 _server_name: &ServerName<'_>,
216 _ocsp_response: &[u8],
217 _now: UnixTime,
218 ) -> Result<ServerCertVerified, Error> {
219 Ok(ServerCertVerified::assertion())
220 }
221
222 fn verify_tls12_signature(
223 &self,
224 _message: &[u8],
225 _cert: &CertificateDer<'_>,
226 _dss: &rustls::DigitallySignedStruct,
227 ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
228 Ok(HandshakeSignatureValid::assertion())
229 }
230
231 fn verify_tls13_signature(
232 &self,
233 _message: &[u8],
234 _cert: &CertificateDer<'_>,
235 _dss: &rustls::DigitallySignedStruct,
236 ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
237 Ok(HandshakeSignatureValid::assertion())
238 }
239
240 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
241 vec![
242 SignatureScheme::ECDSA_NISTP384_SHA384,
243 SignatureScheme::ECDSA_NISTP256_SHA256,
244 SignatureScheme::RSA_PSS_SHA512,
245 SignatureScheme::RSA_PSS_SHA384,
246 SignatureScheme::RSA_PSS_SHA256,
247 SignatureScheme::ED25519,
248 ]
249 }
250 }
251
252 #[tokio::test]
253 async fn it_works() {
254 env_logger::builder().is_test(true).try_init().unwrap();
255
256 let mut config = rustls::ClientConfig::builder()
257 .with_root_certificates(rustls::RootCertStore::empty())
258 .with_no_client_auth();
259 config
260 .dangerous()
261 .set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
262 let tls = super::MakeRustlsConnect::new(config);
263 let (client, conn) = tokio_postgres::connect(
264 "sslmode=require host=localhost port=5432 user=postgres",
265 tls,
266 )
267 .await
268 .expect("connect");
269 tokio::spawn(async move { conn.await.map_err(|e| panic!("{:?}", e)) });
270 let stmt = client.prepare("SELECT 1").await.expect("prepare");
271 let _ = client.query(&stmt, &[]).await.expect("query");
272 }
273}