tokio_postgres_openssl/
lib.rs

1//! TLS support for `tokio-postgres` via `openssl`.
2//!
3//! # Example
4//!
5//! ```no_run
6//! use openssl::ssl::{SslConnector, SslMethod};
7//! use tokio_postgres_openssl::MakeTlsConnector;
8//!
9//! let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
10//! builder.set_ca_file("database_cert.pem").unwrap();
11//! let connector = MakeTlsConnector::new(builder.build());
12//!
13//! let connect_future = tokio_postgres::connect(
14//!     "host=localhost user=postgres sslmode=require",
15//!     connector,
16//! );
17//!
18//! // ...
19//! ```
20#![doc(html_root_url = "https://docs.rs/tokio-postgres-openssl/0.1.0-rc.1")]
21#![warn(rust_2018_idioms, clippy::all, missing_docs)]
22
23use futures::{try_ready, Async, Future, Poll};
24#[cfg(feature = "runtime")]
25use openssl::error::ErrorStack;
26use openssl::hash::MessageDigest;
27use openssl::nid::Nid;
28#[cfg(feature = "runtime")]
29use openssl::ssl::SslConnector;
30use openssl::ssl::{ConnectConfiguration, HandshakeError, SslRef};
31use std::fmt::Debug;
32#[cfg(feature = "runtime")]
33use std::sync::Arc;
34use tokio_io::{AsyncRead, AsyncWrite};
35use tokio_openssl::{ConnectAsync, ConnectConfigurationExt, SslStream};
36#[cfg(feature = "runtime")]
37use tokio_postgres::tls::MakeTlsConnect;
38use tokio_postgres::tls::{ChannelBinding, TlsConnect};
39
40#[cfg(test)]
41mod test;
42
43/// A `MakeTlsConnect` implementation using the `openssl` crate.
44///
45/// Requires the `runtime` Cargo feature (enabled by default).
46#[cfg(feature = "runtime")]
47#[derive(Clone)]
48pub struct MakeTlsConnector {
49    connector: SslConnector,
50    config: Arc<dyn Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + Sync + Send>,
51}
52
53#[cfg(feature = "runtime")]
54impl MakeTlsConnector {
55    /// Creates a new connector.
56    pub fn new(connector: SslConnector) -> MakeTlsConnector {
57        MakeTlsConnector {
58            connector,
59            config: Arc::new(|_, _| Ok(())),
60        }
61    }
62
63    /// Sets a callback used to apply per-connection configuration.
64    ///
65    /// The the callback is provided the domain name along with the `ConnectConfiguration`.
66    pub fn set_callback<F>(&mut self, f: F)
67    where
68        F: Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + 'static + Sync + Send,
69    {
70        self.config = Arc::new(f);
71    }
72}
73
74#[cfg(feature = "runtime")]
75impl<S> MakeTlsConnect<S> for MakeTlsConnector
76where
77    S: AsyncRead + AsyncWrite + Debug + 'static + Sync + Send,
78{
79    type Stream = SslStream<S>;
80    type TlsConnect = TlsConnector;
81    type Error = ErrorStack;
82
83    fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, ErrorStack> {
84        let mut ssl = self.connector.configure()?;
85        (self.config)(&mut ssl, domain)?;
86        Ok(TlsConnector::new(ssl, domain))
87    }
88}
89
90/// A `TlsConnect` implementation using the `openssl` crate.
91pub struct TlsConnector {
92    ssl: ConnectConfiguration,
93    domain: String,
94}
95
96impl TlsConnector {
97    /// Creates a new connector configured to connect to the specified domain.
98    pub fn new(ssl: ConnectConfiguration, domain: &str) -> TlsConnector {
99        TlsConnector {
100            ssl,
101            domain: domain.to_string(),
102        }
103    }
104}
105
106impl<S> TlsConnect<S> for TlsConnector
107where
108    S: AsyncRead + AsyncWrite + Debug + 'static + Sync + Send,
109{
110    type Stream = SslStream<S>;
111    type Error = HandshakeError<S>;
112    type Future = TlsConnectFuture<S>;
113
114    fn connect(self, stream: S) -> TlsConnectFuture<S> {
115        TlsConnectFuture(self.ssl.connect_async(&self.domain, stream))
116    }
117}
118
119/// The future returned by `TlsConnector`.
120pub struct TlsConnectFuture<S>(ConnectAsync<S>);
121
122impl<S> Future for TlsConnectFuture<S>
123where
124    S: AsyncRead + AsyncWrite + Debug + 'static + Sync + Send,
125{
126    type Item = (SslStream<S>, ChannelBinding);
127    type Error = HandshakeError<S>;
128
129    fn poll(&mut self) -> Poll<(SslStream<S>, ChannelBinding), HandshakeError<S>> {
130        let stream = try_ready!(self.0.poll());
131
132        let channel_binding = match tls_server_end_point(stream.get_ref().ssl()) {
133            Some(buf) => ChannelBinding::tls_server_end_point(buf),
134            None => ChannelBinding::none(),
135        };
136
137        Ok(Async::Ready((stream, channel_binding)))
138    }
139}
140
141fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
142    let cert = ssl.peer_certificate()?;
143    let algo_nid = cert.signature_algorithm().object().nid();
144    let signature_algorithms = algo_nid.signature_algorithms()?;
145    let md = match signature_algorithms.digest {
146        Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
147        nid => MessageDigest::from_nid(nid)?,
148    };
149    cert.digest(md).ok().map(|b| b.to_vec())
150}