Skip to main content

s4_server/
tls.rs

1//! TLS termination helpers.
2//!
3//! Used by the binary's listener wiring. Kept as a separate library module so
4//! parsing logic (`load_tls_config`) is unit-testable and the `tokio-rustls`
5//! dependency is centralised here.
6
7use std::error::Error;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use tokio_rustls::TlsAcceptor;
13use tokio_rustls::rustls::ServerConfig;
14use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
15
16/// Loads PEM cert + key files into a rustls `ServerConfig` ready for
17/// `TlsAcceptor::from`. Supports PKCS#8 and RSA private keys via
18/// `rustls_pemfile::private_key`.
19///
20/// ALPN protocols default to `h2` then `http/1.1` — matching the
21/// `hyper_util::server::conn::auto::Builder` upstream so HTTP/2 is negotiated
22/// when the client offers it.
23pub fn load_tls_config(
24    cert_path: &Path,
25    key_path: &Path,
26) -> Result<Arc<ServerConfig>, Box<dyn Error + Send + Sync + 'static>> {
27    use std::fs::File;
28    use std::io::BufReader;
29
30    let mut cert_reader = BufReader::new(File::open(cert_path)?);
31    let certs: Vec<CertificateDer<'static>> =
32        rustls_pemfile::certs(&mut cert_reader).collect::<Result<Vec<_>, _>>()?;
33    if certs.is_empty() {
34        return Err(format!("no certificates found in {}", cert_path.display()).into());
35    }
36
37    let mut key_reader = BufReader::new(File::open(key_path)?);
38    let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)?
39        .ok_or_else(|| format!("no private key found in {}", key_path.display()))?;
40
41    let mut config = ServerConfig::builder()
42        .with_no_client_auth()
43        .with_single_cert(certs, key)?;
44    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
45    Ok(Arc::new(config))
46}
47
48/// v0.5 #32: load a TLS config restricted to TLS 1.3 only — used by
49/// `--compliance-mode strict` (regulated-industry deployments where
50/// TLS 1.2's CBC + RSA-key-exchange ciphers are off-limits). Failures
51/// to negotiate fall back to a clean handshake-failure rather than a
52/// downgrade.
53pub fn load_tls_config_tls13_only(
54    cert_path: &Path,
55    key_path: &Path,
56) -> Result<Arc<ServerConfig>, Box<dyn Error + Send + Sync + 'static>> {
57    use std::fs::File;
58    use std::io::BufReader;
59    use tokio_rustls::rustls::version::TLS13;
60
61    let mut cert_reader = BufReader::new(File::open(cert_path)?);
62    let certs: Vec<CertificateDer<'static>> =
63        rustls_pemfile::certs(&mut cert_reader).collect::<Result<Vec<_>, _>>()?;
64    if certs.is_empty() {
65        return Err(format!("no certificates found in {}", cert_path.display()).into());
66    }
67    let mut key_reader = BufReader::new(File::open(key_path)?);
68    let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)?
69        .ok_or_else(|| format!("no private key found in {}", key_path.display()))?;
70    let mut config = ServerConfig::builder_with_protocol_versions(&[&TLS13])
71        .with_no_client_auth()
72        .with_single_cert(certs, key)?;
73    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
74    Ok(Arc::new(config))
75}
76
77/// Installs the `ring` crypto provider as the process-wide default. rustls
78/// 0.23+ requires this before any `ServerConfig::builder()` call. Idempotent.
79pub fn install_default_crypto_provider() {
80    let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
81}
82
83/// Reloadable TLS state (v0.3 #10). Wraps an `ArcSwap<ServerConfig>` so the
84/// listener can swap the cert/key pair atomically on SIGHUP without dropping
85/// any in-flight connections. Construct via [`TlsState::load`] and pass
86/// `Arc<TlsState>` to both the accept loop and the SIGHUP handler.
87pub struct TlsState {
88    cfg: ArcSwap<ServerConfig>,
89    cert_path: PathBuf,
90    key_path: PathBuf,
91    /// v0.5 #32: when set, reload uses the TLS 1.3-only loader so a
92    /// hot-reload preserves the compliance posture instead of silently
93    /// dropping back to the default 1.2+1.3 acceptance.
94    tls13_only: bool,
95}
96
97impl TlsState {
98    /// Initial load — fails on parse error. Accepts TLS 1.2 + 1.3.
99    pub fn load(
100        cert_path: impl Into<PathBuf>,
101        key_path: impl Into<PathBuf>,
102    ) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
103        let cert_path = cert_path.into();
104        let key_path = key_path.into();
105        let cfg = load_tls_config(&cert_path, &key_path)?;
106        Ok(Self {
107            cfg: ArcSwap::from(cfg),
108            cert_path,
109            key_path,
110            tls13_only: false,
111        })
112    }
113
114    /// v0.5 #32: TLS 1.3-only initial load — fails on parse error.
115    /// Reloads via [`Self::reload`] also use the 1.3-only loader, so a
116    /// SIGHUP hot-swap can't accidentally re-enable 1.2.
117    pub fn load_tls13_only(
118        cert_path: impl Into<PathBuf>,
119        key_path: impl Into<PathBuf>,
120    ) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
121        let cert_path = cert_path.into();
122        let key_path = key_path.into();
123        let cfg = load_tls_config_tls13_only(&cert_path, &key_path)?;
124        Ok(Self {
125            cfg: ArcSwap::from(cfg),
126            cert_path,
127            key_path,
128            tls13_only: true,
129        })
130    }
131
132    /// Build a fresh `TlsAcceptor` from the current config. Cheap (one
133    /// atomic load + Arc clone). Call this once per accepted connection.
134    pub fn acceptor(&self) -> TlsAcceptor {
135        TlsAcceptor::from(self.cfg.load_full())
136    }
137
138    /// Re-read the cert + key from disk and atomically swap the active
139    /// config. Returns `Ok(())` on success and `Err(...)` if the new pair
140    /// failed to parse — the previous config remains in effect either way,
141    /// so a bad reload never causes a listener outage.
142    pub fn reload(&self) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
143        let new_cfg = if self.tls13_only {
144            load_tls_config_tls13_only(&self.cert_path, &self.key_path)?
145        } else {
146            load_tls_config(&self.cert_path, &self.key_path)?
147        };
148        self.cfg.store(new_cfg);
149        Ok(())
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use std::io::Write;
157    use tempfile::NamedTempFile;
158
159    /// Helper: write a self-signed cert+key pair to two NamedTempFiles using
160    /// rcgen and return them so the test can pass paths to load_tls_config.
161    fn write_self_signed_pair() -> (NamedTempFile, NamedTempFile) {
162        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
163        let mut cert_file = NamedTempFile::new().unwrap();
164        cert_file.write_all(cert.cert.pem().as_bytes()).unwrap();
165        cert_file.flush().unwrap();
166        let mut key_file = NamedTempFile::new().unwrap();
167        key_file
168            .write_all(cert.key_pair.serialize_pem().as_bytes())
169            .unwrap();
170        key_file.flush().unwrap();
171        (cert_file, key_file)
172    }
173
174    #[test]
175    fn loads_pkcs8_cert_and_key() {
176        install_default_crypto_provider();
177        let (cert, key) = write_self_signed_pair();
178        let cfg = load_tls_config(cert.path(), key.path()).expect("config should load");
179        assert_eq!(
180            cfg.alpn_protocols,
181            vec![b"h2".to_vec(), b"http/1.1".to_vec()]
182        );
183    }
184
185    #[test]
186    fn rejects_missing_cert_file() {
187        let (_cert, key) = write_self_signed_pair();
188        let err = load_tls_config(std::path::Path::new("/nonexistent/cert.pem"), key.path())
189            .expect_err("should fail on missing cert");
190        assert!(
191            err.to_string().contains("No such file") || err.to_string().contains("cannot find")
192        );
193    }
194
195    #[test]
196    fn rejects_empty_cert_file() {
197        let cert = NamedTempFile::new().unwrap();
198        let (_, key) = write_self_signed_pair();
199        let err =
200            load_tls_config(cert.path(), key.path()).expect_err("should fail on empty cert PEM");
201        assert!(err.to_string().contains("no certificates found"));
202    }
203
204    /// v0.3 #10: reload swaps the config atomically. Verify by capturing
205    /// the cert serials before + after and confirming they differ.
206    #[test]
207    fn reload_swaps_active_config() {
208        install_default_crypto_provider();
209        let (cert_a, key_a) = write_self_signed_pair();
210
211        // Start with cert A, copy to a stable path the TlsState owns.
212        let dir = tempfile::tempdir().unwrap();
213        let cert_path = dir.path().join("tls.crt");
214        let key_path = dir.path().join("tls.key");
215        std::fs::copy(cert_a.path(), &cert_path).unwrap();
216        std::fs::copy(key_a.path(), &key_path).unwrap();
217
218        let state = TlsState::load(&cert_path, &key_path).expect("initial load");
219        let cfg_v1: Arc<ServerConfig> = state.cfg.load_full();
220
221        // Swap the on-disk files to a freshly-generated cert B.
222        let (cert_b, key_b) = write_self_signed_pair();
223        std::fs::copy(cert_b.path(), &cert_path).unwrap();
224        std::fs::copy(key_b.path(), &key_path).unwrap();
225
226        state.reload().expect("reload should succeed");
227        let cfg_v2: Arc<ServerConfig> = state.cfg.load_full();
228
229        // Pointer identity is the cleanest check: ArcSwap::store replaces
230        // the inner Arc, so cfg_v1 and cfg_v2 must NOT be the same Arc.
231        assert!(!Arc::ptr_eq(&cfg_v1, &cfg_v2));
232    }
233
234    /// Reload failure (bad PEM) must not break the active config.
235    #[test]
236    fn reload_failure_keeps_previous_config() {
237        install_default_crypto_provider();
238        let (cert_a, key_a) = write_self_signed_pair();
239        let dir = tempfile::tempdir().unwrap();
240        let cert_path = dir.path().join("tls.crt");
241        let key_path = dir.path().join("tls.key");
242        std::fs::copy(cert_a.path(), &cert_path).unwrap();
243        std::fs::copy(key_a.path(), &key_path).unwrap();
244
245        let state = TlsState::load(&cert_path, &key_path).expect("initial load");
246        let cfg_before: Arc<ServerConfig> = state.cfg.load_full();
247
248        // Corrupt the cert file in-place — reload should fail.
249        std::fs::write(&cert_path, b"not a pem certificate").unwrap();
250        let err = state.reload().expect_err("reload should fail");
251        assert!(err.to_string().contains("no certificates found"));
252
253        let cfg_after: Arc<ServerConfig> = state.cfg.load_full();
254        // Same Arc → previous config preserved.
255        assert!(Arc::ptr_eq(&cfg_before, &cfg_after));
256    }
257}