Skip to main content

zlayer_proxy/
sni_resolver.rs

1//! SNI-based TLS Certificate Resolver
2//!
3//! This module provides dynamic TLS certificate selection based on Server Name
4//! Indication (SNI). It allows serving different certificates for different domains
5//! at runtime, with support for wildcard certificates.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use zlayer_proxy::SniCertResolver;
11//!
12//! let resolver = SniCertResolver::new();
13//!
14//! // Load certificates for specific domains
15//! resolver.load_cert("example.com", cert_pem, key_pem)?;
16//! resolver.load_cert("*.example.com", wildcard_cert_pem, wildcard_key_pem)?;
17//!
18//! // Set a fallback certificate
19//! resolver.set_default_cert(default_cert_pem, default_key_pem)?;
20//! ```
21
22use crate::error::{ProxyError, Result};
23use dashmap::DashMap;
24use rustls::pki_types::{CertificateDer, PrivateKeyDer};
25use rustls::server::{ClientHello, ResolvesServerCert};
26use rustls::sign::CertifiedKey;
27use std::io::BufReader;
28use std::sync::{Arc, RwLock};
29use tracing::{debug, trace, warn};
30
31/// SNI-based certificate resolver for dynamic TLS certificate selection
32///
33/// This resolver maintains a mapping of domain names to TLS certificates,
34/// allowing the proxy to serve different certificates for different domains.
35/// It supports:
36///
37/// - Exact domain matching (e.g., `api.example.com`)
38/// - Wildcard certificates (e.g., `*.example.com`)
39/// - A default/fallback certificate for unmatched domains
40///
41/// The resolver is thread-safe and supports concurrent certificate updates.
42#[derive(Debug)]
43pub struct SniCertResolver {
44    /// Domain -> `CertifiedKey` mapping
45    certs: DashMap<String, Arc<CertifiedKey>>,
46    /// Default/fallback certificate (optional)
47    default_cert: RwLock<Option<Arc<CertifiedKey>>>,
48}
49
50impl SniCertResolver {
51    /// Create a new empty SNI certificate resolver
52    ///
53    /// # Example
54    ///
55    /// ```rust
56    /// use zlayer_proxy::SniCertResolver;
57    ///
58    /// let resolver = SniCertResolver::new();
59    /// ```
60    #[must_use]
61    pub fn new() -> Self {
62        Self {
63            certs: DashMap::new(),
64            default_cert: RwLock::new(None),
65        }
66    }
67
68    /// Load a certificate for a specific domain
69    ///
70    /// Parses the PEM-encoded certificate chain and private key, then stores
71    /// the resulting `CertifiedKey` for the given domain.
72    ///
73    /// # Arguments
74    ///
75    /// * `domain` - The domain name (e.g., `example.com` or `*.example.com`)
76    /// * `cert_pem` - PEM-encoded certificate chain
77    /// * `key_pem` - PEM-encoded private key
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if:
82    /// - The certificate PEM cannot be parsed
83    /// - The private key PEM cannot be parsed
84    /// - The key is not compatible with the certificate
85    ///
86    /// # Example
87    ///
88    /// ```rust,ignore
89    /// resolver.load_cert("example.com", cert_pem, key_pem)?;
90    /// ```
91    pub fn load_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
92        let certified_key = create_certified_key(cert_pem, key_pem)?;
93        let domain_normalized = normalize_domain(domain);
94
95        debug!(domain = %domain_normalized, "Loaded TLS certificate");
96        self.certs
97            .insert(domain_normalized, Arc::new(certified_key));
98
99        Ok(())
100    }
101
102    /// Set the default/fallback certificate
103    ///
104    /// This certificate is used when no domain-specific certificate matches
105    /// the client's SNI request.
106    ///
107    /// # Arguments
108    ///
109    /// * `cert_pem` - PEM-encoded certificate chain
110    /// * `key_pem` - PEM-encoded private key
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the certificate or key cannot be parsed.
115    ///
116    /// # Example
117    ///
118    /// ```rust,ignore
119    /// resolver.set_default_cert(default_cert_pem, default_key_pem)?;
120    /// ```
121    /// # Panics
122    ///
123    /// Panics if the internal `RwLock` is poisoned.
124    pub fn set_default_cert(&self, cert_pem: &str, key_pem: &str) -> Result<()> {
125        let certified_key = create_certified_key(cert_pem, key_pem)?;
126
127        debug!("Set default TLS certificate");
128        let mut default = self.default_cert.write().expect("RwLock poisoned");
129        *default = Some(Arc::new(certified_key));
130
131        Ok(())
132    }
133
134    /// Remove a certificate for a specific domain
135    ///
136    /// # Arguments
137    ///
138    /// * `domain` - The domain name to remove the certificate for
139    ///
140    /// # Example
141    ///
142    /// ```rust,ignore
143    /// resolver.remove_cert("example.com");
144    /// ```
145    pub fn remove_cert(&self, domain: &str) {
146        let domain_normalized = normalize_domain(domain);
147        if self.certs.remove(&domain_normalized).is_some() {
148            debug!(domain = %domain_normalized, "Removed TLS certificate");
149        }
150    }
151
152    /// Refresh/update a certificate for an existing domain
153    ///
154    /// This is equivalent to calling `load_cert` but semantically indicates
155    /// an update to an existing certificate (e.g., for certificate renewal).
156    ///
157    /// # Arguments
158    ///
159    /// * `domain` - The domain name
160    /// * `cert_pem` - New PEM-encoded certificate chain
161    /// * `key_pem` - New PEM-encoded private key
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if the certificate or key cannot be parsed.
166    ///
167    /// # Example
168    ///
169    /// ```rust,ignore
170    /// resolver.refresh_cert("example.com", new_cert_pem, new_key_pem)?;
171    /// ```
172    pub fn refresh_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
173        let certified_key = create_certified_key(cert_pem, key_pem)?;
174        let domain_normalized = normalize_domain(domain);
175
176        debug!(domain = %domain_normalized, "Refreshed TLS certificate");
177        self.certs
178            .insert(domain_normalized, Arc::new(certified_key));
179
180        Ok(())
181    }
182
183    /// Check if a certificate exists for a domain
184    ///
185    /// # Arguments
186    ///
187    /// * `domain` - The domain name to check
188    ///
189    /// # Returns
190    ///
191    /// `true` if a certificate is loaded for the exact domain name
192    #[must_use]
193    pub fn has_cert(&self, domain: &str) -> bool {
194        let domain_normalized = normalize_domain(domain);
195        self.certs.contains_key(&domain_normalized)
196    }
197
198    /// Get the number of loaded certificates
199    #[must_use]
200    pub fn cert_count(&self) -> usize {
201        self.certs.len()
202    }
203
204    /// List all domains with loaded certificates
205    #[must_use]
206    pub fn domains(&self) -> Vec<String> {
207        self.certs.iter().map(|r| r.key().clone()).collect()
208    }
209
210    /// Check if a default/fallback certificate is configured
211    #[must_use]
212    pub fn has_default_cert(&self) -> bool {
213        self.default_cert.read().is_ok_and(|guard| guard.is_some())
214    }
215
216    /// Internal method to resolve a certificate for a given server name
217    fn resolve_cert(&self, server_name: Option<&str>) -> Option<Arc<CertifiedKey>> {
218        let server_name = server_name?;
219        let normalized = normalize_domain(server_name);
220
221        // Try exact match first
222        if let Some(cert) = self.certs.get(&normalized) {
223            trace!(domain = %normalized, "Exact certificate match");
224            return Some(Arc::clone(cert.value()));
225        }
226
227        // Try wildcard match (e.g., for "foo.example.com", try "*.example.com")
228        if let Some(wildcard_domain) = get_wildcard_domain(&normalized) {
229            if let Some(cert) = self.certs.get(&wildcard_domain) {
230                trace!(
231                    domain = %normalized,
232                    wildcard = %wildcard_domain,
233                    "Wildcard certificate match"
234                );
235                return Some(Arc::clone(cert.value()));
236            }
237        }
238
239        // Fall back to default certificate
240        // Note: We use std::sync::RwLock since ResolvesServerCert::resolve is sync
241        if let Ok(guard) = self.default_cert.read() {
242            if let Some(default) = guard.as_ref() {
243                trace!(domain = %normalized, "Using default certificate");
244                return Some(Arc::clone(default));
245            }
246        }
247
248        warn!(domain = %normalized, "No certificate found");
249        None
250    }
251}
252
253impl Default for SniCertResolver {
254    fn default() -> Self {
255        Self::new()
256    }
257}
258
259impl ResolvesServerCert for SniCertResolver {
260    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
261        let server_name = client_hello.server_name();
262        self.resolve_cert(server_name)
263    }
264}
265
266/// Create a `CertifiedKey` from PEM-encoded certificate and private key
267///
268/// # Arguments
269///
270/// * `cert_pem` - PEM-encoded certificate chain (may contain multiple certificates)
271/// * `key_pem` - PEM-encoded private key (PKCS#1, PKCS#8, or SEC1 format)
272///
273/// # Errors
274///
275/// Returns an error if:
276/// - The certificate PEM cannot be parsed
277/// - No certificates are found in the PEM
278/// - The private key PEM cannot be parsed
279/// - No private key is found in the PEM
280/// - The key cannot be converted to a signing key
281fn create_certified_key(cert_pem: &str, key_pem: &str) -> Result<CertifiedKey> {
282    // Parse certificates
283    let certs = parse_certificates(cert_pem)?;
284    if certs.is_empty() {
285        return Err(ProxyError::Tls("No certificates found in PEM".to_string()));
286    }
287
288    // Parse private key
289    let key = parse_private_key(key_pem)?;
290
291    // Create signing key using rustls crypto provider
292    let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
293        .map_err(|e| ProxyError::Tls(format!("Failed to create signing key: {e}")))?;
294
295    Ok(CertifiedKey::new(certs, signing_key))
296}
297
298/// Parse PEM-encoded certificates
299fn parse_certificates(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
300    let mut reader = BufReader::new(pem.as_bytes());
301    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
302        .collect::<std::result::Result<Vec<_>, _>>()
303        .map_err(|e| ProxyError::Tls(format!("Failed to parse certificate PEM: {e}")))?;
304
305    Ok(certs)
306}
307
308/// Parse a PEM-encoded private key
309fn parse_private_key(pem: &str) -> Result<PrivateKeyDer<'static>> {
310    let mut reader = BufReader::new(pem.as_bytes());
311
312    // Try to read different key formats
313    loop {
314        match rustls_pemfile::read_one(&mut reader) {
315            Ok(Some(rustls_pemfile::Item::Pkcs1Key(key))) => {
316                return Ok(PrivateKeyDer::Pkcs1(key));
317            }
318            Ok(Some(rustls_pemfile::Item::Pkcs8Key(key))) => {
319                return Ok(PrivateKeyDer::Pkcs8(key));
320            }
321            Ok(Some(rustls_pemfile::Item::Sec1Key(key))) => {
322                return Ok(PrivateKeyDer::Sec1(key));
323            }
324            Ok(Some(_)) => {
325                // Skip non-key items (like certificates)
326            }
327            Ok(None) => {
328                return Err(ProxyError::Tls("No private key found in PEM".to_string()));
329            }
330            Err(e) => {
331                return Err(ProxyError::Tls(format!(
332                    "Failed to parse private key PEM: {e}"
333                )));
334            }
335        }
336    }
337}
338
339/// Normalize a domain name (lowercase, trim whitespace)
340fn normalize_domain(domain: &str) -> String {
341    domain.trim().to_lowercase()
342}
343
344/// Get the wildcard domain for a given domain
345///
346/// For `foo.example.com`, returns `*.example.com`.
347/// For `example.com` (no subdomain), returns `None`.
348fn get_wildcard_domain(domain: &str) -> Option<String> {
349    let parts: Vec<&str> = domain.split('.').collect();
350    if parts.len() > 2 {
351        // Has subdomain, create wildcard
352        Some(format!("*.{}", parts[1..].join(".")))
353    } else {
354        // No subdomain (e.g., "example.com")
355        None
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_normalize_domain() {
365        assert_eq!(normalize_domain("Example.COM"), "example.com");
366        assert_eq!(normalize_domain("  foo.bar.com  "), "foo.bar.com");
367        assert_eq!(normalize_domain("API.Example.ORG"), "api.example.org");
368    }
369
370    #[test]
371    fn test_get_wildcard_domain() {
372        assert_eq!(
373            get_wildcard_domain("foo.example.com"),
374            Some("*.example.com".to_string())
375        );
376        assert_eq!(
377            get_wildcard_domain("bar.foo.example.com"),
378            Some("*.foo.example.com".to_string())
379        );
380        assert_eq!(get_wildcard_domain("example.com"), None);
381        assert_eq!(get_wildcard_domain("localhost"), None);
382    }
383
384    #[test]
385    fn test_sni_resolver_new() {
386        let resolver = SniCertResolver::new();
387        assert_eq!(resolver.cert_count(), 0);
388        assert!(resolver.domains().is_empty());
389    }
390
391    #[test]
392    fn test_sni_resolver_default() {
393        let resolver = SniCertResolver::default();
394        assert_eq!(resolver.cert_count(), 0);
395    }
396
397    // Generate a self-signed certificate for testing
398    fn generate_test_cert() -> (String, String) {
399        use rcgen::{generate_simple_self_signed, CertifiedKey as RcgenCertifiedKey};
400
401        let subject_alt_names = vec!["localhost".to_string(), "example.com".to_string()];
402        let RcgenCertifiedKey { cert, key_pair } =
403            generate_simple_self_signed(subject_alt_names).unwrap();
404
405        (cert.pem(), key_pair.serialize_pem())
406    }
407
408    #[tokio::test]
409    async fn test_load_cert() {
410        let resolver = SniCertResolver::new();
411        let (cert_pem, key_pem) = generate_test_cert();
412
413        let result = resolver.load_cert("example.com", &cert_pem, &key_pem);
414        assert!(result.is_ok());
415        assert!(resolver.has_cert("example.com"));
416        assert_eq!(resolver.cert_count(), 1);
417    }
418
419    #[tokio::test]
420    async fn test_load_cert_case_insensitive() {
421        let resolver = SniCertResolver::new();
422        let (cert_pem, key_pem) = generate_test_cert();
423
424        resolver
425            .load_cert("Example.COM", &cert_pem, &key_pem)
426            .unwrap();
427        assert!(resolver.has_cert("example.com"));
428        assert!(resolver.has_cert("EXAMPLE.COM"));
429    }
430
431    #[tokio::test]
432    async fn test_remove_cert() {
433        let resolver = SniCertResolver::new();
434        let (cert_pem, key_pem) = generate_test_cert();
435
436        resolver
437            .load_cert("example.com", &cert_pem, &key_pem)
438            .unwrap();
439        assert!(resolver.has_cert("example.com"));
440
441        resolver.remove_cert("example.com");
442        assert!(!resolver.has_cert("example.com"));
443        assert_eq!(resolver.cert_count(), 0);
444    }
445
446    #[tokio::test]
447    async fn test_refresh_cert() {
448        let resolver = SniCertResolver::new();
449        let (cert_pem, key_pem) = generate_test_cert();
450
451        // Load initial cert
452        resolver
453            .load_cert("example.com", &cert_pem, &key_pem)
454            .unwrap();
455
456        // Refresh with new cert
457        let (new_cert_pem, new_key_pem) = generate_test_cert();
458        let result = resolver.refresh_cert("example.com", &new_cert_pem, &new_key_pem);
459        assert!(result.is_ok());
460        assert_eq!(resolver.cert_count(), 1);
461    }
462
463    #[tokio::test]
464    async fn test_set_default_cert() {
465        let resolver = SniCertResolver::new();
466        let (cert_pem, key_pem) = generate_test_cert();
467
468        let result = resolver.set_default_cert(&cert_pem, &key_pem);
469        assert!(result.is_ok());
470
471        // Default cert should not show up in cert_count or domains
472        assert_eq!(resolver.cert_count(), 0);
473    }
474
475    #[tokio::test]
476    async fn test_has_default_cert() {
477        let resolver = SniCertResolver::new();
478        assert!(!resolver.has_default_cert());
479
480        let (cert_pem, key_pem) = generate_test_cert();
481        resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
482
483        assert!(resolver.has_default_cert());
484    }
485
486    #[tokio::test]
487    async fn test_domains() {
488        let resolver = SniCertResolver::new();
489        let (cert_pem, key_pem) = generate_test_cert();
490
491        resolver
492            .load_cert("api.example.com", &cert_pem, &key_pem)
493            .unwrap();
494        resolver
495            .load_cert("web.example.com", &cert_pem, &key_pem)
496            .unwrap();
497
498        let domains = resolver.domains();
499        assert_eq!(domains.len(), 2);
500        assert!(domains.contains(&"api.example.com".to_string()));
501        assert!(domains.contains(&"web.example.com".to_string()));
502    }
503
504    #[tokio::test]
505    async fn test_resolve_exact_match() {
506        let resolver = SniCertResolver::new();
507        let (cert_pem, key_pem) = generate_test_cert();
508
509        resolver
510            .load_cert("example.com", &cert_pem, &key_pem)
511            .unwrap();
512
513        let result = resolver.resolve_cert(Some("example.com"));
514        assert!(result.is_some());
515    }
516
517    #[tokio::test]
518    async fn test_resolve_wildcard_match() {
519        let resolver = SniCertResolver::new();
520        let (cert_pem, key_pem) = generate_test_cert();
521
522        // Load wildcard cert
523        resolver
524            .load_cert("*.example.com", &cert_pem, &key_pem)
525            .unwrap();
526
527        // Should match subdomains
528        let result = resolver.resolve_cert(Some("api.example.com"));
529        assert!(result.is_some());
530
531        let result = resolver.resolve_cert(Some("web.example.com"));
532        assert!(result.is_some());
533
534        // Should not match base domain
535        let result = resolver.resolve_cert(Some("example.com"));
536        assert!(result.is_none());
537    }
538
539    #[tokio::test]
540    async fn test_resolve_default_fallback() {
541        let resolver = SniCertResolver::new();
542        let (cert_pem, key_pem) = generate_test_cert();
543
544        resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
545
546        // Unknown domain should fall back to default
547        let result = resolver.resolve_cert(Some("unknown.com"));
548        assert!(result.is_some());
549    }
550
551    #[tokio::test]
552    async fn test_resolve_no_match() {
553        let resolver = SniCertResolver::new();
554        let (cert_pem, key_pem) = generate_test_cert();
555
556        resolver
557            .load_cert("example.com", &cert_pem, &key_pem)
558            .unwrap();
559
560        // No default, different domain
561        let result = resolver.resolve_cert(Some("other.com"));
562        assert!(result.is_none());
563    }
564
565    #[tokio::test]
566    async fn test_resolve_none_server_name() {
567        let resolver = SniCertResolver::new();
568
569        // No server name provided
570        let result = resolver.resolve_cert(None);
571        assert!(result.is_none());
572    }
573
574    #[test]
575    fn test_invalid_cert_pem() {
576        let result = parse_certificates("not a valid PEM");
577        assert!(result.is_ok()); // Will succeed but return empty vec
578        assert!(result.unwrap().is_empty());
579    }
580
581    #[test]
582    fn test_invalid_key_pem() {
583        let result = parse_private_key("not a valid PEM");
584        assert!(result.is_err());
585    }
586
587    #[test]
588    fn test_create_certified_key_empty_certs() {
589        let (_, key_pem) = generate_test_cert();
590        let result = create_certified_key("", &key_pem);
591        assert!(result.is_err());
592    }
593}