1use std::collections::HashMap;
49use std::fs::File;
50use std::io::BufReader;
51use std::path::Path;
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54
55use parking_lot::RwLock;
56use rustls::client::ClientConfig;
57use rustls::pki_types::CertificateDer;
58use rustls::server::{ClientHello, ResolvesServerCert};
59use rustls::sign::CertifiedKey;
60use rustls::{RootCertStore, ServerConfig};
61use tracing::{debug, error, info, trace, warn};
62
63use sentinel_config::{TlsConfig, UpstreamTlsConfig};
64
65#[derive(Debug)]
67pub enum TlsError {
68 CertificateLoad(String),
70 KeyLoad(String),
72 ConfigBuild(String),
74 CertKeyMismatch(String),
76 InvalidCertificate(String),
78 OcspFetch(String),
80}
81
82impl std::fmt::Display for TlsError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 TlsError::CertificateLoad(e) => write!(f, "Failed to load certificate: {}", e),
86 TlsError::KeyLoad(e) => write!(f, "Failed to load private key: {}", e),
87 TlsError::ConfigBuild(e) => write!(f, "Failed to build TLS config: {}", e),
88 TlsError::CertKeyMismatch(e) => write!(f, "Certificate/key mismatch: {}", e),
89 TlsError::InvalidCertificate(e) => write!(f, "Invalid certificate: {}", e),
90 TlsError::OcspFetch(e) => write!(f, "Failed to fetch OCSP response: {}", e),
91 }
92 }
93}
94
95impl std::error::Error for TlsError {}
96
97#[derive(Debug)]
105pub struct SniResolver {
106 default_cert: Arc<CertifiedKey>,
108 sni_certs: HashMap<String, Arc<CertifiedKey>>,
111 wildcard_certs: HashMap<String, Arc<CertifiedKey>>,
113}
114
115impl SniResolver {
116 pub fn from_config(config: &TlsConfig) -> Result<Self, TlsError> {
118 let default_cert = load_certified_key(&config.cert_file, &config.key_file)?;
120
121 info!(
122 cert_file = %config.cert_file.display(),
123 "Loaded default TLS certificate"
124 );
125
126 let mut sni_certs = HashMap::new();
127 let mut wildcard_certs = HashMap::new();
128
129 for sni_config in &config.additional_certs {
131 let cert = load_certified_key(&sni_config.cert_file, &sni_config.key_file)?;
132 let cert = Arc::new(cert);
133
134 for hostname in &sni_config.hostnames {
135 let hostname_lower = hostname.to_lowercase();
136
137 if hostname_lower.starts_with("*.") {
138 let domain = hostname_lower.strip_prefix("*.").unwrap().to_string();
140 wildcard_certs.insert(domain.clone(), cert.clone());
141 debug!(
142 pattern = %hostname,
143 domain = %domain,
144 cert_file = %sni_config.cert_file.display(),
145 "Registered wildcard SNI certificate"
146 );
147 } else {
148 sni_certs.insert(hostname_lower.clone(), cert.clone());
150 debug!(
151 hostname = %hostname_lower,
152 cert_file = %sni_config.cert_file.display(),
153 "Registered SNI certificate"
154 );
155 }
156 }
157 }
158
159 info!(
160 exact_certs = sni_certs.len(),
161 wildcard_certs = wildcard_certs.len(),
162 "SNI resolver initialized"
163 );
164
165 Ok(Self {
166 default_cert: Arc::new(default_cert),
167 sni_certs,
168 wildcard_certs,
169 })
170 }
171
172 pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
177 let Some(name) = server_name else {
178 debug!("No SNI provided, using default certificate");
179 return self.default_cert.clone();
180 };
181
182 let name_lower = name.to_lowercase();
183
184 if let Some(cert) = self.sni_certs.get(&name_lower) {
186 debug!(hostname = %name_lower, "SNI exact match found");
187 return cert.clone();
188 }
189
190 let parts: Vec<&str> = name_lower.split('.').collect();
193 for i in 1..parts.len() {
194 let domain = parts[i..].join(".");
195 if let Some(cert) = self.wildcard_certs.get(&domain) {
196 debug!(
197 hostname = %name_lower,
198 wildcard_domain = %domain,
199 "SNI wildcard match found"
200 );
201 return cert.clone();
202 }
203 }
204
205 debug!(
206 hostname = %name_lower,
207 "No SNI match found, using default certificate"
208 );
209 self.default_cert.clone()
210 }
211}
212
213impl ResolvesServerCert for SniResolver {
214 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
215 Some(self.resolve(client_hello.server_name()))
216 }
217}
218
219pub struct HotReloadableSniResolver {
229 inner: RwLock<Arc<SniResolver>>,
231 config: RwLock<TlsConfig>,
233 last_reload: RwLock<Instant>,
235}
236
237impl std::fmt::Debug for HotReloadableSniResolver {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.debug_struct("HotReloadableSniResolver")
240 .field("last_reload", &*self.last_reload.read())
241 .finish()
242 }
243}
244
245impl HotReloadableSniResolver {
246 pub fn from_config(config: TlsConfig) -> Result<Self, TlsError> {
248 let resolver = SniResolver::from_config(&config)?;
249
250 Ok(Self {
251 inner: RwLock::new(Arc::new(resolver)),
252 config: RwLock::new(config),
253 last_reload: RwLock::new(Instant::now()),
254 })
255 }
256
257 pub fn reload(&self) -> Result<(), TlsError> {
262 let config = self.config.read();
263
264 info!(
265 cert_file = %config.cert_file.display(),
266 sni_count = config.additional_certs.len(),
267 "Reloading TLS certificates"
268 );
269
270 let new_resolver = SniResolver::from_config(&config)?;
272
273 *self.inner.write() = Arc::new(new_resolver);
275 *self.last_reload.write() = Instant::now();
276
277 info!("TLS certificates reloaded successfully");
278 Ok(())
279 }
280
281 pub fn update_config(&self, new_config: TlsConfig) -> Result<(), TlsError> {
283 let new_resolver = SniResolver::from_config(&new_config)?;
285
286 *self.config.write() = new_config;
288 *self.inner.write() = Arc::new(new_resolver);
289 *self.last_reload.write() = Instant::now();
290
291 info!("TLS configuration updated and certificates reloaded");
292 Ok(())
293 }
294
295 pub fn last_reload_age(&self) -> Duration {
297 self.last_reload.read().elapsed()
298 }
299
300 pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
304 self.inner.read().resolve(server_name)
305 }
306}
307
308impl ResolvesServerCert for HotReloadableSniResolver {
309 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
310 Some(self.inner.read().resolve(client_hello.server_name()))
311 }
312}
313
314pub struct CertificateReloader {
318 resolvers: RwLock<HashMap<String, Arc<HotReloadableSniResolver>>>,
320}
321
322impl CertificateReloader {
323 pub fn new() -> Self {
325 Self {
326 resolvers: RwLock::new(HashMap::new()),
327 }
328 }
329
330 pub fn register(&self, listener_id: &str, resolver: Arc<HotReloadableSniResolver>) {
332 debug!(listener_id = %listener_id, "Registering TLS resolver for hot-reload");
333 self.resolvers
334 .write()
335 .insert(listener_id.to_string(), resolver);
336 }
337
338 pub fn reload_all(&self) -> (usize, Vec<(String, TlsError)>) {
342 let resolvers = self.resolvers.read();
343 let mut success_count = 0;
344 let mut errors = Vec::new();
345
346 info!(
347 listener_count = resolvers.len(),
348 "Reloading certificates for all TLS listeners"
349 );
350
351 for (listener_id, resolver) in resolvers.iter() {
352 match resolver.reload() {
353 Ok(()) => {
354 success_count += 1;
355 debug!(listener_id = %listener_id, "Certificate reload successful");
356 }
357 Err(e) => {
358 error!(listener_id = %listener_id, error = %e, "Certificate reload failed");
359 errors.push((listener_id.clone(), e));
360 }
361 }
362 }
363
364 if errors.is_empty() {
365 info!(
366 success_count = success_count,
367 "All certificates reloaded successfully"
368 );
369 } else {
370 warn!(
371 success_count = success_count,
372 error_count = errors.len(),
373 "Certificate reload completed with errors"
374 );
375 }
376
377 (success_count, errors)
378 }
379
380 pub fn status(&self) -> HashMap<String, Duration> {
382 self.resolvers
383 .read()
384 .iter()
385 .map(|(id, resolver)| (id.clone(), resolver.last_reload_age()))
386 .collect()
387 }
388}
389
390impl Default for CertificateReloader {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396#[derive(Debug, Clone)]
402pub struct OcspCacheEntry {
403 pub response: Vec<u8>,
405 pub fetched_at: Instant,
407 pub expires_at: Option<Instant>,
409}
410
411pub struct OcspStapler {
415 cache: RwLock<HashMap<String, OcspCacheEntry>>,
417 refresh_interval: Duration,
419}
420
421impl OcspStapler {
422 pub fn new() -> Self {
424 Self {
425 cache: RwLock::new(HashMap::new()),
426 refresh_interval: Duration::from_secs(3600), }
428 }
429
430 pub fn with_refresh_interval(interval: Duration) -> Self {
432 Self {
433 cache: RwLock::new(HashMap::new()),
434 refresh_interval: interval,
435 }
436 }
437
438 pub fn get_response(&self, cert_fingerprint: &str) -> Option<Vec<u8>> {
440 let cache = self.cache.read();
441 if let Some(entry) = cache.get(cert_fingerprint) {
442 if entry.fetched_at.elapsed() < self.refresh_interval {
444 trace!(fingerprint = %cert_fingerprint, "OCSP cache hit");
445 return Some(entry.response.clone());
446 }
447 trace!(fingerprint = %cert_fingerprint, "OCSP cache expired");
448 }
449 None
450 }
451
452 pub fn fetch_ocsp_response(
457 &self,
458 _cert_der: &[u8],
459 _issuer_der: &[u8],
460 ) -> Result<Vec<u8>, TlsError> {
461 warn!("OCSP stapling fetch not yet implemented - certificates will work without stapling");
474 Err(TlsError::OcspFetch(
475 "OCSP responder URL extraction requires x509-parser dependency".to_string(),
476 ))
477 }
478
479 pub fn prefetch_for_config(&self, config: &TlsConfig) -> Vec<String> {
481 let mut warnings = Vec::new();
482
483 if !config.ocsp_stapling {
484 trace!("OCSP stapling disabled in config");
485 return warnings;
486 }
487
488 info!("Prefetching OCSP responses for certificates");
489
490 warnings.push("OCSP stapling prefetch not yet fully implemented".to_string());
493
494 warnings
495 }
496
497 pub fn clear_cache(&self) {
499 self.cache.write().clear();
500 info!("OCSP cache cleared");
501 }
502}
503
504impl Default for OcspStapler {
505 fn default() -> Self {
506 Self::new()
507 }
508}
509
510pub fn build_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<ClientConfig, TlsError> {
519 let mut root_store = RootCertStore::empty();
520
521 if let Some(ca_path) = &config.ca_cert {
523 let ca_file = File::open(ca_path)
524 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
525 let mut ca_reader = BufReader::new(ca_file);
526
527 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
528 .collect::<Result<Vec<_>, _>>()
529 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
530
531 for cert in certs {
532 root_store.add(cert).map_err(|e| {
533 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
534 })?;
535 }
536
537 debug!(
538 ca_file = %ca_path.display(),
539 cert_count = root_store.len(),
540 "Loaded upstream CA certificates"
541 );
542 } else if !config.insecure_skip_verify {
543 root_store = RootCertStore {
545 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
546 };
547 trace!("Using webpki-roots for upstream TLS verification");
548 }
549
550 let builder = ClientConfig::builder().with_root_certificates(root_store);
552
553 let client_config = if let (Some(cert_path), Some(key_path)) =
554 (&config.client_cert, &config.client_key)
555 {
556 let cert_file = File::open(cert_path)
558 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
559 let mut cert_reader = BufReader::new(cert_file);
560
561 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
562 .collect::<Result<Vec<_>, _>>()
563 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
564
565 if certs.is_empty() {
566 return Err(TlsError::CertificateLoad(format!(
567 "{}: No certificates found",
568 cert_path.display()
569 )));
570 }
571
572 let key_file = File::open(key_path)
574 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
575 let mut key_reader = BufReader::new(key_file);
576
577 let key = rustls_pemfile::private_key(&mut key_reader)
578 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
579 .ok_or_else(|| {
580 TlsError::KeyLoad(format!("{}: No private key found", key_path.display()))
581 })?;
582
583 info!(
584 cert_file = %cert_path.display(),
585 "Configured mTLS client certificate for upstream connections"
586 );
587
588 builder
589 .with_client_auth_cert(certs, key)
590 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to set client auth: {}", e)))?
591 } else {
592 builder.with_no_client_auth()
594 };
595
596 debug!("Upstream TLS configuration built successfully");
597 Ok(client_config)
598}
599
600pub fn validate_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<(), TlsError> {
602 if let Some(ca_path) = &config.ca_cert {
604 if !ca_path.exists() {
605 return Err(TlsError::CertificateLoad(format!(
606 "Upstream CA certificate not found: {}",
607 ca_path.display()
608 )));
609 }
610 }
611
612 if let Some(cert_path) = &config.client_cert {
614 if !cert_path.exists() {
615 return Err(TlsError::CertificateLoad(format!(
616 "Upstream client certificate not found: {}",
617 cert_path.display()
618 )));
619 }
620
621 match &config.client_key {
623 Some(key_path) if !key_path.exists() => {
624 return Err(TlsError::KeyLoad(format!(
625 "Upstream client key not found: {}",
626 key_path.display()
627 )));
628 }
629 None => {
630 return Err(TlsError::ConfigBuild(
631 "client_cert specified without client_key".to_string(),
632 ));
633 }
634 _ => {}
635 }
636 }
637
638 if config.client_key.is_some() && config.client_cert.is_none() {
639 return Err(TlsError::ConfigBuild(
640 "client_key specified without client_cert".to_string(),
641 ));
642 }
643
644 Ok(())
645}
646
647fn load_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
653 let cert_file = File::open(cert_path)
655 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
656 let mut cert_reader = BufReader::new(cert_file);
657
658 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
659 .collect::<Result<Vec<_>, _>>()
660 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
661
662 if certs.is_empty() {
663 return Err(TlsError::CertificateLoad(format!(
664 "{}: No certificates found in file",
665 cert_path.display()
666 )));
667 }
668
669 let key_file = File::open(key_path)
671 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
672 let mut key_reader = BufReader::new(key_file);
673
674 let key = rustls_pemfile::private_key(&mut key_reader)
675 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
676 .ok_or_else(|| {
677 TlsError::KeyLoad(format!(
678 "{}: No private key found in file",
679 key_path.display()
680 ))
681 })?;
682
683 let provider = rustls::crypto::CryptoProvider::get_default()
685 .cloned()
686 .unwrap_or_else(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()));
687
688 let signing_key = provider
689 .key_provider
690 .load_private_key(key)
691 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to load private key: {:?}", e)))?;
692
693 Ok(CertifiedKey::new(certs, signing_key))
694}
695
696pub fn load_client_ca(ca_path: &Path) -> Result<RootCertStore, TlsError> {
698 let ca_file = File::open(ca_path)
699 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
700 let mut ca_reader = BufReader::new(ca_file);
701
702 let mut root_store = RootCertStore::empty();
703
704 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
705 .collect::<Result<Vec<_>, _>>()
706 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
707
708 for cert in certs {
709 root_store.add(cert).map_err(|e| {
710 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
711 })?;
712 }
713
714 if root_store.is_empty() {
715 return Err(TlsError::CertificateLoad(format!(
716 "{}: No CA certificates found",
717 ca_path.display()
718 )));
719 }
720
721 info!(
722 ca_file = %ca_path.display(),
723 cert_count = root_store.len(),
724 "Loaded client CA certificates"
725 );
726
727 Ok(root_store)
728}
729
730pub fn build_server_config(config: &TlsConfig) -> Result<ServerConfig, TlsError> {
732 let resolver = SniResolver::from_config(config)?;
733
734 let builder = ServerConfig::builder();
735
736 let server_config = if config.client_auth {
738 if let Some(ca_path) = &config.ca_file {
739 let root_store = load_client_ca(ca_path)?;
740 let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
741 .build()
742 .map_err(|e| {
743 TlsError::ConfigBuild(format!("Failed to build client verifier: {}", e))
744 })?;
745
746 info!("mTLS enabled: client certificates required");
747
748 builder
749 .with_client_cert_verifier(verifier)
750 .with_cert_resolver(Arc::new(resolver))
751 } else {
752 warn!("client_auth enabled but no ca_file specified, disabling client auth");
753 builder
754 .with_no_client_auth()
755 .with_cert_resolver(Arc::new(resolver))
756 }
757 } else {
758 builder
759 .with_no_client_auth()
760 .with_cert_resolver(Arc::new(resolver))
761 };
762
763 let mut config = server_config;
765 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
766
767 debug!("TLS configuration built successfully");
768
769 Ok(config)
770}
771
772pub fn validate_tls_config(config: &TlsConfig) -> Result<(), TlsError> {
774 if !config.cert_file.exists() {
776 return Err(TlsError::CertificateLoad(format!(
777 "Certificate file not found: {}",
778 config.cert_file.display()
779 )));
780 }
781 if !config.key_file.exists() {
782 return Err(TlsError::KeyLoad(format!(
783 "Key file not found: {}",
784 config.key_file.display()
785 )));
786 }
787
788 for sni in &config.additional_certs {
790 if !sni.cert_file.exists() {
791 return Err(TlsError::CertificateLoad(format!(
792 "SNI certificate file not found: {}",
793 sni.cert_file.display()
794 )));
795 }
796 if !sni.key_file.exists() {
797 return Err(TlsError::KeyLoad(format!(
798 "SNI key file not found: {}",
799 sni.key_file.display()
800 )));
801 }
802 }
803
804 if config.client_auth {
806 if let Some(ca_path) = &config.ca_file {
807 if !ca_path.exists() {
808 return Err(TlsError::CertificateLoad(format!(
809 "CA certificate file not found: {}",
810 ca_path.display()
811 )));
812 }
813 }
814 }
815
816 Ok(())
817}
818
819#[cfg(test)]
820mod tests {
821
822 #[test]
823 fn test_wildcard_matching() {
824 let name = "foo.bar.example.com";
827 let parts: Vec<&str> = name.split('.').collect();
828
829 assert_eq!(parts.len(), 4);
830
831 let domain1 = parts[1..].join(".");
833 assert_eq!(domain1, "bar.example.com");
834
835 let domain2 = parts[2..].join(".");
836 assert_eq!(domain2, "example.com");
837 }
838
839 #[test]
840 fn test_hostname_normalization() {
841 let hostname = "Example.COM";
842 let normalized = hostname.to_lowercase();
843 assert_eq!(normalized, "example.com");
844 }
845}