1use std::collections::HashMap;
49use std::fs::File;
50use std::io::BufReader;
51use std::path::Path;
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54
55use parking_lot::RwLock;
56use rustls::client::ClientConfig;
57use rustls::pki_types::CertificateDer;
58use rustls::server::{ClientHello, ResolvesServerCert};
59use rustls::sign::CertifiedKey;
60use rustls::{RootCertStore, ServerConfig};
61use tracing::{debug, error, info, trace, warn};
62
63use sentinel_config::{TlsConfig, UpstreamTlsConfig};
64
65#[derive(Debug)]
67pub enum TlsError {
68 CertificateLoad(String),
70 KeyLoad(String),
72 ConfigBuild(String),
74 CertKeyMismatch(String),
76 InvalidCertificate(String),
78 OcspFetch(String),
80}
81
82impl std::fmt::Display for TlsError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 TlsError::CertificateLoad(e) => write!(f, "Failed to load certificate: {}", e),
86 TlsError::KeyLoad(e) => write!(f, "Failed to load private key: {}", e),
87 TlsError::ConfigBuild(e) => write!(f, "Failed to build TLS config: {}", e),
88 TlsError::CertKeyMismatch(e) => write!(f, "Certificate/key mismatch: {}", e),
89 TlsError::InvalidCertificate(e) => write!(f, "Invalid certificate: {}", e),
90 TlsError::OcspFetch(e) => write!(f, "Failed to fetch OCSP response: {}", e),
91 }
92 }
93}
94
95impl std::error::Error for TlsError {}
96
97#[derive(Debug)]
105pub struct SniResolver {
106 default_cert: Arc<CertifiedKey>,
108 sni_certs: HashMap<String, Arc<CertifiedKey>>,
111 wildcard_certs: HashMap<String, Arc<CertifiedKey>>,
113}
114
115impl SniResolver {
116 pub fn from_config(config: &TlsConfig) -> Result<Self, TlsError> {
118 let default_cert = load_certified_key(&config.cert_file, &config.key_file)?;
120
121 info!(
122 cert_file = %config.cert_file.display(),
123 "Loaded default TLS certificate"
124 );
125
126 let mut sni_certs = HashMap::new();
127 let mut wildcard_certs = HashMap::new();
128
129 for sni_config in &config.additional_certs {
131 let cert = load_certified_key(&sni_config.cert_file, &sni_config.key_file)?;
132 let cert = Arc::new(cert);
133
134 for hostname in &sni_config.hostnames {
135 let hostname_lower = hostname.to_lowercase();
136
137 if hostname_lower.starts_with("*.") {
138 let domain = hostname_lower.strip_prefix("*.").unwrap().to_string();
140 wildcard_certs.insert(domain.clone(), cert.clone());
141 debug!(
142 pattern = %hostname,
143 domain = %domain,
144 cert_file = %sni_config.cert_file.display(),
145 "Registered wildcard SNI certificate"
146 );
147 } else {
148 sni_certs.insert(hostname_lower.clone(), cert.clone());
150 debug!(
151 hostname = %hostname_lower,
152 cert_file = %sni_config.cert_file.display(),
153 "Registered SNI certificate"
154 );
155 }
156 }
157 }
158
159 info!(
160 exact_certs = sni_certs.len(),
161 wildcard_certs = wildcard_certs.len(),
162 "SNI resolver initialized"
163 );
164
165 Ok(Self {
166 default_cert: Arc::new(default_cert),
167 sni_certs,
168 wildcard_certs,
169 })
170 }
171
172 fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
174 let Some(name) = server_name else {
175 debug!("No SNI provided, using default certificate");
176 return self.default_cert.clone();
177 };
178
179 let name_lower = name.to_lowercase();
180
181 if let Some(cert) = self.sni_certs.get(&name_lower) {
183 debug!(hostname = %name_lower, "SNI exact match found");
184 return cert.clone();
185 }
186
187 let parts: Vec<&str> = name_lower.split('.').collect();
190 for i in 1..parts.len() {
191 let domain = parts[i..].join(".");
192 if let Some(cert) = self.wildcard_certs.get(&domain) {
193 debug!(
194 hostname = %name_lower,
195 wildcard_domain = %domain,
196 "SNI wildcard match found"
197 );
198 return cert.clone();
199 }
200 }
201
202 debug!(
203 hostname = %name_lower,
204 "No SNI match found, using default certificate"
205 );
206 self.default_cert.clone()
207 }
208}
209
210impl ResolvesServerCert for SniResolver {
211 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
212 Some(self.resolve(client_hello.server_name()))
213 }
214}
215
216pub struct HotReloadableSniResolver {
226 inner: RwLock<Arc<SniResolver>>,
228 config: RwLock<TlsConfig>,
230 last_reload: RwLock<Instant>,
232}
233
234impl std::fmt::Debug for HotReloadableSniResolver {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.debug_struct("HotReloadableSniResolver")
237 .field("last_reload", &*self.last_reload.read())
238 .finish()
239 }
240}
241
242impl HotReloadableSniResolver {
243 pub fn from_config(config: TlsConfig) -> Result<Self, TlsError> {
245 let resolver = SniResolver::from_config(&config)?;
246
247 Ok(Self {
248 inner: RwLock::new(Arc::new(resolver)),
249 config: RwLock::new(config),
250 last_reload: RwLock::new(Instant::now()),
251 })
252 }
253
254 pub fn reload(&self) -> Result<(), TlsError> {
259 let config = self.config.read();
260
261 info!(
262 cert_file = %config.cert_file.display(),
263 sni_count = config.additional_certs.len(),
264 "Reloading TLS certificates"
265 );
266
267 let new_resolver = SniResolver::from_config(&config)?;
269
270 *self.inner.write() = Arc::new(new_resolver);
272 *self.last_reload.write() = Instant::now();
273
274 info!("TLS certificates reloaded successfully");
275 Ok(())
276 }
277
278 pub fn update_config(&self, new_config: TlsConfig) -> Result<(), TlsError> {
280 let new_resolver = SniResolver::from_config(&new_config)?;
282
283 *self.config.write() = new_config;
285 *self.inner.write() = Arc::new(new_resolver);
286 *self.last_reload.write() = Instant::now();
287
288 info!("TLS configuration updated and certificates reloaded");
289 Ok(())
290 }
291
292 pub fn last_reload_age(&self) -> Duration {
294 self.last_reload.read().elapsed()
295 }
296}
297
298impl ResolvesServerCert for HotReloadableSniResolver {
299 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
300 Some(self.inner.read().resolve(client_hello.server_name()))
301 }
302}
303
304pub struct CertificateReloader {
308 resolvers: RwLock<HashMap<String, Arc<HotReloadableSniResolver>>>,
310}
311
312impl CertificateReloader {
313 pub fn new() -> Self {
315 Self {
316 resolvers: RwLock::new(HashMap::new()),
317 }
318 }
319
320 pub fn register(&self, listener_id: &str, resolver: Arc<HotReloadableSniResolver>) {
322 debug!(listener_id = %listener_id, "Registering TLS resolver for hot-reload");
323 self.resolvers
324 .write()
325 .insert(listener_id.to_string(), resolver);
326 }
327
328 pub fn reload_all(&self) -> (usize, Vec<(String, TlsError)>) {
332 let resolvers = self.resolvers.read();
333 let mut success_count = 0;
334 let mut errors = Vec::new();
335
336 info!(
337 listener_count = resolvers.len(),
338 "Reloading certificates for all TLS listeners"
339 );
340
341 for (listener_id, resolver) in resolvers.iter() {
342 match resolver.reload() {
343 Ok(()) => {
344 success_count += 1;
345 debug!(listener_id = %listener_id, "Certificate reload successful");
346 }
347 Err(e) => {
348 error!(listener_id = %listener_id, error = %e, "Certificate reload failed");
349 errors.push((listener_id.clone(), e));
350 }
351 }
352 }
353
354 if errors.is_empty() {
355 info!(
356 success_count = success_count,
357 "All certificates reloaded successfully"
358 );
359 } else {
360 warn!(
361 success_count = success_count,
362 error_count = errors.len(),
363 "Certificate reload completed with errors"
364 );
365 }
366
367 (success_count, errors)
368 }
369
370 pub fn status(&self) -> HashMap<String, Duration> {
372 self.resolvers
373 .read()
374 .iter()
375 .map(|(id, resolver)| (id.clone(), resolver.last_reload_age()))
376 .collect()
377 }
378}
379
380impl Default for CertificateReloader {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386#[derive(Debug, Clone)]
392pub struct OcspCacheEntry {
393 pub response: Vec<u8>,
395 pub fetched_at: Instant,
397 pub expires_at: Option<Instant>,
399}
400
401pub struct OcspStapler {
405 cache: RwLock<HashMap<String, OcspCacheEntry>>,
407 refresh_interval: Duration,
409}
410
411impl OcspStapler {
412 pub fn new() -> Self {
414 Self {
415 cache: RwLock::new(HashMap::new()),
416 refresh_interval: Duration::from_secs(3600), }
418 }
419
420 pub fn with_refresh_interval(interval: Duration) -> Self {
422 Self {
423 cache: RwLock::new(HashMap::new()),
424 refresh_interval: interval,
425 }
426 }
427
428 pub fn get_response(&self, cert_fingerprint: &str) -> Option<Vec<u8>> {
430 let cache = self.cache.read();
431 if let Some(entry) = cache.get(cert_fingerprint) {
432 if entry.fetched_at.elapsed() < self.refresh_interval {
434 trace!(fingerprint = %cert_fingerprint, "OCSP cache hit");
435 return Some(entry.response.clone());
436 }
437 trace!(fingerprint = %cert_fingerprint, "OCSP cache expired");
438 }
439 None
440 }
441
442 pub fn fetch_ocsp_response(
447 &self,
448 _cert_der: &[u8],
449 _issuer_der: &[u8],
450 ) -> Result<Vec<u8>, TlsError> {
451 warn!("OCSP stapling fetch not yet implemented - certificates will work without stapling");
464 Err(TlsError::OcspFetch(
465 "OCSP responder URL extraction requires x509-parser dependency".to_string(),
466 ))
467 }
468
469 pub fn prefetch_for_config(&self, config: &TlsConfig) -> Vec<String> {
471 let mut warnings = Vec::new();
472
473 if !config.ocsp_stapling {
474 trace!("OCSP stapling disabled in config");
475 return warnings;
476 }
477
478 info!("Prefetching OCSP responses for certificates");
479
480 warnings.push("OCSP stapling prefetch not yet fully implemented".to_string());
483
484 warnings
485 }
486
487 pub fn clear_cache(&self) {
489 self.cache.write().clear();
490 info!("OCSP cache cleared");
491 }
492}
493
494impl Default for OcspStapler {
495 fn default() -> Self {
496 Self::new()
497 }
498}
499
500pub fn build_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<ClientConfig, TlsError> {
509 let mut root_store = RootCertStore::empty();
510
511 if let Some(ca_path) = &config.ca_cert {
513 let ca_file = File::open(ca_path)
514 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
515 let mut ca_reader = BufReader::new(ca_file);
516
517 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
518 .collect::<Result<Vec<_>, _>>()
519 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
520
521 for cert in certs {
522 root_store.add(cert).map_err(|e| {
523 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
524 })?;
525 }
526
527 debug!(
528 ca_file = %ca_path.display(),
529 cert_count = root_store.len(),
530 "Loaded upstream CA certificates"
531 );
532 } else if !config.insecure_skip_verify {
533 root_store = RootCertStore {
535 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
536 };
537 trace!("Using webpki-roots for upstream TLS verification");
538 }
539
540 let builder = ClientConfig::builder().with_root_certificates(root_store);
542
543 let client_config = if let (Some(cert_path), Some(key_path)) =
544 (&config.client_cert, &config.client_key)
545 {
546 let cert_file = File::open(cert_path)
548 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
549 let mut cert_reader = BufReader::new(cert_file);
550
551 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
552 .collect::<Result<Vec<_>, _>>()
553 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
554
555 if certs.is_empty() {
556 return Err(TlsError::CertificateLoad(format!(
557 "{}: No certificates found",
558 cert_path.display()
559 )));
560 }
561
562 let key_file = File::open(key_path)
564 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
565 let mut key_reader = BufReader::new(key_file);
566
567 let key = rustls_pemfile::private_key(&mut key_reader)
568 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
569 .ok_or_else(|| {
570 TlsError::KeyLoad(format!("{}: No private key found", key_path.display()))
571 })?;
572
573 info!(
574 cert_file = %cert_path.display(),
575 "Configured mTLS client certificate for upstream connections"
576 );
577
578 builder
579 .with_client_auth_cert(certs, key)
580 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to set client auth: {}", e)))?
581 } else {
582 builder.with_no_client_auth()
584 };
585
586 debug!("Upstream TLS configuration built successfully");
587 Ok(client_config)
588}
589
590pub fn validate_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<(), TlsError> {
592 if let Some(ca_path) = &config.ca_cert {
594 if !ca_path.exists() {
595 return Err(TlsError::CertificateLoad(format!(
596 "Upstream CA certificate not found: {}",
597 ca_path.display()
598 )));
599 }
600 }
601
602 if let Some(cert_path) = &config.client_cert {
604 if !cert_path.exists() {
605 return Err(TlsError::CertificateLoad(format!(
606 "Upstream client certificate not found: {}",
607 cert_path.display()
608 )));
609 }
610
611 match &config.client_key {
613 Some(key_path) if !key_path.exists() => {
614 return Err(TlsError::KeyLoad(format!(
615 "Upstream client key not found: {}",
616 key_path.display()
617 )));
618 }
619 None => {
620 return Err(TlsError::ConfigBuild(
621 "client_cert specified without client_key".to_string(),
622 ));
623 }
624 _ => {}
625 }
626 }
627
628 if config.client_key.is_some() && config.client_cert.is_none() {
629 return Err(TlsError::ConfigBuild(
630 "client_key specified without client_cert".to_string(),
631 ));
632 }
633
634 Ok(())
635}
636
637fn load_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
643 let cert_file = File::open(cert_path)
645 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
646 let mut cert_reader = BufReader::new(cert_file);
647
648 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
649 .collect::<Result<Vec<_>, _>>()
650 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
651
652 if certs.is_empty() {
653 return Err(TlsError::CertificateLoad(format!(
654 "{}: No certificates found in file",
655 cert_path.display()
656 )));
657 }
658
659 let key_file = File::open(key_path)
661 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
662 let mut key_reader = BufReader::new(key_file);
663
664 let key = rustls_pemfile::private_key(&mut key_reader)
665 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
666 .ok_or_else(|| {
667 TlsError::KeyLoad(format!(
668 "{}: No private key found in file",
669 key_path.display()
670 ))
671 })?;
672
673 let provider = rustls::crypto::CryptoProvider::get_default()
675 .cloned()
676 .unwrap_or_else(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()));
677
678 let signing_key = provider
679 .key_provider
680 .load_private_key(key)
681 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to load private key: {:?}", e)))?;
682
683 Ok(CertifiedKey::new(certs, signing_key))
684}
685
686pub fn load_client_ca(ca_path: &Path) -> Result<RootCertStore, TlsError> {
688 let ca_file = File::open(ca_path)
689 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
690 let mut ca_reader = BufReader::new(ca_file);
691
692 let mut root_store = RootCertStore::empty();
693
694 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
695 .collect::<Result<Vec<_>, _>>()
696 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
697
698 for cert in certs {
699 root_store.add(cert).map_err(|e| {
700 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
701 })?;
702 }
703
704 if root_store.is_empty() {
705 return Err(TlsError::CertificateLoad(format!(
706 "{}: No CA certificates found",
707 ca_path.display()
708 )));
709 }
710
711 info!(
712 ca_file = %ca_path.display(),
713 cert_count = root_store.len(),
714 "Loaded client CA certificates"
715 );
716
717 Ok(root_store)
718}
719
720pub fn build_server_config(config: &TlsConfig) -> Result<ServerConfig, TlsError> {
722 let resolver = SniResolver::from_config(config)?;
723
724 let builder = ServerConfig::builder();
725
726 let server_config = if config.client_auth {
728 if let Some(ca_path) = &config.ca_file {
729 let root_store = load_client_ca(ca_path)?;
730 let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
731 .build()
732 .map_err(|e| {
733 TlsError::ConfigBuild(format!("Failed to build client verifier: {}", e))
734 })?;
735
736 info!("mTLS enabled: client certificates required");
737
738 builder
739 .with_client_cert_verifier(verifier)
740 .with_cert_resolver(Arc::new(resolver))
741 } else {
742 warn!("client_auth enabled but no ca_file specified, disabling client auth");
743 builder
744 .with_no_client_auth()
745 .with_cert_resolver(Arc::new(resolver))
746 }
747 } else {
748 builder
749 .with_no_client_auth()
750 .with_cert_resolver(Arc::new(resolver))
751 };
752
753 let mut config = server_config;
755 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
756
757 debug!("TLS configuration built successfully");
758
759 Ok(config)
760}
761
762pub fn validate_tls_config(config: &TlsConfig) -> Result<(), TlsError> {
764 if !config.cert_file.exists() {
766 return Err(TlsError::CertificateLoad(format!(
767 "Certificate file not found: {}",
768 config.cert_file.display()
769 )));
770 }
771 if !config.key_file.exists() {
772 return Err(TlsError::KeyLoad(format!(
773 "Key file not found: {}",
774 config.key_file.display()
775 )));
776 }
777
778 for sni in &config.additional_certs {
780 if !sni.cert_file.exists() {
781 return Err(TlsError::CertificateLoad(format!(
782 "SNI certificate file not found: {}",
783 sni.cert_file.display()
784 )));
785 }
786 if !sni.key_file.exists() {
787 return Err(TlsError::KeyLoad(format!(
788 "SNI key file not found: {}",
789 sni.key_file.display()
790 )));
791 }
792 }
793
794 if config.client_auth {
796 if let Some(ca_path) = &config.ca_file {
797 if !ca_path.exists() {
798 return Err(TlsError::CertificateLoad(format!(
799 "CA certificate file not found: {}",
800 ca_path.display()
801 )));
802 }
803 }
804 }
805
806 Ok(())
807}
808
809#[cfg(test)]
810mod tests {
811
812 #[test]
813 fn test_wildcard_matching() {
814 let name = "foo.bar.example.com";
817 let parts: Vec<&str> = name.split('.').collect();
818
819 assert_eq!(parts.len(), 4);
820
821 let domain1 = parts[1..].join(".");
823 assert_eq!(domain1, "bar.example.com");
824
825 let domain2 = parts[2..].join(".");
826 assert_eq!(domain2, "example.com");
827 }
828
829 #[test]
830 fn test_hostname_normalization() {
831 let hostname = "Example.COM";
832 let normalized = hostname.to_lowercase();
833 assert_eq!(normalized, "example.com");
834 }
835}