yb_postgres_native_tls/
lib.rs1#![warn(rust_2018_idioms, clippy::all, missing_docs)]
55
56use std::future::Future;
57use std::io;
58use std::pin::Pin;
59use std::task::{Context, Poll};
60use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
61use yb_tokio_postgres::tls;
62#[cfg(feature = "runtime")]
63use yb_tokio_postgres::tls::MakeTlsConnect;
64use yb_tokio_postgres::tls::{ChannelBinding, TlsConnect};
65
66#[cfg(test)]
67mod test;
68
69#[cfg(feature = "runtime")]
73#[derive(Clone)]
74pub struct MakeTlsConnector(native_tls::TlsConnector);
75
76#[cfg(feature = "runtime")]
77impl MakeTlsConnector {
78 pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector {
80 MakeTlsConnector(connector)
81 }
82}
83
84#[cfg(feature = "runtime")]
85impl<S> MakeTlsConnect<S> for MakeTlsConnector
86where
87 S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
88{
89 type Stream = TlsStream<S>;
90 type TlsConnect = TlsConnector;
91 type Error = native_tls::Error;
92
93 fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, native_tls::Error> {
94 Ok(TlsConnector::new(self.0.clone(), domain))
95 }
96}
97
98pub struct TlsConnector {
100 connector: tokio_native_tls::TlsConnector,
101 domain: String,
102}
103
104impl TlsConnector {
105 pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
107 TlsConnector {
108 connector: tokio_native_tls::TlsConnector::from(connector),
109 domain: domain.to_string(),
110 }
111 }
112}
113
114impl<S> TlsConnect<S> for TlsConnector
115where
116 S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
117{
118 type Stream = TlsStream<S>;
119 type Error = native_tls::Error;
120 #[allow(clippy::type_complexity)]
121 type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
122
123 fn connect(self, stream: S) -> Self::Future {
124 let stream = BufReader::with_capacity(8192, stream);
125 let future = async move {
126 let stream = self.connector.connect(&self.domain, stream).await?;
127
128 Ok(TlsStream(stream))
129 };
130
131 Box::pin(future)
132 }
133}
134
135pub struct TlsStream<S>(tokio_native_tls::TlsStream<BufReader<S>>);
137
138impl<S> AsyncRead for TlsStream<S>
139where
140 S: AsyncRead + AsyncWrite + Unpin,
141{
142 fn poll_read(
143 mut self: Pin<&mut Self>,
144 cx: &mut Context<'_>,
145 buf: &mut ReadBuf<'_>,
146 ) -> Poll<io::Result<()>> {
147 Pin::new(&mut self.0).poll_read(cx, buf)
148 }
149}
150
151impl<S> AsyncWrite for TlsStream<S>
152where
153 S: AsyncRead + AsyncWrite + Unpin,
154{
155 fn poll_write(
156 mut self: Pin<&mut Self>,
157 cx: &mut Context<'_>,
158 buf: &[u8],
159 ) -> Poll<io::Result<usize>> {
160 Pin::new(&mut self.0).poll_write(cx, buf)
161 }
162
163 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
164 Pin::new(&mut self.0).poll_flush(cx)
165 }
166
167 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
168 Pin::new(&mut self.0).poll_shutdown(cx)
169 }
170}
171
172impl<S> tls::TlsStream for TlsStream<S>
173where
174 S: AsyncRead + AsyncWrite + Unpin,
175{
176 fn channel_binding(&self) -> ChannelBinding {
177 match self.0.get_ref().tls_server_end_point().ok().flatten() {
178 Some(buf) => ChannelBinding::tls_server_end_point(buf),
179 None => ChannelBinding::none(),
180 }
181 }
182}