1use std::error::Error;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use tokio_rustls::TlsAcceptor;
13use tokio_rustls::rustls::ServerConfig;
14use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
15
16pub fn load_tls_config(
24 cert_path: &Path,
25 key_path: &Path,
26) -> Result<Arc<ServerConfig>, Box<dyn Error + Send + Sync + 'static>> {
27 use std::fs::File;
28 use std::io::BufReader;
29
30 let mut cert_reader = BufReader::new(File::open(cert_path)?);
31 let certs: Vec<CertificateDer<'static>> =
32 rustls_pemfile::certs(&mut cert_reader).collect::<Result<Vec<_>, _>>()?;
33 if certs.is_empty() {
34 return Err(format!("no certificates found in {}", cert_path.display()).into());
35 }
36
37 let mut key_reader = BufReader::new(File::open(key_path)?);
38 let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)?
39 .ok_or_else(|| format!("no private key found in {}", key_path.display()))?;
40
41 let mut config = ServerConfig::builder()
42 .with_no_client_auth()
43 .with_single_cert(certs, key)?;
44 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
45 Ok(Arc::new(config))
46}
47
48pub fn install_default_crypto_provider() {
51 let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
52}
53
54pub struct TlsState {
59 cfg: ArcSwap<ServerConfig>,
60 cert_path: PathBuf,
61 key_path: PathBuf,
62}
63
64impl TlsState {
65 pub fn load(
67 cert_path: impl Into<PathBuf>,
68 key_path: impl Into<PathBuf>,
69 ) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
70 let cert_path = cert_path.into();
71 let key_path = key_path.into();
72 let cfg = load_tls_config(&cert_path, &key_path)?;
73 Ok(Self {
74 cfg: ArcSwap::from(cfg),
75 cert_path,
76 key_path,
77 })
78 }
79
80 pub fn acceptor(&self) -> TlsAcceptor {
83 TlsAcceptor::from(self.cfg.load_full())
84 }
85
86 pub fn reload(&self) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
91 let new_cfg = load_tls_config(&self.cert_path, &self.key_path)?;
92 self.cfg.store(new_cfg);
93 Ok(())
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use std::io::Write;
101 use tempfile::NamedTempFile;
102
103 fn write_self_signed_pair() -> (NamedTempFile, NamedTempFile) {
106 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
107 let mut cert_file = NamedTempFile::new().unwrap();
108 cert_file.write_all(cert.cert.pem().as_bytes()).unwrap();
109 cert_file.flush().unwrap();
110 let mut key_file = NamedTempFile::new().unwrap();
111 key_file
112 .write_all(cert.key_pair.serialize_pem().as_bytes())
113 .unwrap();
114 key_file.flush().unwrap();
115 (cert_file, key_file)
116 }
117
118 #[test]
119 fn loads_pkcs8_cert_and_key() {
120 install_default_crypto_provider();
121 let (cert, key) = write_self_signed_pair();
122 let cfg = load_tls_config(cert.path(), key.path()).expect("config should load");
123 assert_eq!(
124 cfg.alpn_protocols,
125 vec![b"h2".to_vec(), b"http/1.1".to_vec()]
126 );
127 }
128
129 #[test]
130 fn rejects_missing_cert_file() {
131 let (_cert, key) = write_self_signed_pair();
132 let err = load_tls_config(std::path::Path::new("/nonexistent/cert.pem"), key.path())
133 .expect_err("should fail on missing cert");
134 assert!(
135 err.to_string().contains("No such file") || err.to_string().contains("cannot find")
136 );
137 }
138
139 #[test]
140 fn rejects_empty_cert_file() {
141 let cert = NamedTempFile::new().unwrap();
142 let (_, key) = write_self_signed_pair();
143 let err =
144 load_tls_config(cert.path(), key.path()).expect_err("should fail on empty cert PEM");
145 assert!(err.to_string().contains("no certificates found"));
146 }
147
148 #[test]
151 fn reload_swaps_active_config() {
152 install_default_crypto_provider();
153 let (cert_a, key_a) = write_self_signed_pair();
154
155 let dir = tempfile::tempdir().unwrap();
157 let cert_path = dir.path().join("tls.crt");
158 let key_path = dir.path().join("tls.key");
159 std::fs::copy(cert_a.path(), &cert_path).unwrap();
160 std::fs::copy(key_a.path(), &key_path).unwrap();
161
162 let state = TlsState::load(&cert_path, &key_path).expect("initial load");
163 let cfg_v1: Arc<ServerConfig> = state.cfg.load_full();
164
165 let (cert_b, key_b) = write_self_signed_pair();
167 std::fs::copy(cert_b.path(), &cert_path).unwrap();
168 std::fs::copy(key_b.path(), &key_path).unwrap();
169
170 state.reload().expect("reload should succeed");
171 let cfg_v2: Arc<ServerConfig> = state.cfg.load_full();
172
173 assert!(!Arc::ptr_eq(&cfg_v1, &cfg_v2));
176 }
177
178 #[test]
180 fn reload_failure_keeps_previous_config() {
181 install_default_crypto_provider();
182 let (cert_a, key_a) = write_self_signed_pair();
183 let dir = tempfile::tempdir().unwrap();
184 let cert_path = dir.path().join("tls.crt");
185 let key_path = dir.path().join("tls.key");
186 std::fs::copy(cert_a.path(), &cert_path).unwrap();
187 std::fs::copy(key_a.path(), &key_path).unwrap();
188
189 let state = TlsState::load(&cert_path, &key_path).expect("initial load");
190 let cfg_before: Arc<ServerConfig> = state.cfg.load_full();
191
192 std::fs::write(&cert_path, b"not a pem certificate").unwrap();
194 let err = state.reload().expect_err("reload should fail");
195 assert!(err.to_string().contains("no certificates found"));
196
197 let cfg_after: Arc<ServerConfig> = state.cfg.load_full();
198 assert!(Arc::ptr_eq(&cfg_before, &cfg_after));
200 }
201}