tokio_postgres_rustls_improved/
lib.rs1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![forbid(missing_docs, unsafe_code, unused)]
4#![deny(
5 clippy::all,
6 clippy::pedantic,
7 clippy::unwrap_used,
8 clippy::expect_used,
9 clippy::nursery,
10 clippy::dbg_macro,
11 clippy::todo
12)]
13
14use std::{convert::TryFrom, sync::Arc};
15
16use rustls::{ClientConfig, pki_types::ServerName};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio_postgres::tls::MakeTlsConnect;
19
20mod private {
21 use std::{
22 future::Future,
23 io,
24 pin::Pin,
25 task::{Context, Poll},
26 };
27
28 use rustls::pki_types::ServerName;
29 use sha2::digest::const_oid::db::rfc5912::{
30 ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ECDSA_WITH_SHA_512, ID_MD_5, ID_SHA_1, ID_SHA_256,
31 ID_SHA_384, ID_SHA_512, SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION,
32 SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
33 };
34 use sha2::{Digest, Sha256, Sha384, Sha512};
35 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
36 use tokio_postgres::tls::{ChannelBinding, TlsConnect};
37 use tokio_rustls::{TlsConnector, client::TlsStream};
38 use x509_cert::{Certificate, der::Decode};
39
40 pub struct TlsConnectFuture<S> {
41 inner: tokio_rustls::Connect<S>,
42 }
43
44 impl<S> Future for TlsConnectFuture<S>
45 where
46 S: AsyncRead + AsyncWrite + Unpin,
47 {
48 type Output = io::Result<RustlsStream<S>>;
49
50 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
51 Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
52 }
53 }
54
55 pub struct RustlsConnect(pub RustlsConnectData);
56
57 pub struct RustlsConnectData {
58 pub hostname: ServerName<'static>,
59 pub connector: TlsConnector,
60 }
61
62 impl<S> TlsConnect<S> for RustlsConnect
63 where
64 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
65 {
66 type Stream = RustlsStream<S>;
67 type Error = io::Error;
68 type Future = TlsConnectFuture<S>;
69
70 fn connect(self, stream: S) -> Self::Future {
71 TlsConnectFuture {
72 inner: self.0.connector.connect(self.0.hostname, stream),
73 }
74 }
75 }
76
77 pub struct RustlsStream<S>(TlsStream<S>);
78
79 impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
80 where
81 S: AsyncRead + AsyncWrite + Unpin,
82 {
83 fn channel_binding(&self) -> ChannelBinding {
84 let (_, session) = self.0.get_ref();
85 match session.peer_certificates() {
86 Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0]).map_or_else(
87 |_| ChannelBinding::none(),
88 |cert| {
89 match cert.signature_algorithm.oid {
90 ID_MD_5
92 | ID_SHA_1
93 | ID_SHA_256
94 | SHA_1_WITH_RSA_ENCRYPTION
95 | SHA_256_WITH_RSA_ENCRYPTION
96 | ECDSA_WITH_SHA_256 => ChannelBinding::tls_server_end_point(
97 Sha256::digest(certs[0].as_ref()).to_vec(),
98 ),
99 ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
100 ChannelBinding::tls_server_end_point(
101 Sha384::digest(certs[0].as_ref()).to_vec(),
102 )
103 }
104 ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_512 => {
105 ChannelBinding::tls_server_end_point(
106 Sha512::digest(certs[0].as_ref()).to_vec(),
107 )
108 }
109 _ => ChannelBinding::none(),
110 }
111 },
112 ),
113 _ => ChannelBinding::none(),
114 }
115 }
116 }
117
118 impl<S> AsyncRead for RustlsStream<S>
119 where
120 S: AsyncRead + AsyncWrite + Unpin,
121 {
122 fn poll_read(
123 mut self: Pin<&mut Self>,
124 cx: &mut Context<'_>,
125 buf: &mut ReadBuf<'_>,
126 ) -> Poll<tokio::io::Result<()>> {
127 Pin::new(&mut self.0).poll_read(cx, buf)
128 }
129 }
130
131 impl<S> AsyncWrite for RustlsStream<S>
132 where
133 S: AsyncRead + AsyncWrite + Unpin,
134 {
135 fn poll_write(
136 mut self: Pin<&mut Self>,
137 cx: &mut Context<'_>,
138 buf: &[u8],
139 ) -> Poll<tokio::io::Result<usize>> {
140 Pin::new(&mut self.0).poll_write(cx, buf)
141 }
142
143 fn poll_flush(
144 mut self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 ) -> Poll<tokio::io::Result<()>> {
147 Pin::new(&mut self.0).poll_flush(cx)
148 }
149
150 fn poll_shutdown(
151 mut self: Pin<&mut Self>,
152 cx: &mut Context<'_>,
153 ) -> Poll<tokio::io::Result<()>> {
154 Pin::new(&mut self.0).poll_shutdown(cx)
155 }
156 }
157}
158
159#[derive(Clone)]
165pub struct MakeRustlsConnect {
166 config: Arc<ClientConfig>,
167}
168
169impl MakeRustlsConnect {
170 #[must_use]
172 pub fn new(config: ClientConfig) -> Self {
173 Self {
174 config: Arc::new(config),
175 }
176 }
177}
178
179impl<S> MakeTlsConnect<S> for MakeRustlsConnect
180where
181 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
182{
183 type Stream = private::RustlsStream<S>;
184 type TlsConnect = private::RustlsConnect;
185 type Error = rustls::pki_types::InvalidDnsNameError;
186
187 fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
202 ServerName::try_from(hostname).map(|dns_name| {
203 private::RustlsConnect(private::RustlsConnectData {
204 hostname: dns_name.to_owned(),
205 connector: Arc::clone(&self.config).into(),
206 })
207 })
208 }
209}