1use std::collections::HashMap;
49use std::fs::File;
50use std::io::BufReader;
51use std::path::Path;
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54
55use parking_lot::RwLock;
56use rustls::client::ClientConfig;
57use rustls::pki_types::CertificateDer;
58use rustls::server::{ClientHello, ResolvesServerCert};
59use rustls::sign::CertifiedKey;
60use rustls::{RootCertStore, ServerConfig};
61use tracing::{debug, error, info, trace, warn};
62
63use sentinel_config::{TlsConfig, UpstreamTlsConfig};
64
65#[derive(Debug)]
67pub enum TlsError {
68 CertificateLoad(String),
70 KeyLoad(String),
72 ConfigBuild(String),
74 CertKeyMismatch(String),
76 InvalidCertificate(String),
78 OcspFetch(String),
80}
81
82impl std::fmt::Display for TlsError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 TlsError::CertificateLoad(e) => write!(f, "Failed to load certificate: {}", e),
86 TlsError::KeyLoad(e) => write!(f, "Failed to load private key: {}", e),
87 TlsError::ConfigBuild(e) => write!(f, "Failed to build TLS config: {}", e),
88 TlsError::CertKeyMismatch(e) => write!(f, "Certificate/key mismatch: {}", e),
89 TlsError::InvalidCertificate(e) => write!(f, "Invalid certificate: {}", e),
90 TlsError::OcspFetch(e) => write!(f, "Failed to fetch OCSP response: {}", e),
91 }
92 }
93}
94
95impl std::error::Error for TlsError {}
96
97#[derive(Debug)]
105pub struct SniResolver {
106 default_cert: Arc<CertifiedKey>,
108 sni_certs: HashMap<String, Arc<CertifiedKey>>,
111 wildcard_certs: HashMap<String, Arc<CertifiedKey>>,
113}
114
115impl SniResolver {
116 pub fn from_config(config: &TlsConfig) -> Result<Self, TlsError> {
118 let default_cert = load_certified_key(&config.cert_file, &config.key_file)?;
120
121 info!(
122 cert_file = %config.cert_file.display(),
123 "Loaded default TLS certificate"
124 );
125
126 let mut sni_certs = HashMap::new();
127 let mut wildcard_certs = HashMap::new();
128
129 for sni_config in &config.additional_certs {
131 let cert = load_certified_key(&sni_config.cert_file, &sni_config.key_file)?;
132 let cert = Arc::new(cert);
133
134 for hostname in &sni_config.hostnames {
135 let hostname_lower = hostname.to_lowercase();
136
137 if hostname_lower.starts_with("*.") {
138 let domain = hostname_lower.strip_prefix("*.").unwrap().to_string();
140 wildcard_certs.insert(domain.clone(), cert.clone());
141 debug!(
142 pattern = %hostname,
143 domain = %domain,
144 cert_file = %sni_config.cert_file.display(),
145 "Registered wildcard SNI certificate"
146 );
147 } else {
148 sni_certs.insert(hostname_lower.clone(), cert.clone());
150 debug!(
151 hostname = %hostname_lower,
152 cert_file = %sni_config.cert_file.display(),
153 "Registered SNI certificate"
154 );
155 }
156 }
157 }
158
159 info!(
160 exact_certs = sni_certs.len(),
161 wildcard_certs = wildcard_certs.len(),
162 "SNI resolver initialized"
163 );
164
165 Ok(Self {
166 default_cert: Arc::new(default_cert),
167 sni_certs,
168 wildcard_certs,
169 })
170 }
171
172 pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
177 let Some(name) = server_name else {
178 debug!("No SNI provided, using default certificate");
179 return self.default_cert.clone();
180 };
181
182 let name_lower = name.to_lowercase();
183
184 if let Some(cert) = self.sni_certs.get(&name_lower) {
186 debug!(hostname = %name_lower, "SNI exact match found");
187 return cert.clone();
188 }
189
190 let parts: Vec<&str> = name_lower.split('.').collect();
193 for i in 1..parts.len() {
194 let domain = parts[i..].join(".");
195 if let Some(cert) = self.wildcard_certs.get(&domain) {
196 debug!(
197 hostname = %name_lower,
198 wildcard_domain = %domain,
199 "SNI wildcard match found"
200 );
201 return cert.clone();
202 }
203 }
204
205 debug!(
206 hostname = %name_lower,
207 "No SNI match found, using default certificate"
208 );
209 self.default_cert.clone()
210 }
211}
212
213impl ResolvesServerCert for SniResolver {
214 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
215 Some(self.resolve(client_hello.server_name()))
216 }
217}
218
219pub struct HotReloadableSniResolver {
229 inner: RwLock<Arc<SniResolver>>,
231 config: RwLock<TlsConfig>,
233 last_reload: RwLock<Instant>,
235}
236
237impl std::fmt::Debug for HotReloadableSniResolver {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.debug_struct("HotReloadableSniResolver")
240 .field("last_reload", &*self.last_reload.read())
241 .finish()
242 }
243}
244
245impl HotReloadableSniResolver {
246 pub fn from_config(config: TlsConfig) -> Result<Self, TlsError> {
248 let resolver = SniResolver::from_config(&config)?;
249
250 Ok(Self {
251 inner: RwLock::new(Arc::new(resolver)),
252 config: RwLock::new(config),
253 last_reload: RwLock::new(Instant::now()),
254 })
255 }
256
257 pub fn reload(&self) -> Result<(), TlsError> {
262 let config = self.config.read();
263
264 info!(
265 cert_file = %config.cert_file.display(),
266 sni_count = config.additional_certs.len(),
267 "Reloading TLS certificates"
268 );
269
270 let new_resolver = SniResolver::from_config(&config)?;
272
273 *self.inner.write() = Arc::new(new_resolver);
275 *self.last_reload.write() = Instant::now();
276
277 info!("TLS certificates reloaded successfully");
278 Ok(())
279 }
280
281 pub fn update_config(&self, new_config: TlsConfig) -> Result<(), TlsError> {
283 let new_resolver = SniResolver::from_config(&new_config)?;
285
286 *self.config.write() = new_config;
288 *self.inner.write() = Arc::new(new_resolver);
289 *self.last_reload.write() = Instant::now();
290
291 info!("TLS configuration updated and certificates reloaded");
292 Ok(())
293 }
294
295 pub fn last_reload_age(&self) -> Duration {
297 self.last_reload.read().elapsed()
298 }
299
300 pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
304 self.inner.read().resolve(server_name)
305 }
306}
307
308impl ResolvesServerCert for HotReloadableSniResolver {
309 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
310 Some(self.inner.read().resolve(client_hello.server_name()))
311 }
312}
313
314pub struct CertificateReloader {
318 resolvers: RwLock<HashMap<String, Arc<HotReloadableSniResolver>>>,
320}
321
322impl CertificateReloader {
323 pub fn new() -> Self {
325 Self {
326 resolvers: RwLock::new(HashMap::new()),
327 }
328 }
329
330 pub fn register(&self, listener_id: &str, resolver: Arc<HotReloadableSniResolver>) {
332 debug!(listener_id = %listener_id, "Registering TLS resolver for hot-reload");
333 self.resolvers
334 .write()
335 .insert(listener_id.to_string(), resolver);
336 }
337
338 pub fn reload_all(&self) -> (usize, Vec<(String, TlsError)>) {
342 let resolvers = self.resolvers.read();
343 let mut success_count = 0;
344 let mut errors = Vec::new();
345
346 info!(
347 listener_count = resolvers.len(),
348 "Reloading certificates for all TLS listeners"
349 );
350
351 for (listener_id, resolver) in resolvers.iter() {
352 match resolver.reload() {
353 Ok(()) => {
354 success_count += 1;
355 debug!(listener_id = %listener_id, "Certificate reload successful");
356 }
357 Err(e) => {
358 error!(listener_id = %listener_id, error = %e, "Certificate reload failed");
359 errors.push((listener_id.clone(), e));
360 }
361 }
362 }
363
364 if errors.is_empty() {
365 info!(
366 success_count = success_count,
367 "All certificates reloaded successfully"
368 );
369 } else {
370 warn!(
371 success_count = success_count,
372 error_count = errors.len(),
373 "Certificate reload completed with errors"
374 );
375 }
376
377 (success_count, errors)
378 }
379
380 pub fn status(&self) -> HashMap<String, Duration> {
382 self.resolvers
383 .read()
384 .iter()
385 .map(|(id, resolver)| (id.clone(), resolver.last_reload_age()))
386 .collect()
387 }
388}
389
390impl Default for CertificateReloader {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396#[derive(Debug, Clone)]
402pub struct OcspCacheEntry {
403 pub response: Vec<u8>,
405 pub fetched_at: Instant,
407 pub expires_at: Option<Instant>,
409}
410
411pub struct OcspStapler {
415 cache: RwLock<HashMap<String, OcspCacheEntry>>,
417 refresh_interval: Duration,
419}
420
421impl OcspStapler {
422 pub fn new() -> Self {
424 Self {
425 cache: RwLock::new(HashMap::new()),
426 refresh_interval: Duration::from_secs(3600), }
428 }
429
430 pub fn with_refresh_interval(interval: Duration) -> Self {
432 Self {
433 cache: RwLock::new(HashMap::new()),
434 refresh_interval: interval,
435 }
436 }
437
438 pub fn get_response(&self, cert_fingerprint: &str) -> Option<Vec<u8>> {
440 let cache = self.cache.read();
441 if let Some(entry) = cache.get(cert_fingerprint) {
442 if entry.fetched_at.elapsed() < self.refresh_interval {
444 trace!(fingerprint = %cert_fingerprint, "OCSP cache hit");
445 return Some(entry.response.clone());
446 }
447 trace!(fingerprint = %cert_fingerprint, "OCSP cache expired");
448 }
449 None
450 }
451
452 pub fn fetch_ocsp_response(
457 &self,
458 cert_der: &[u8],
459 issuer_der: &[u8],
460 ) -> Result<Vec<u8>, TlsError> {
461 use x509_parser::prelude::*;
462
463 let (_, cert) = X509Certificate::from_der(cert_der)
465 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
466
467 let (_, issuer) = X509Certificate::from_der(issuer_der)
469 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e)))?;
470
471 let ocsp_url = extract_ocsp_responder_url(&cert)?;
473 debug!(url = %ocsp_url, "Found OCSP responder URL");
474
475 let ocsp_request = build_ocsp_request(&cert, &issuer)?;
477
478 let response = send_ocsp_request_sync(&ocsp_url, &ocsp_request)?;
481
482 let fingerprint = calculate_cert_fingerprint(cert_der);
484
485 let entry = OcspCacheEntry {
487 response: response.clone(),
488 fetched_at: Instant::now(),
489 expires_at: None, };
491 self.cache.write().insert(fingerprint, entry);
492
493 info!("Successfully fetched and cached OCSP response");
494 Ok(response)
495 }
496
497 pub async fn fetch_ocsp_response_async(
499 &self,
500 cert_der: &[u8],
501 issuer_der: &[u8],
502 ) -> Result<Vec<u8>, TlsError> {
503 use x509_parser::prelude::*;
504
505 let (_, cert) = X509Certificate::from_der(cert_der)
507 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
508
509 let (_, issuer) = X509Certificate::from_der(issuer_der)
511 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e)))?;
512
513 let ocsp_url = extract_ocsp_responder_url(&cert)?;
515 debug!(url = %ocsp_url, "Found OCSP responder URL");
516
517 let ocsp_request = build_ocsp_request(&cert, &issuer)?;
519
520 let response = send_ocsp_request_async(&ocsp_url, &ocsp_request).await?;
522
523 let fingerprint = calculate_cert_fingerprint(cert_der);
525
526 let entry = OcspCacheEntry {
528 response: response.clone(),
529 fetched_at: Instant::now(),
530 expires_at: None,
531 };
532 self.cache.write().insert(fingerprint, entry);
533
534 info!("Successfully fetched and cached OCSP response (async)");
535 Ok(response)
536 }
537
538 pub fn prefetch_for_config(&self, config: &TlsConfig) -> Vec<String> {
540 let mut warnings = Vec::new();
541
542 if !config.ocsp_stapling {
543 trace!("OCSP stapling disabled in config");
544 return warnings;
545 }
546
547 info!("Prefetching OCSP responses for certificates");
548
549 warnings.push("OCSP stapling prefetch not yet fully implemented".to_string());
552
553 warnings
554 }
555
556 pub fn clear_cache(&self) {
558 self.cache.write().clear();
559 info!("OCSP cache cleared");
560 }
561}
562
563impl Default for OcspStapler {
564 fn default() -> Self {
565 Self::new()
566 }
567}
568
569fn extract_ocsp_responder_url(cert: &x509_parser::certificate::X509Certificate) -> Result<String, TlsError> {
575 use x509_parser::prelude::*;
576
577 let aia = cert
579 .extensions()
580 .iter()
581 .find(|ext| ext.oid == oid_registry::OID_PKIX_AUTHORITY_INFO_ACCESS)
582 .ok_or_else(|| TlsError::OcspFetch(
583 "Certificate does not have Authority Information Access extension".to_string()
584 ))?;
585
586 let aia_value = match aia.parsed_extension() {
588 ParsedExtension::AuthorityInfoAccess(aia) => aia,
589 _ => return Err(TlsError::OcspFetch(
590 "Failed to parse Authority Information Access extension".to_string()
591 )),
592 };
593
594 for access in &aia_value.accessdescs {
596 if access.access_method == oid_registry::OID_PKIX_ACCESS_DESCRIPTOR_OCSP {
597 match &access.access_location {
598 GeneralName::URI(url) => {
599 return Ok(url.to_string());
600 }
601 _ => continue,
602 }
603 }
604 }
605
606 Err(TlsError::OcspFetch(
607 "Certificate AIA does not contain OCSP responder URL".to_string()
608 ))
609}
610
611fn build_ocsp_request(
615 cert: &x509_parser::certificate::X509Certificate,
616 issuer: &x509_parser::certificate::X509Certificate,
617) -> Result<Vec<u8>, TlsError> {
618 use sha2::{Sha256, Digest};
619
620 let issuer_name_hash = {
627 let mut hasher = Sha256::new();
628 hasher.update(issuer.subject().as_raw());
629 hasher.finalize()
630 };
631
632 let issuer_key_hash = {
634 let mut hasher = Sha256::new();
635 hasher.update(issuer.public_key().subject_public_key.data.as_ref());
636 hasher.finalize()
637 };
638
639 let serial = cert.serial.to_bytes_be();
641
642 let request = build_ocsp_request_der(
645 &issuer_name_hash,
646 &issuer_key_hash,
647 &serial,
648 );
649
650 Ok(request)
651}
652
653fn build_ocsp_request_der(
655 issuer_name_hash: &[u8],
656 issuer_key_hash: &[u8],
657 serial_number: &[u8],
658) -> Vec<u8> {
659 let sha256_oid: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01];
661
662 let hash_algorithm = der_sequence(&[
664 &der_oid(sha256_oid),
665 &der_null(),
666 ]);
667
668 let cert_id = der_sequence(&[
669 &hash_algorithm,
670 &der_octet_string(issuer_name_hash),
671 &der_octet_string(issuer_key_hash),
672 &der_integer(serial_number),
673 ]);
674
675 let request = der_sequence(&[&cert_id]);
677
678 let request_list = der_sequence(&[&request]);
680
681 let tbs_request = der_sequence(&[&request_list]);
683
684 der_sequence(&[&tbs_request])
686}
687
688fn der_sequence(items: &[&[u8]]) -> Vec<u8> {
690 let mut content = Vec::new();
691 for item in items {
692 content.extend_from_slice(item);
693 }
694 let mut result = vec![0x30]; result.extend(der_length(content.len()));
696 result.extend(content);
697 result
698}
699
700fn der_oid(oid: &[u8]) -> Vec<u8> {
701 let mut result = vec![0x06]; result.extend(der_length(oid.len()));
703 result.extend_from_slice(oid);
704 result
705}
706
707fn der_null() -> Vec<u8> {
708 vec![0x05, 0x00] }
710
711fn der_octet_string(data: &[u8]) -> Vec<u8> {
712 let mut result = vec![0x04]; result.extend(der_length(data.len()));
714 result.extend_from_slice(data);
715 result
716}
717
718fn der_integer(data: &[u8]) -> Vec<u8> {
719 let mut result = vec![0x02]; let data = match data.iter().position(|&b| b != 0) {
722 Some(pos) => &data[pos..],
723 None => &[0],
724 };
725 if !data.is_empty() && data[0] & 0x80 != 0 {
727 result.extend(der_length(data.len() + 1));
728 result.push(0x00);
729 } else {
730 result.extend(der_length(data.len()));
731 }
732 result.extend_from_slice(data);
733 result
734}
735
736fn der_length(len: usize) -> Vec<u8> {
737 if len < 128 {
738 vec![len as u8]
739 } else if len < 256 {
740 vec![0x81, len as u8]
741 } else {
742 vec![0x82, (len >> 8) as u8, len as u8]
743 }
744}
745
746fn send_ocsp_request_sync(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
748 use std::io::{Read, Write};
749 use std::net::TcpStream;
750 use std::time::Duration;
751
752 let url = url::Url::parse(url)
754 .map_err(|e| TlsError::OcspFetch(format!("Invalid OCSP URL: {}", e)))?;
755
756 let host = url.host_str()
757 .ok_or_else(|| TlsError::OcspFetch("OCSP URL has no host".to_string()))?;
758 let port = url.port().unwrap_or(80);
759 let path = if url.path().is_empty() { "/" } else { url.path() };
760
761 let addr = format!("{}:{}", host, port);
763 let mut stream = TcpStream::connect(&addr)
764 .map_err(|e| TlsError::OcspFetch(format!("Failed to connect to OCSP responder: {}", e)))?;
765
766 stream.set_read_timeout(Some(Duration::from_secs(10)))
767 .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
768 stream.set_write_timeout(Some(Duration::from_secs(10)))
769 .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
770
771 let http_request = format!(
773 "POST {} HTTP/1.1\r\n\
774 Host: {}\r\n\
775 Content-Type: application/ocsp-request\r\n\
776 Content-Length: {}\r\n\
777 Connection: close\r\n\
778 \r\n",
779 path, host, request.len()
780 );
781
782 stream.write_all(http_request.as_bytes())
784 .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request: {}", e)))?;
785 stream.write_all(request)
786 .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request body: {}", e)))?;
787
788 let mut response = Vec::new();
790 stream.read_to_end(&mut response)
791 .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
792
793 let headers_end = response.windows(4)
795 .position(|w| w == b"\r\n\r\n")
796 .ok_or_else(|| TlsError::OcspFetch("Invalid HTTP response: no headers end".to_string()))?;
797
798 let body = &response[headers_end + 4..];
799 if body.is_empty() {
800 return Err(TlsError::OcspFetch("Empty OCSP response body".to_string()));
801 }
802
803 Ok(body.to_vec())
804}
805
806async fn send_ocsp_request_async(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
808 let client = reqwest::Client::builder()
809 .timeout(Duration::from_secs(10))
810 .build()
811 .map_err(|e| TlsError::OcspFetch(format!("Failed to create HTTP client: {}", e)))?;
812
813 let response = client
814 .post(url)
815 .header("Content-Type", "application/ocsp-request")
816 .body(request.to_vec())
817 .send()
818 .await
819 .map_err(|e| TlsError::OcspFetch(format!("OCSP request failed: {}", e)))?;
820
821 if !response.status().is_success() {
822 return Err(TlsError::OcspFetch(format!(
823 "OCSP responder returned status: {}",
824 response.status()
825 )));
826 }
827
828 let body = response.bytes().await
829 .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
830
831 Ok(body.to_vec())
832}
833
834fn calculate_cert_fingerprint(cert_der: &[u8]) -> String {
836 use sha2::{Sha256, Digest};
837 let mut hasher = Sha256::new();
838 hasher.update(cert_der);
839 let result = hasher.finalize();
840 hex::encode(result)
841}
842
843pub fn load_client_cert_key(
861 cert_path: &Path,
862 key_path: &Path,
863) -> Result<Arc<pingora_core::utils::tls::CertKey>, TlsError> {
864 let cert_file = File::open(cert_path)
866 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
867 let mut cert_reader = BufReader::new(cert_file);
868
869 let cert_ders: Vec<Vec<u8>> = rustls_pemfile::certs(&mut cert_reader)
871 .collect::<Result<Vec<_>, _>>()
872 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?
873 .into_iter()
874 .map(|c| c.to_vec())
875 .collect();
876
877 if cert_ders.is_empty() {
878 return Err(TlsError::CertificateLoad(format!(
879 "{}: No certificates found in PEM file",
880 cert_path.display()
881 )));
882 }
883
884 let key_file = File::open(key_path)
886 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
887 let mut key_reader = BufReader::new(key_file);
888
889 let key_der = rustls_pemfile::private_key(&mut key_reader)
891 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
892 .ok_or_else(|| {
893 TlsError::KeyLoad(format!(
894 "{}: No private key found in PEM file",
895 key_path.display()
896 ))
897 })?
898 .secret_der()
899 .to_vec();
900
901 let cert_key = pingora_core::utils::tls::CertKey::new(cert_ders, key_der);
903
904 debug!(
905 cert_path = %cert_path.display(),
906 key_path = %key_path.display(),
907 "Loaded mTLS client certificate for upstream connections"
908 );
909
910 Ok(Arc::new(cert_key))
911}
912
913pub fn build_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<ClientConfig, TlsError> {
918 let mut root_store = RootCertStore::empty();
919
920 if let Some(ca_path) = &config.ca_cert {
922 let ca_file = File::open(ca_path)
923 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
924 let mut ca_reader = BufReader::new(ca_file);
925
926 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
927 .collect::<Result<Vec<_>, _>>()
928 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
929
930 for cert in certs {
931 root_store.add(cert).map_err(|e| {
932 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
933 })?;
934 }
935
936 debug!(
937 ca_file = %ca_path.display(),
938 cert_count = root_store.len(),
939 "Loaded upstream CA certificates"
940 );
941 } else if !config.insecure_skip_verify {
942 root_store = RootCertStore {
944 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
945 };
946 trace!("Using webpki-roots for upstream TLS verification");
947 }
948
949 let builder = ClientConfig::builder().with_root_certificates(root_store);
951
952 let client_config = if let (Some(cert_path), Some(key_path)) =
953 (&config.client_cert, &config.client_key)
954 {
955 let cert_file = File::open(cert_path)
957 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
958 let mut cert_reader = BufReader::new(cert_file);
959
960 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
961 .collect::<Result<Vec<_>, _>>()
962 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
963
964 if certs.is_empty() {
965 return Err(TlsError::CertificateLoad(format!(
966 "{}: No certificates found",
967 cert_path.display()
968 )));
969 }
970
971 let key_file = File::open(key_path)
973 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
974 let mut key_reader = BufReader::new(key_file);
975
976 let key = rustls_pemfile::private_key(&mut key_reader)
977 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
978 .ok_or_else(|| {
979 TlsError::KeyLoad(format!("{}: No private key found", key_path.display()))
980 })?;
981
982 info!(
983 cert_file = %cert_path.display(),
984 "Configured mTLS client certificate for upstream connections"
985 );
986
987 builder
988 .with_client_auth_cert(certs, key)
989 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to set client auth: {}", e)))?
990 } else {
991 builder.with_no_client_auth()
993 };
994
995 debug!("Upstream TLS configuration built successfully");
996 Ok(client_config)
997}
998
999pub fn validate_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<(), TlsError> {
1001 if let Some(ca_path) = &config.ca_cert {
1003 if !ca_path.exists() {
1004 return Err(TlsError::CertificateLoad(format!(
1005 "Upstream CA certificate not found: {}",
1006 ca_path.display()
1007 )));
1008 }
1009 }
1010
1011 if let Some(cert_path) = &config.client_cert {
1013 if !cert_path.exists() {
1014 return Err(TlsError::CertificateLoad(format!(
1015 "Upstream client certificate not found: {}",
1016 cert_path.display()
1017 )));
1018 }
1019
1020 match &config.client_key {
1022 Some(key_path) if !key_path.exists() => {
1023 return Err(TlsError::KeyLoad(format!(
1024 "Upstream client key not found: {}",
1025 key_path.display()
1026 )));
1027 }
1028 None => {
1029 return Err(TlsError::ConfigBuild(
1030 "client_cert specified without client_key".to_string(),
1031 ));
1032 }
1033 _ => {}
1034 }
1035 }
1036
1037 if config.client_key.is_some() && config.client_cert.is_none() {
1038 return Err(TlsError::ConfigBuild(
1039 "client_key specified without client_cert".to_string(),
1040 ));
1041 }
1042
1043 Ok(())
1044}
1045
1046fn load_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
1052 let cert_file = File::open(cert_path)
1054 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1055 let mut cert_reader = BufReader::new(cert_file);
1056
1057 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
1058 .collect::<Result<Vec<_>, _>>()
1059 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1060
1061 if certs.is_empty() {
1062 return Err(TlsError::CertificateLoad(format!(
1063 "{}: No certificates found in file",
1064 cert_path.display()
1065 )));
1066 }
1067
1068 let key_file = File::open(key_path)
1070 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
1071 let mut key_reader = BufReader::new(key_file);
1072
1073 let key = rustls_pemfile::private_key(&mut key_reader)
1074 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
1075 .ok_or_else(|| {
1076 TlsError::KeyLoad(format!(
1077 "{}: No private key found in file",
1078 key_path.display()
1079 ))
1080 })?;
1081
1082 let provider = rustls::crypto::CryptoProvider::get_default()
1084 .cloned()
1085 .unwrap_or_else(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()));
1086
1087 let signing_key = provider
1088 .key_provider
1089 .load_private_key(key)
1090 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to load private key: {:?}", e)))?;
1091
1092 Ok(CertifiedKey::new(certs, signing_key))
1093}
1094
1095pub fn load_client_ca(ca_path: &Path) -> Result<RootCertStore, TlsError> {
1097 let ca_file = File::open(ca_path)
1098 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1099 let mut ca_reader = BufReader::new(ca_file);
1100
1101 let mut root_store = RootCertStore::empty();
1102
1103 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
1104 .collect::<Result<Vec<_>, _>>()
1105 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1106
1107 for cert in certs {
1108 root_store.add(cert).map_err(|e| {
1109 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
1110 })?;
1111 }
1112
1113 if root_store.is_empty() {
1114 return Err(TlsError::CertificateLoad(format!(
1115 "{}: No CA certificates found",
1116 ca_path.display()
1117 )));
1118 }
1119
1120 info!(
1121 ca_file = %ca_path.display(),
1122 cert_count = root_store.len(),
1123 "Loaded client CA certificates"
1124 );
1125
1126 Ok(root_store)
1127}
1128
1129pub fn build_server_config(config: &TlsConfig) -> Result<ServerConfig, TlsError> {
1131 let resolver = SniResolver::from_config(config)?;
1132
1133 let builder = ServerConfig::builder();
1134
1135 let server_config = if config.client_auth {
1137 if let Some(ca_path) = &config.ca_file {
1138 let root_store = load_client_ca(ca_path)?;
1139 let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1140 .build()
1141 .map_err(|e| {
1142 TlsError::ConfigBuild(format!("Failed to build client verifier: {}", e))
1143 })?;
1144
1145 info!("mTLS enabled: client certificates required");
1146
1147 builder
1148 .with_client_cert_verifier(verifier)
1149 .with_cert_resolver(Arc::new(resolver))
1150 } else {
1151 warn!("client_auth enabled but no ca_file specified, disabling client auth");
1152 builder
1153 .with_no_client_auth()
1154 .with_cert_resolver(Arc::new(resolver))
1155 }
1156 } else {
1157 builder
1158 .with_no_client_auth()
1159 .with_cert_resolver(Arc::new(resolver))
1160 };
1161
1162 let mut config = server_config;
1164 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
1165
1166 debug!("TLS configuration built successfully");
1167
1168 Ok(config)
1169}
1170
1171pub fn validate_tls_config(config: &TlsConfig) -> Result<(), TlsError> {
1173 if !config.cert_file.exists() {
1175 return Err(TlsError::CertificateLoad(format!(
1176 "Certificate file not found: {}",
1177 config.cert_file.display()
1178 )));
1179 }
1180 if !config.key_file.exists() {
1181 return Err(TlsError::KeyLoad(format!(
1182 "Key file not found: {}",
1183 config.key_file.display()
1184 )));
1185 }
1186
1187 for sni in &config.additional_certs {
1189 if !sni.cert_file.exists() {
1190 return Err(TlsError::CertificateLoad(format!(
1191 "SNI certificate file not found: {}",
1192 sni.cert_file.display()
1193 )));
1194 }
1195 if !sni.key_file.exists() {
1196 return Err(TlsError::KeyLoad(format!(
1197 "SNI key file not found: {}",
1198 sni.key_file.display()
1199 )));
1200 }
1201 }
1202
1203 if config.client_auth {
1205 if let Some(ca_path) = &config.ca_file {
1206 if !ca_path.exists() {
1207 return Err(TlsError::CertificateLoad(format!(
1208 "CA certificate file not found: {}",
1209 ca_path.display()
1210 )));
1211 }
1212 }
1213 }
1214
1215 Ok(())
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220
1221 #[test]
1222 fn test_wildcard_matching() {
1223 let name = "foo.bar.example.com";
1226 let parts: Vec<&str> = name.split('.').collect();
1227
1228 assert_eq!(parts.len(), 4);
1229
1230 let domain1 = parts[1..].join(".");
1232 assert_eq!(domain1, "bar.example.com");
1233
1234 let domain2 = parts[2..].join(".");
1235 assert_eq!(domain2, "example.com");
1236 }
1237
1238 #[test]
1239 fn test_hostname_normalization() {
1240 let hostname = "Example.COM";
1241 let normalized = hostname.to_lowercase();
1242 assert_eq!(normalized, "example.com");
1243 }
1244}