sqlx_postgres/connection/
tls.rs1use crate::error::Error;
2use crate::net::tls::{self, TlsConfig};
3use crate::net::{Socket, SocketIntoBox, WithSocket};
4
5use crate::message::SslRequest;
6use crate::{PgConnectOptions, PgSslMode};
7
8pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions);
9
10impl<'a> WithSocket for MaybeUpgradeTls<'a> {
11 type Output = crate::Result<Box<dyn Socket>>;
12
13 async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
14 maybe_upgrade(socket, self.0).await
15 }
16}
17
18async fn maybe_upgrade<S: Socket>(
19 mut socket: S,
20 options: &PgConnectOptions,
21) -> Result<Box<dyn Socket>, Error> {
22 match options.ssl_mode {
24 PgSslMode::Allow | PgSslMode::Disable => return Ok(Box::new(socket)),
26
27 PgSslMode::Prefer => {
28 if !tls::available() {
29 return Ok(Box::new(socket));
30 }
31
32 if !request_upgrade(&mut socket, options).await? {
34 return Ok(Box::new(socket));
35 }
36 }
37
38 PgSslMode::Require | PgSslMode::VerifyFull | PgSslMode::VerifyCa => {
39 tls::error_if_unavailable()?;
40
41 if !request_upgrade(&mut socket, options).await? {
42 return Err(Error::Tls("server does not support TLS".into()));
44 }
45 }
46 }
47
48 let accept_invalid_certs = !matches!(
49 options.ssl_mode,
50 PgSslMode::VerifyCa | PgSslMode::VerifyFull
51 );
52 let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull);
53
54 let config = TlsConfig {
55 accept_invalid_certs,
56 accept_invalid_hostnames,
57 hostname: &options.host,
58 root_cert_path: options.ssl_root_cert.as_ref(),
59 client_cert_path: options.ssl_client_cert.as_ref(),
60 client_key_path: options.ssl_client_key.as_ref(),
61 };
62
63 tls::handshake(socket, config, SocketIntoBox).await
64}
65
66async fn request_upgrade(
67 socket: &mut impl Socket,
68 _options: &PgConnectOptions,
69) -> Result<bool, Error> {
70 socket.write(SslRequest::BYTES).await?;
76
77 let mut response = [0u8];
81
82 socket.read(&mut &mut response[..]).await?;
83
84 match response[0] {
85 b'S' => {
86 Ok(true)
88 }
89
90 b'N' => {
91 Ok(false)
93 }
94
95 other => Err(err_protocol!(
96 "unexpected response from SSLRequest: 0x{:02x}",
97 other
98 )),
99 }
100}