tokio_postgres_native_tls/
lib.rs

1//! TLS support for `tokio-postgres` via `native-tls.
2//!
3//! # Example
4//!
5//! ```no_run
6//! use native_tls::{Certificate, TlsConnector};
7//! use tokio_postgres_native_tls::MakeTlsConnector;
8//! use std::fs;
9//!
10//! let cert = fs::read("database_cert.pem").unwrap();
11//! let cert = Certificate::from_pem(&cert).unwrap();
12//! let connector = TlsConnector::builder()
13//!     .add_root_certificate(cert)
14//!     .build()
15//!     .unwrap();
16//! let connector = MakeTlsConnector::new(connector);
17//!
18//! let connect_future = tokio_postgres::connect(
19//!     "host=localhost user=postgres sslmode=require",
20//!     connector,
21//! );
22//!
23//! // ...
24//! ```
25#![doc(html_root_url = "https://docs.rs/tokio-postgres-native-tls/0.1.0-rc.1")]
26#![warn(rust_2018_idioms, clippy::all, missing_docs)]
27
28use futures::{try_ready, Async, Future, Poll};
29use tokio_io::{AsyncRead, AsyncWrite};
30#[cfg(feature = "runtime")]
31use tokio_postgres::tls::MakeTlsConnect;
32use tokio_postgres::tls::{ChannelBinding, TlsConnect};
33use tokio_tls::{Connect, TlsStream};
34
35#[cfg(test)]
36mod test;
37
38/// A `MakeTlsConnect` implementation using the `native-tls` crate.
39///
40/// Requires the `runtime` Cargo feature (enabled by default).
41#[cfg(feature = "runtime")]
42#[derive(Clone)]
43pub struct MakeTlsConnector(native_tls::TlsConnector);
44
45#[cfg(feature = "runtime")]
46impl MakeTlsConnector {
47    /// Creates a new connector.
48    pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector {
49        MakeTlsConnector(connector)
50    }
51}
52
53#[cfg(feature = "runtime")]
54impl<S> MakeTlsConnect<S> for MakeTlsConnector
55where
56    S: AsyncRead + AsyncWrite,
57{
58    type Stream = TlsStream<S>;
59    type TlsConnect = TlsConnector;
60    type Error = native_tls::Error;
61
62    fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, native_tls::Error> {
63        Ok(TlsConnector::new(self.0.clone(), domain))
64    }
65}
66
67/// A `TlsConnect` implementation using the `native-tls` crate.
68pub struct TlsConnector {
69    connector: tokio_tls::TlsConnector,
70    domain: String,
71}
72
73impl TlsConnector {
74    /// Creates a new connector configured to connect to the specified domain.
75    pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
76        TlsConnector {
77            connector: tokio_tls::TlsConnector::from(connector),
78            domain: domain.to_string(),
79        }
80    }
81}
82
83impl<S> TlsConnect<S> for TlsConnector
84where
85    S: AsyncRead + AsyncWrite,
86{
87    type Stream = TlsStream<S>;
88    type Error = native_tls::Error;
89    type Future = TlsConnectFuture<S>;
90
91    fn connect(self, stream: S) -> TlsConnectFuture<S> {
92        TlsConnectFuture(self.connector.connect(&self.domain, stream))
93    }
94}
95
96/// The future returned by `TlsConnector`.
97pub struct TlsConnectFuture<S>(Connect<S>);
98
99impl<S> Future for TlsConnectFuture<S>
100where
101    S: AsyncRead + AsyncWrite,
102{
103    type Item = (TlsStream<S>, ChannelBinding);
104    type Error = native_tls::Error;
105
106    fn poll(&mut self) -> Poll<(TlsStream<S>, ChannelBinding), native_tls::Error> {
107        let stream = try_ready!(self.0.poll());
108
109        let channel_binding = match stream.get_ref().tls_server_end_point().unwrap_or(None) {
110            Some(buf) => ChannelBinding::tls_server_end_point(buf),
111            None => ChannelBinding::none(),
112        };
113
114        Ok(Async::Ready((stream, channel_binding)))
115    }
116}