1use std::error::Error;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use tokio_rustls::TlsAcceptor;
13use tokio_rustls::rustls::ServerConfig;
14use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
15
16pub fn load_tls_config(
24 cert_path: &Path,
25 key_path: &Path,
26) -> Result<Arc<ServerConfig>, Box<dyn Error + Send + Sync + 'static>> {
27 use std::fs::File;
28 use std::io::BufReader;
29
30 let mut cert_reader = BufReader::new(File::open(cert_path)?);
31 let certs: Vec<CertificateDer<'static>> =
32 rustls_pemfile::certs(&mut cert_reader).collect::<Result<Vec<_>, _>>()?;
33 if certs.is_empty() {
34 return Err(format!("no certificates found in {}", cert_path.display()).into());
35 }
36
37 let mut key_reader = BufReader::new(File::open(key_path)?);
38 let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)?
39 .ok_or_else(|| format!("no private key found in {}", key_path.display()))?;
40
41 let mut config = ServerConfig::builder()
42 .with_no_client_auth()
43 .with_single_cert(certs, key)?;
44 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
45 Ok(Arc::new(config))
46}
47
48pub fn load_tls_config_tls13_only(
54 cert_path: &Path,
55 key_path: &Path,
56) -> Result<Arc<ServerConfig>, Box<dyn Error + Send + Sync + 'static>> {
57 use std::fs::File;
58 use std::io::BufReader;
59 use tokio_rustls::rustls::version::TLS13;
60
61 let mut cert_reader = BufReader::new(File::open(cert_path)?);
62 let certs: Vec<CertificateDer<'static>> =
63 rustls_pemfile::certs(&mut cert_reader).collect::<Result<Vec<_>, _>>()?;
64 if certs.is_empty() {
65 return Err(format!("no certificates found in {}", cert_path.display()).into());
66 }
67 let mut key_reader = BufReader::new(File::open(key_path)?);
68 let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)?
69 .ok_or_else(|| format!("no private key found in {}", key_path.display()))?;
70 let mut config = ServerConfig::builder_with_protocol_versions(&[&TLS13])
71 .with_no_client_auth()
72 .with_single_cert(certs, key)?;
73 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
74 Ok(Arc::new(config))
75}
76
77pub fn install_default_crypto_provider() {
80 let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
81}
82
83pub struct TlsState {
88 cfg: ArcSwap<ServerConfig>,
89 cert_path: PathBuf,
90 key_path: PathBuf,
91 tls13_only: bool,
95}
96
97impl TlsState {
98 pub fn load(
100 cert_path: impl Into<PathBuf>,
101 key_path: impl Into<PathBuf>,
102 ) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
103 let cert_path = cert_path.into();
104 let key_path = key_path.into();
105 let cfg = load_tls_config(&cert_path, &key_path)?;
106 Ok(Self {
107 cfg: ArcSwap::from(cfg),
108 cert_path,
109 key_path,
110 tls13_only: false,
111 })
112 }
113
114 pub fn load_tls13_only(
118 cert_path: impl Into<PathBuf>,
119 key_path: impl Into<PathBuf>,
120 ) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
121 let cert_path = cert_path.into();
122 let key_path = key_path.into();
123 let cfg = load_tls_config_tls13_only(&cert_path, &key_path)?;
124 Ok(Self {
125 cfg: ArcSwap::from(cfg),
126 cert_path,
127 key_path,
128 tls13_only: true,
129 })
130 }
131
132 pub fn acceptor(&self) -> TlsAcceptor {
135 TlsAcceptor::from(self.cfg.load_full())
136 }
137
138 pub fn reload(&self) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
143 let new_cfg = if self.tls13_only {
144 load_tls_config_tls13_only(&self.cert_path, &self.key_path)?
145 } else {
146 load_tls_config(&self.cert_path, &self.key_path)?
147 };
148 self.cfg.store(new_cfg);
149 Ok(())
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use std::io::Write;
157 use tempfile::NamedTempFile;
158
159 fn write_self_signed_pair() -> (NamedTempFile, NamedTempFile) {
162 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
163 let mut cert_file = NamedTempFile::new().unwrap();
164 cert_file.write_all(cert.cert.pem().as_bytes()).unwrap();
165 cert_file.flush().unwrap();
166 let mut key_file = NamedTempFile::new().unwrap();
167 key_file
168 .write_all(cert.key_pair.serialize_pem().as_bytes())
169 .unwrap();
170 key_file.flush().unwrap();
171 (cert_file, key_file)
172 }
173
174 #[test]
175 fn loads_pkcs8_cert_and_key() {
176 install_default_crypto_provider();
177 let (cert, key) = write_self_signed_pair();
178 let cfg = load_tls_config(cert.path(), key.path()).expect("config should load");
179 assert_eq!(
180 cfg.alpn_protocols,
181 vec![b"h2".to_vec(), b"http/1.1".to_vec()]
182 );
183 }
184
185 #[test]
186 fn rejects_missing_cert_file() {
187 let (_cert, key) = write_self_signed_pair();
188 let err = load_tls_config(std::path::Path::new("/nonexistent/cert.pem"), key.path())
189 .expect_err("should fail on missing cert");
190 assert!(
191 err.to_string().contains("No such file") || err.to_string().contains("cannot find")
192 );
193 }
194
195 #[test]
196 fn rejects_empty_cert_file() {
197 let cert = NamedTempFile::new().unwrap();
198 let (_, key) = write_self_signed_pair();
199 let err =
200 load_tls_config(cert.path(), key.path()).expect_err("should fail on empty cert PEM");
201 assert!(err.to_string().contains("no certificates found"));
202 }
203
204 #[test]
207 fn reload_swaps_active_config() {
208 install_default_crypto_provider();
209 let (cert_a, key_a) = write_self_signed_pair();
210
211 let dir = tempfile::tempdir().unwrap();
213 let cert_path = dir.path().join("tls.crt");
214 let key_path = dir.path().join("tls.key");
215 std::fs::copy(cert_a.path(), &cert_path).unwrap();
216 std::fs::copy(key_a.path(), &key_path).unwrap();
217
218 let state = TlsState::load(&cert_path, &key_path).expect("initial load");
219 let cfg_v1: Arc<ServerConfig> = state.cfg.load_full();
220
221 let (cert_b, key_b) = write_self_signed_pair();
223 std::fs::copy(cert_b.path(), &cert_path).unwrap();
224 std::fs::copy(key_b.path(), &key_path).unwrap();
225
226 state.reload().expect("reload should succeed");
227 let cfg_v2: Arc<ServerConfig> = state.cfg.load_full();
228
229 assert!(!Arc::ptr_eq(&cfg_v1, &cfg_v2));
232 }
233
234 #[test]
236 fn reload_failure_keeps_previous_config() {
237 install_default_crypto_provider();
238 let (cert_a, key_a) = write_self_signed_pair();
239 let dir = tempfile::tempdir().unwrap();
240 let cert_path = dir.path().join("tls.crt");
241 let key_path = dir.path().join("tls.key");
242 std::fs::copy(cert_a.path(), &cert_path).unwrap();
243 std::fs::copy(key_a.path(), &key_path).unwrap();
244
245 let state = TlsState::load(&cert_path, &key_path).expect("initial load");
246 let cfg_before: Arc<ServerConfig> = state.cfg.load_full();
247
248 std::fs::write(&cert_path, b"not a pem certificate").unwrap();
250 let err = state.reload().expect_err("reload should fail");
251 assert!(err.to_string().contains("no certificates found"));
252
253 let cfg_after: Arc<ServerConfig> = state.cfg.load_full();
254 assert!(Arc::ptr_eq(&cfg_before, &cfg_after));
256 }
257}