tokio_postgres_fork/
connect_tls.rs

1use crate::config::{SslMode, SslNegotiation};
2use crate::maybe_tls_stream::MaybeTlsStream;
3use crate::tls::private::ForcePrivateApi;
4use crate::tls::TlsConnect;
5use crate::Error;
6use bytes::BytesMut;
7use postgres_protocol::message::frontend;
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10pub async fn connect_tls<S, T>(
11    mut stream: S,
12    mode: SslMode,
13    negotiation: SslNegotiation,
14    tls: T,
15    has_hostname: bool,
16) -> Result<MaybeTlsStream<S, T::Stream>, Error>
17where
18    S: AsyncRead + AsyncWrite + Unpin,
19    T: TlsConnect<S>,
20{
21    match mode {
22        SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
23        SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
24            return Ok(MaybeTlsStream::Raw(stream))
25        }
26        SslMode::Prefer if negotiation == SslNegotiation::Direct => {
27            return Err(Error::tls("weak sslmode \"prefer\" may not be used with sslnegotiation=direct (use \"require\", \"verify-ca\", or \"verify-full\")".into()))
28        }
29        SslMode::Prefer | SslMode::Require => {}
30    }
31
32    if negotiation == SslNegotiation::Postgres {
33        let mut buf = BytesMut::new();
34        frontend::ssl_request(&mut buf);
35        stream.write_all(&buf).await.map_err(Error::io)?;
36
37        let mut buf = [0];
38        stream.read_exact(&mut buf).await.map_err(Error::io)?;
39
40        if buf[0] != b'S' {
41            if SslMode::Require == mode {
42                return Err(Error::tls("server does not support TLS".into()));
43            } else {
44                return Ok(MaybeTlsStream::Raw(stream));
45            }
46        }
47    }
48
49    if !has_hostname {
50        return Err(Error::tls("no hostname provided for TLS handshake".into()));
51    }
52
53    let stream = tls
54        .connect(stream)
55        .await
56        .map_err(|e| Error::tls(e.into()))?;
57
58    Ok(MaybeTlsStream::Tls(stream))
59}