tokio_postgres_rustls_improved/
lib.rs1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![forbid(rust_2018_idioms)]
4#![forbid(missing_docs, unsafe_code, unused)]
5#![deny(
6 clippy::all,
7 clippy::pedantic,
8 clippy::unwrap_used,
9 clippy::expect_used,
10 clippy::nursery,
11 clippy::dbg_macro,
12 clippy::todo
13)]
14
15use std::{convert::TryFrom, sync::Arc};
16
17use rustls::{ClientConfig, pki_types::ServerName};
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_postgres::tls::MakeTlsConnect;
20
21#[cfg(feature = "config-stream")]
22mod dynamic_config;
23#[cfg(feature = "config-stream")]
24#[cfg_attr(docsrs, doc(cfg(feature = "config-stream")))]
25pub use dynamic_config::MakeDynamicRustlsConnect;
26
27mod private {
28 use std::{
29 future::Future,
30 io,
31 pin::Pin,
32 task::{Context, Poll},
33 };
34
35 use rustls::pki_types::ServerName;
36 use sha2::digest::const_oid::db::rfc5912::{
37 ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ECDSA_WITH_SHA_512, ID_MD_5, ID_SHA_1, ID_SHA_256,
38 ID_SHA_384, ID_SHA_512, SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION,
39 SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
40 };
41 use sha2::{Digest, Sha256, Sha384, Sha512};
42 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
43 use tokio_postgres::tls::{ChannelBinding, TlsConnect};
44 use tokio_rustls::{TlsConnector, client::TlsStream};
45 use x509_cert::{Certificate, der::Decode};
46
47 pub struct TlsConnectFuture<S> {
48 inner: tokio_rustls::Connect<S>,
49 }
50
51 impl<S> Future for TlsConnectFuture<S>
52 where
53 S: AsyncRead + AsyncWrite + Unpin,
54 {
55 type Output = io::Result<RustlsStream<S>>;
56
57 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
58 Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
59 }
60 }
61
62 pub struct RustlsConnect(pub RustlsConnectData);
63
64 pub struct RustlsConnectData {
65 pub hostname: ServerName<'static>,
66 pub connector: TlsConnector,
67 }
68
69 impl<S> TlsConnect<S> for RustlsConnect
70 where
71 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
72 {
73 type Stream = RustlsStream<S>;
74 type Error = io::Error;
75 type Future = TlsConnectFuture<S>;
76
77 fn connect(self, stream: S) -> Self::Future {
78 TlsConnectFuture {
79 inner: self.0.connector.connect(self.0.hostname, stream),
80 }
81 }
82 }
83
84 pub struct RustlsStream<S>(TlsStream<S>);
85
86 impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
87 where
88 S: AsyncRead + AsyncWrite + Unpin,
89 {
90 fn channel_binding(&self) -> ChannelBinding {
91 let (_, session) = self.0.get_ref();
92 match session.peer_certificates() {
93 Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0]).map_or_else(
94 |_| ChannelBinding::none(),
95 |cert| {
96 match cert.signature_algorithm.oid {
97 ID_MD_5
99 | ID_SHA_1
100 | ID_SHA_256
101 | SHA_1_WITH_RSA_ENCRYPTION
102 | SHA_256_WITH_RSA_ENCRYPTION
103 | ECDSA_WITH_SHA_256 => ChannelBinding::tls_server_end_point(
104 Sha256::digest(certs[0].as_ref()).to_vec(),
105 ),
106 ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
107 ChannelBinding::tls_server_end_point(
108 Sha384::digest(certs[0].as_ref()).to_vec(),
109 )
110 }
111 ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_512 => {
112 ChannelBinding::tls_server_end_point(
113 Sha512::digest(certs[0].as_ref()).to_vec(),
114 )
115 }
116 _ => ChannelBinding::none(),
117 }
118 },
119 ),
120 _ => ChannelBinding::none(),
121 }
122 }
123 }
124
125 impl<S> AsyncRead for RustlsStream<S>
126 where
127 S: AsyncRead + AsyncWrite + Unpin,
128 {
129 fn poll_read(
130 mut self: Pin<&mut Self>,
131 cx: &mut Context<'_>,
132 buf: &mut ReadBuf<'_>,
133 ) -> Poll<tokio::io::Result<()>> {
134 Pin::new(&mut self.0).poll_read(cx, buf)
135 }
136 }
137
138 impl<S> AsyncWrite for RustlsStream<S>
139 where
140 S: AsyncRead + AsyncWrite + Unpin,
141 {
142 fn poll_write(
143 mut self: Pin<&mut Self>,
144 cx: &mut Context<'_>,
145 buf: &[u8],
146 ) -> Poll<tokio::io::Result<usize>> {
147 Pin::new(&mut self.0).poll_write(cx, buf)
148 }
149
150 fn poll_flush(
151 mut self: Pin<&mut Self>,
152 cx: &mut Context<'_>,
153 ) -> Poll<tokio::io::Result<()>> {
154 Pin::new(&mut self.0).poll_flush(cx)
155 }
156
157 fn poll_shutdown(
158 mut self: Pin<&mut Self>,
159 cx: &mut Context<'_>,
160 ) -> Poll<tokio::io::Result<()>> {
161 Pin::new(&mut self.0).poll_shutdown(cx)
162 }
163 }
164}
165
166#[derive(Clone)]
172pub struct MakeRustlsConnect {
173 config: Arc<ClientConfig>,
174}
175
176impl MakeRustlsConnect {
177 #[must_use]
179 pub fn new(config: ClientConfig) -> Self {
180 Self {
181 config: Arc::new(config),
182 }
183 }
184}
185
186impl<S> MakeTlsConnect<S> for MakeRustlsConnect
187where
188 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
189{
190 type Stream = private::RustlsStream<S>;
191 type TlsConnect = private::RustlsConnect;
192 type Error = rustls::pki_types::InvalidDnsNameError;
193
194 fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
209 ServerName::try_from(hostname).map(|dns_name| {
210 private::RustlsConnect(private::RustlsConnectData {
211 hostname: dns_name.to_owned(),
212 connector: Arc::clone(&self.config).into(),
213 })
214 })
215 }
216}