yb_postgres_native_tls/
lib.rs

1//! TLS support for `tokio-postgres` and `postgres` via `native-tls`.
2//!
3//! # Examples
4//!
5//! ```no_run
6//! use native_tls::{Certificate, TlsConnector};
7//! # #[cfg(feature = "runtime")]
8//! use postgres_native_tls::MakeTlsConnector;
9//! use std::fs;
10//!
11//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
12//! # #[cfg(feature = "runtime")] {
13//! let cert = fs::read("database_cert.pem")?;
14//! let cert = Certificate::from_pem(&cert)?;
15//! let connector = TlsConnector::builder()
16//!     .add_root_certificate(cert)
17//!     .build()?;
18//! let connector = MakeTlsConnector::new(connector);
19//!
20//! let connect_future = yb_tokio_postgres::connect(
21//!     "host=localhost user=postgres sslmode=require",
22//!     connector,
23//! );
24//! # }
25//!
26//! // ...
27//! # Ok(())
28//! # }
29//! ```
30//!
31//! ```no_run
32//! use native_tls::{Certificate, TlsConnector};
33//! # #[cfg(feature = "runtime")]
34//! use postgres_native_tls::MakeTlsConnector;
35//! use std::fs;
36//!
37//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
38//! # #[cfg(feature = "runtime")] {
39//! let cert = fs::read("database_cert.pem")?;
40//! let cert = Certificate::from_pem(&cert)?;
41//! let connector = TlsConnector::builder()
42//!     .add_root_certificate(cert)
43//!     .build()?;
44//! let connector = MakeTlsConnector::new(connector);
45//!
46//! let client = yb_postgres::Client::connect(
47//!     "host=localhost user=postgres sslmode=require",
48//!     connector,
49//! )?;
50//! # }
51//! # Ok(())
52//! # }
53//! ```
54#![warn(rust_2018_idioms, clippy::all, missing_docs)]
55
56use std::future::Future;
57use std::io;
58use std::pin::Pin;
59use std::task::{Context, Poll};
60use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
61use yb_tokio_postgres::tls;
62#[cfg(feature = "runtime")]
63use yb_tokio_postgres::tls::MakeTlsConnect;
64use yb_tokio_postgres::tls::{ChannelBinding, TlsConnect};
65
66#[cfg(test)]
67mod test;
68
69/// A `MakeTlsConnect` implementation using the `native-tls` crate.
70///
71/// Requires the `runtime` Cargo feature (enabled by default).
72#[cfg(feature = "runtime")]
73#[derive(Clone)]
74pub struct MakeTlsConnector(native_tls::TlsConnector);
75
76#[cfg(feature = "runtime")]
77impl MakeTlsConnector {
78    /// Creates a new connector.
79    pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector {
80        MakeTlsConnector(connector)
81    }
82}
83
84#[cfg(feature = "runtime")]
85impl<S> MakeTlsConnect<S> for MakeTlsConnector
86where
87    S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
88{
89    type Stream = TlsStream<S>;
90    type TlsConnect = TlsConnector;
91    type Error = native_tls::Error;
92
93    fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, native_tls::Error> {
94        Ok(TlsConnector::new(self.0.clone(), domain))
95    }
96}
97
98/// A `TlsConnect` implementation using the `native-tls` crate.
99pub struct TlsConnector {
100    connector: tokio_native_tls::TlsConnector,
101    domain: String,
102}
103
104impl TlsConnector {
105    /// Creates a new connector configured to connect to the specified domain.
106    pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
107        TlsConnector {
108            connector: tokio_native_tls::TlsConnector::from(connector),
109            domain: domain.to_string(),
110        }
111    }
112}
113
114impl<S> TlsConnect<S> for TlsConnector
115where
116    S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
117{
118    type Stream = TlsStream<S>;
119    type Error = native_tls::Error;
120    #[allow(clippy::type_complexity)]
121    type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
122
123    fn connect(self, stream: S) -> Self::Future {
124        let stream = BufReader::with_capacity(8192, stream);
125        let future = async move {
126            let stream = self.connector.connect(&self.domain, stream).await?;
127
128            Ok(TlsStream(stream))
129        };
130
131        Box::pin(future)
132    }
133}
134
135/// The stream returned by `TlsConnector`.
136pub struct TlsStream<S>(tokio_native_tls::TlsStream<BufReader<S>>);
137
138impl<S> AsyncRead for TlsStream<S>
139where
140    S: AsyncRead + AsyncWrite + Unpin,
141{
142    fn poll_read(
143        mut self: Pin<&mut Self>,
144        cx: &mut Context<'_>,
145        buf: &mut ReadBuf<'_>,
146    ) -> Poll<io::Result<()>> {
147        Pin::new(&mut self.0).poll_read(cx, buf)
148    }
149}
150
151impl<S> AsyncWrite for TlsStream<S>
152where
153    S: AsyncRead + AsyncWrite + Unpin,
154{
155    fn poll_write(
156        mut self: Pin<&mut Self>,
157        cx: &mut Context<'_>,
158        buf: &[u8],
159    ) -> Poll<io::Result<usize>> {
160        Pin::new(&mut self.0).poll_write(cx, buf)
161    }
162
163    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
164        Pin::new(&mut self.0).poll_flush(cx)
165    }
166
167    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
168        Pin::new(&mut self.0).poll_shutdown(cx)
169    }
170}
171
172impl<S> tls::TlsStream for TlsStream<S>
173where
174    S: AsyncRead + AsyncWrite + Unpin,
175{
176    fn channel_binding(&self) -> ChannelBinding {
177        match self.0.get_ref().tls_server_end_point().ok().flatten() {
178            Some(buf) => ChannelBinding::tls_server_end_point(buf),
179            None => ChannelBinding::none(),
180        }
181    }
182}