1use std::io::{Read, Write};
24use std::net::TcpStream;
25use std::path::PathBuf;
26use std::sync::{Arc, RwLock};
27use std::time::Duration;
28
29use rustls::{ClientConfig, ServerConfig, ServerConnection};
30use rustls_pki_types::ServerName;
31
32use crate::auth::AuthSubject;
33use crate::ctx::extract_mtls_subject;
34use crate::tls::{TlsConfigError, load_server_config, load_server_config_with_client_auth};
35
36#[derive(Debug, Clone)]
45pub struct RotatingTlsConfig {
46 inner: Arc<RwLock<Arc<ServerConfig>>>,
47 cert_path: PathBuf,
48 key_path: PathBuf,
49 client_ca_path: Option<PathBuf>,
50}
51
52impl RotatingTlsConfig {
53 pub fn load(
59 cert_path: PathBuf,
60 key_path: PathBuf,
61 client_ca_path: Option<PathBuf>,
62 ) -> Result<Self, TlsConfigError> {
63 let cfg = match &client_ca_path {
64 Some(ca) => load_server_config_with_client_auth(&cert_path, &key_path, ca)?,
65 None => load_server_config(&cert_path, &key_path)?,
66 };
67 Ok(Self {
68 inner: Arc::new(RwLock::new(cfg)),
69 cert_path,
70 key_path,
71 client_ca_path,
72 })
73 }
74
75 #[must_use]
79 pub fn current(&self) -> Arc<ServerConfig> {
80 match self.inner.read() {
81 Ok(g) => Arc::clone(&g),
82 Err(poisoned) => Arc::clone(&poisoned.into_inner()),
83 }
84 }
85
86 pub fn reload(&self) -> Result<(), TlsConfigError> {
93 let new_cfg = match &self.client_ca_path {
94 Some(ca) => load_server_config_with_client_auth(&self.cert_path, &self.key_path, ca)?,
95 None => load_server_config(&self.cert_path, &self.key_path)?,
96 };
97 let mut g = match self.inner.write() {
98 Ok(g) => g,
99 Err(poisoned) => poisoned.into_inner(),
100 };
101 *g = new_cfg;
102 Ok(())
103 }
104}
105
106pub fn serve_tls_handshake(
118 cfg: Arc<ServerConfig>,
119 mut stream: TcpStream,
120 handshake_timeout: Duration,
121) -> std::io::Result<(TcpStream, ServerConnection, Option<AuthSubject>)> {
122 stream.set_read_timeout(Some(handshake_timeout))?;
123 stream.set_write_timeout(Some(handshake_timeout))?;
124 let mut conn = ServerConnection::new(cfg).map_err(|e| {
125 std::io::Error::new(std::io::ErrorKind::InvalidData, format!("rustls: {e}"))
126 })?;
127
128 while conn.is_handshaking() {
130 if conn.wants_write() {
131 let mut sink = TcpWriter(&mut stream);
132 conn.write_tls(&mut sink)?;
133 }
134 if conn.wants_read() {
135 let mut src = TcpReader(&mut stream);
136 let n = conn.read_tls(&mut src)?;
137 if n == 0 {
138 return Err(std::io::Error::new(
139 std::io::ErrorKind::UnexpectedEof,
140 "tls handshake eof",
141 ));
142 }
143 conn.process_new_packets().map_err(|e| {
144 std::io::Error::new(std::io::ErrorKind::InvalidData, format!("rustls: {e}"))
145 })?;
146 }
147 }
148 while conn.wants_write() {
150 let mut sink = TcpWriter(&mut stream);
151 conn.write_tls(&mut sink)?;
152 }
153
154 let mtls_subject = extract_mtls_subject(&conn);
155 Ok((stream, conn, mtls_subject))
156}
157
158struct TcpReader<'a>(&'a mut TcpStream);
160impl Read for TcpReader<'_> {
161 fn read(&mut self, b: &mut [u8]) -> std::io::Result<usize> {
162 self.0.read(b)
163 }
164}
165
166struct TcpWriter<'a>(&'a mut TcpStream);
168impl Write for TcpWriter<'_> {
169 fn write(&mut self, b: &[u8]) -> std::io::Result<usize> {
170 self.0.write(b)
171 }
172 fn flush(&mut self) -> std::io::Result<()> {
173 self.0.flush()
174 }
175}
176
177pub fn build_client_tls_connector(
191 ca_pem_path: Option<&std::path::Path>,
192 client_cert_pem_path: Option<&std::path::Path>,
193 client_key_pem_path: Option<&std::path::Path>,
194) -> Result<Arc<ClientConfig>, TlsConfigError> {
195 use crate::tls::{read_certs, read_private_key};
196
197 let mut roots = rustls::RootCertStore::empty();
198 if let Some(ca) = ca_pem_path {
199 for c in read_certs(ca)? {
200 roots
201 .add(c)
202 .map_err(|e| TlsConfigError::Rustls(format!("ca add: {e}")))?;
203 }
204 }
205 let provider = rustls::crypto::ring::default_provider();
206 let builder = ClientConfig::builder_with_provider(Arc::new(provider))
207 .with_safe_default_protocol_versions()
208 .map_err(|e| TlsConfigError::Rustls(format!("{e}")))?
209 .with_root_certificates(roots);
210
211 let cfg = match (client_cert_pem_path, client_key_pem_path) {
212 (Some(c), Some(k)) => {
213 let certs = read_certs(c)?;
214 let key = read_private_key(k)?;
215 builder
216 .with_client_auth_cert(certs, key)
217 .map_err(|e| TlsConfigError::Rustls(format!("client auth: {e}")))?
218 }
219 (None, None) => builder.with_no_client_auth(),
220 _ => {
221 return Err(TlsConfigError::Rustls(
222 "client cert and key must be set together".into(),
223 ));
224 }
225 };
226 Ok(Arc::new(cfg))
227}
228
229pub fn parse_server_name(host: &str) -> Result<ServerName<'static>, TlsConfigError> {
235 ServerName::try_from(host.to_string())
236 .map_err(|e| TlsConfigError::Rustls(format!("invalid server name '{host}': {e}")))
237}
238
239#[cfg(test)]
240#[allow(clippy::expect_used, clippy::unwrap_used)]
241mod tests {
242 use super::*;
243 #[allow(unused_imports)]
244 use std::io::Write as _;
245
246 fn write_temp(name: &str, body: &[u8]) -> PathBuf {
247 let dir =
248 std::env::temp_dir().join(format!("zd-bridge-conn-{}-{}", name, std::process::id()));
249 let _ = std::fs::create_dir_all(&dir);
250 let p = dir.join(name);
251 let mut f = std::fs::File::create(&p).unwrap();
252 f.write_all(body).unwrap();
253 p
254 }
255
256 fn gen_self_signed() -> (String, String) {
257 let ck = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
258 (ck.cert.pem(), ck.key_pair.serialize_pem())
259 }
260
261 #[test]
262 fn rotating_config_load_and_current_works() {
263 let (cert, key) = gen_self_signed();
264 let c = write_temp("rcert.pem", cert.as_bytes());
265 let k = write_temp("rkey.pem", key.as_bytes());
266 let r = RotatingTlsConfig::load(c, k, None).expect("load");
267 let cur1 = r.current();
268 let cur2 = r.current();
269 assert!(Arc::ptr_eq(&cur1, &cur2));
271 }
272
273 #[test]
274 fn rotating_config_reload_swaps_inner_arc() {
275 let (cert1, key1) = gen_self_signed();
276 let c = write_temp("rcert2.pem", cert1.as_bytes());
277 let k = write_temp("rkey2.pem", key1.as_bytes());
278 let r = RotatingTlsConfig::load(c.clone(), k.clone(), None).expect("load");
279 let before = r.current();
280 let (cert2, key2) = gen_self_signed();
282 std::fs::write(&c, cert2.as_bytes()).unwrap();
283 std::fs::write(&k, key2.as_bytes()).unwrap();
284 r.reload().expect("reload");
285 let after = r.current();
286 assert!(!Arc::ptr_eq(&before, &after));
288 }
289
290 #[test]
291 fn rotating_config_reload_with_bad_path_keeps_old() {
292 let (cert, key) = gen_self_signed();
293 let c = write_temp("rcert3.pem", cert.as_bytes());
294 let k = write_temp("rkey3.pem", key.as_bytes());
295 let r = RotatingTlsConfig::load(c.clone(), k.clone(), None).expect("load");
296 let before = r.current();
297 std::fs::write(&c, b"-----BEGIN GARBAGE-----\n-----END GARBAGE-----\n").unwrap();
299 let err = r.reload().unwrap_err();
300 assert!(matches!(err, TlsConfigError::NoCertificateInPem));
301 let after = r.current();
303 assert!(Arc::ptr_eq(&before, &after));
304 }
305
306 #[test]
307 fn parse_server_name_accepts_dns_hostname() {
308 let _ = parse_server_name("example.com").expect("dns");
309 }
310
311 #[test]
312 fn parse_server_name_accepts_ip() {
313 let _ = parse_server_name("127.0.0.1").expect("ip");
314 }
315
316 #[test]
317 fn build_client_tls_connector_no_auth_succeeds() {
318 let (cert, _key) = gen_self_signed();
319 let ca = write_temp("ca.pem", cert.as_bytes());
320 let cfg = build_client_tls_connector(Some(&ca), None, None).expect("client cfg");
321 assert!(Arc::strong_count(&cfg) >= 1);
322 }
323
324 #[test]
325 fn build_client_tls_connector_with_mtls_succeeds() {
326 let (cert, key) = gen_self_signed();
327 let cap = write_temp("ca2.pem", cert.as_bytes());
328 let cp = write_temp("cli.pem", cert.as_bytes());
329 let kp = write_temp("clikey.pem", key.as_bytes());
330 let cfg = build_client_tls_connector(Some(&cap), Some(&cp), Some(&kp)).expect("mtls");
331 assert!(Arc::strong_count(&cfg) >= 1);
332 }
333
334 #[test]
335 fn build_client_tls_connector_partial_auth_rejected() {
336 let (cert, _key) = gen_self_signed();
337 let cap = write_temp("ca3.pem", cert.as_bytes());
338 let cp = write_temp("cli2.pem", cert.as_bytes());
339 let err = build_client_tls_connector(Some(&cap), Some(&cp), None).unwrap_err();
340 assert!(matches!(err, TlsConfigError::Rustls(_)));
341 }
342}