1use std::net::SocketAddr;
2use std::sync::Arc;
3use tokio::net::TcpStream;
4use tokio_rustls::rustls::pki_types::ServerName;
5use tokio_rustls::rustls::{ClientConfig, RootCertStore};
6use tokio_rustls::TlsConnector;
7
8use crate::error::Error;
9
10pub type TlsStream = tokio_rustls::client::TlsStream<TcpStream>;
11
12pub fn make_tls_config(
15 extra_ca_der: Option<&[u8]>,
16) -> std::result::Result<Arc<ClientConfig>, Error> {
17 let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
20
21 let mut root_store = RootCertStore::empty();
22 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
23 if let Some(ca_bytes) = extra_ca_der {
24 let ca_cert = tokio_rustls::rustls::pki_types::CertificateDer::from(ca_bytes.to_vec());
25 root_store.add(ca_cert).map_err(|e| {
26 Error::Io(std::io::Error::new(
27 std::io::ErrorKind::InvalidData,
28 format!("invalid CA cert: {e}"),
29 ))
30 })?;
31 }
32 let config = ClientConfig::builder()
33 .with_root_certificates(root_store)
34 .with_no_client_auth();
35 Ok(Arc::new(config))
36}
37
38pub async fn tls_connect(
40 addr: SocketAddr,
41 server_name: &str,
42 extra_ca_der: Option<&[u8]>,
43) -> std::result::Result<TlsStream, Error> {
44 let config = make_tls_config(extra_ca_der)?;
45 let connector = TlsConnector::from(config);
46 let tcp = TcpStream::connect(addr).await.map_err(Error::Io)?;
47 let server_name = ServerName::try_from(server_name.to_string()).map_err(|e| {
48 Error::Io(std::io::Error::new(
49 std::io::ErrorKind::InvalidInput,
50 format!("invalid server name: {e}"),
51 ))
52 })?;
53 connector.connect(server_name, tcp).await.map_err(Error::Io)
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59
60 #[test]
61 fn tls_config_builds_with_system_roots() {
62 let _cfg = make_tls_config(None).unwrap();
63 }
64
65 #[test]
66 fn tls_config_server_name_parses() {
67 let name = rustls::pki_types::ServerName::try_from("plc.example.com".to_string());
68 assert!(name.is_ok());
69 }
70}