1#![allow(dead_code)]
2
3use std::path::PathBuf;
4
5use crate::error::Error;
6use crate::net::socket::WithSocket;
7use crate::net::Socket;
8
9#[cfg(feature = "_tls-rustls")]
10mod tls_rustls;
11
12#[cfg(feature = "_tls-native-tls")]
13mod tls_native_tls;
14
15mod util;
16
17#[derive(Clone, Debug)]
19pub enum CertificateInput {
20 Inline(Vec<u8>),
22 File(PathBuf),
24}
25
26impl From<String> for CertificateInput {
27 fn from(value: String) -> Self {
28 let trimmed = value.trim();
30
31 if trimmed.starts_with("-----BEGIN") && trimmed.ends_with("-----") {
34 CertificateInput::Inline(value.as_bytes().to_vec())
35 } else {
36 CertificateInput::File(PathBuf::from(value))
37 }
38 }
39}
40
41impl CertificateInput {
42 async fn data(&self) -> Result<Vec<u8>, std::io::Error> {
43 use crate::fs;
44 match self {
45 CertificateInput::Inline(v) => Ok(v.clone()),
46 CertificateInput::File(path) => fs::read(path).await,
47 }
48 }
49}
50
51impl std::fmt::Display for CertificateInput {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 CertificateInput::Inline(v) => write!(f, "{}", String::from_utf8_lossy(v.as_slice())),
55 CertificateInput::File(path) => write!(f, "file: {}", path.display()),
56 }
57 }
58}
59
60pub struct TlsConfig<'a> {
61 pub accept_invalid_certs: bool,
62 pub accept_invalid_hostnames: bool,
63 pub hostname: &'a str,
64 pub root_cert_path: Option<&'a CertificateInput>,
65 pub client_cert_path: Option<&'a CertificateInput>,
66 pub client_key_path: Option<&'a CertificateInput>,
67}
68
69pub async fn handshake<S, Ws>(
70 socket: S,
71 config: TlsConfig<'_>,
72 with_socket: Ws,
73) -> crate::Result<Ws::Output>
74where
75 S: Socket,
76 Ws: WithSocket,
77{
78 #[cfg(feature = "_tls-native-tls")]
79 return Ok(with_socket
80 .with_socket(tls_native_tls::handshake(socket, config).await?)
81 .await);
82
83 #[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
84 return Ok(with_socket
85 .with_socket(tls_rustls::handshake(socket, config).await?)
86 .await);
87
88 #[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
89 {
90 drop((socket, config, with_socket));
91 panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled")
92 }
93}
94
95pub fn available() -> bool {
96 cfg!(any(feature = "_tls-native-tls", feature = "_tls-rustls"))
97}
98
99pub fn error_if_unavailable() -> crate::Result<()> {
100 if !available() {
101 return Err(Error::tls(
102 "TLS upgrade required by connect options \
103 but SQLx was built without TLS support enabled",
104 ));
105 }
106
107 Ok(())
108}