Skip to main content

stakpak_shared/
cert_utils.rs

1use anyhow::{Context, Result};
2use rcgen::{
3    BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose, SanType,
4};
5use rustls::pki_types::pem::PemObject;
6use rustls::pki_types::{CertificateDer, PrivateKeyDer};
7use rustls::{ClientConfig, RootCertStore, ServerConfig};
8use std::path::Path;
9use std::sync::Arc;
10use time::OffsetDateTime;
11
12pub struct CertificateChain {
13    pub ca_cert: rcgen::Certificate,
14    pub server_cert: rcgen::Certificate,
15    pub client_cert: rcgen::Certificate,
16}
17
18impl std::fmt::Debug for CertificateChain {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("CertificateChain")
21            .field("ca_cert", &"<certificate>")
22            .field("server_cert", &"<certificate>")
23            .field("client_cert", &"<certificate>")
24            .finish()
25    }
26}
27
28impl CertificateChain {
29    pub fn generate() -> Result<Self> {
30        // Generate CA certificate
31        let mut ca_params = CertificateParams::default();
32        ca_params.distinguished_name = DistinguishedName::new();
33        ca_params
34            .distinguished_name
35            .push(DnType::CommonName, "Stakpak MCP CA");
36        ca_params
37            .distinguished_name
38            .push(DnType::OrganizationName, "Stakpak");
39        ca_params.distinguished_name.push(DnType::CountryName, "US");
40
41        ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
42        ca_params.key_usages = vec![
43            KeyUsagePurpose::KeyCertSign,
44            KeyUsagePurpose::CrlSign,
45            KeyUsagePurpose::DigitalSignature,
46        ];
47
48        ca_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
49        ca_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
50
51        let ca_cert = rcgen::Certificate::from_params(ca_params)?;
52
53        // Generate server certificate
54        let mut server_params = CertificateParams::default();
55        server_params.distinguished_name = DistinguishedName::new();
56        server_params
57            .distinguished_name
58            .push(DnType::CommonName, "Stakpak MCP Server");
59        server_params
60            .distinguished_name
61            .push(DnType::OrganizationName, "Stakpak");
62        server_params
63            .distinguished_name
64            .push(DnType::CountryName, "US");
65
66        server_params.subject_alt_names = vec![
67            SanType::DnsName("localhost".to_string()),
68            SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0))),
69            SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
70        ];
71
72        server_params.key_usages = vec![
73            KeyUsagePurpose::DigitalSignature,
74            KeyUsagePurpose::KeyEncipherment,
75        ];
76
77        server_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
78        server_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
79
80        let server_cert = rcgen::Certificate::from_params(server_params)?;
81
82        // Generate client certificate
83        let mut client_params = CertificateParams::default();
84        client_params.distinguished_name = DistinguishedName::new();
85        client_params
86            .distinguished_name
87            .push(DnType::CommonName, "Stakpak MCP Client");
88        client_params
89            .distinguished_name
90            .push(DnType::OrganizationName, "Stakpak");
91        client_params
92            .distinguished_name
93            .push(DnType::CountryName, "US");
94
95        client_params.key_usages = vec![
96            KeyUsagePurpose::DigitalSignature,
97            KeyUsagePurpose::KeyEncipherment,
98        ];
99
100        client_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
101        client_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
102
103        let client_cert = rcgen::Certificate::from_params(client_params)?;
104
105        Ok(CertificateChain {
106            ca_cert,
107            server_cert,
108            client_cert,
109        })
110    }
111
112    pub fn create_server_config(&self) -> Result<ServerConfig> {
113        // Sign server certificate with CA
114        let server_cert_der = self.server_cert.serialize_der_with_signer(&self.ca_cert)?;
115        let server_key_der = self.server_cert.serialize_private_key_der();
116
117        let server_cert_chain = vec![CertificateDer::from(server_cert_der)];
118        let server_private_key = PrivateKeyDer::try_from(server_key_der)
119            .map_err(|e| anyhow::anyhow!("Failed to convert server private key: {:?}", e))?;
120
121        // Set up root certificate store to trust our CA (for client cert validation)
122        let mut root_cert_store = RootCertStore::empty();
123        let ca_cert_der = self.ca_cert.serialize_der()?;
124        root_cert_store.add(CertificateDer::from(ca_cert_der))?;
125
126        // Create client certificate verifier that requires client certificates
127        let client_cert_verifier =
128            rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
129                .build()
130                .map_err(|e| anyhow::anyhow!("Failed to build client cert verifier: {}", e))?;
131
132        let config = ServerConfig::builder()
133            .with_client_cert_verifier(client_cert_verifier)
134            .with_single_cert(server_cert_chain, server_private_key)?;
135
136        Ok(config)
137    }
138
139    pub fn create_client_config(&self) -> Result<ClientConfig> {
140        // Sign client certificate with CA
141        let client_cert_der = self.client_cert.serialize_der_with_signer(&self.ca_cert)?;
142        let client_key_der = self.client_cert.serialize_private_key_der();
143
144        let client_cert_chain = vec![CertificateDer::from(client_cert_der)];
145        let client_private_key = PrivateKeyDer::try_from(client_key_der)
146            .map_err(|e| anyhow::anyhow!("Failed to convert client private key: {:?}", e))?;
147
148        // Set up root certificate store to trust our CA (for server cert validation)
149        let mut root_cert_store = RootCertStore::empty();
150        let ca_cert_der = self.ca_cert.serialize_der()?;
151        root_cert_store.add(CertificateDer::from(ca_cert_der))?;
152
153        let config = ClientConfig::builder()
154            .with_root_certificates(root_cert_store)
155            .with_client_auth_cert(client_cert_chain, client_private_key)?;
156
157        Ok(config)
158    }
159
160    pub fn get_ca_cert_pem(&self) -> Result<String> {
161        Ok(self.ca_cert.serialize_pem()?)
162    }
163
164    pub fn get_server_cert_pem(&self) -> Result<String> {
165        Ok(self.server_cert.serialize_pem_with_signer(&self.ca_cert)?)
166    }
167
168    pub fn get_client_cert_pem(&self) -> Result<String> {
169        Ok(self.client_cert.serialize_pem_with_signer(&self.ca_cert)?)
170    }
171
172    pub fn get_server_key_pem(&self) -> Result<String> {
173        Ok(self.server_cert.serialize_private_key_pem())
174    }
175
176    pub fn get_client_key_pem(&self) -> Result<String> {
177        Ok(self.client_cert.serialize_private_key_pem())
178    }
179}
180
181/// A single-sided mTLS identity: a CA that signs one leaf certificate.
182///
183/// Each side of the mTLS connection generates its own `MtlsIdentity`. Only the
184/// CA certificate (public) is shared with the peer — private keys never leave
185/// the process that generated them.
186///
187/// ```text
188/// Host (client)                         Container (server)
189/// ─────────────────                     ─────────────────────
190/// MtlsIdentity::generate_client()       MtlsIdentity::generate_server()
191///   ├─ client CA cert  ──────────────►  trusted by server (verifies client)
192///   ├─ client leaf cert (in memory)     ├─ server CA cert  ◄── output to stdout
193///   └─ client leaf key (in memory)      ├─ server leaf cert (in memory)
194///                                       └─ server leaf key (in memory)
195///
196/// host trusts server CA cert ◄──────── parsed from stdout
197/// ```
198pub struct MtlsIdentity {
199    ca_cert: rcgen::Certificate,
200    leaf_cert: rcgen::Certificate,
201}
202
203impl std::fmt::Debug for MtlsIdentity {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        f.debug_struct("MtlsIdentity")
206            .field("ca_cert", &"<certificate>")
207            .field("leaf_cert", &"<certificate>")
208            .finish()
209    }
210}
211
212impl MtlsIdentity {
213    /// Generate a CA + leaf certificate for a given role.
214    fn generate_leaf(common_name: &str, san: Vec<SanType>) -> Result<Self> {
215        // CA certificate
216        let mut ca_params = CertificateParams::default();
217        ca_params.distinguished_name = DistinguishedName::new();
218        ca_params
219            .distinguished_name
220            .push(DnType::CommonName, format!("{common_name} CA"));
221        ca_params
222            .distinguished_name
223            .push(DnType::OrganizationName, "Stakpak");
224        ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
225        ca_params.key_usages = vec![
226            KeyUsagePurpose::KeyCertSign,
227            KeyUsagePurpose::CrlSign,
228            KeyUsagePurpose::DigitalSignature,
229        ];
230        ca_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
231        ca_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
232        let ca_cert = rcgen::Certificate::from_params(ca_params)?;
233
234        // Leaf certificate
235        let mut leaf_params = CertificateParams::default();
236        leaf_params.distinguished_name = DistinguishedName::new();
237        leaf_params
238            .distinguished_name
239            .push(DnType::CommonName, common_name);
240        leaf_params
241            .distinguished_name
242            .push(DnType::OrganizationName, "Stakpak");
243        leaf_params.subject_alt_names = san;
244        leaf_params.key_usages = vec![
245            KeyUsagePurpose::DigitalSignature,
246            KeyUsagePurpose::KeyEncipherment,
247        ];
248        leaf_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
249        leaf_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
250        let leaf_cert = rcgen::Certificate::from_params(leaf_params)?;
251
252        Ok(Self { ca_cert, leaf_cert })
253    }
254
255    /// Generate a client identity (CA + client leaf cert).
256    pub fn generate_client() -> Result<Self> {
257        Self::generate_leaf("Stakpak MCP Client", vec![])
258    }
259
260    /// Generate a server identity (CA + server leaf cert with localhost SANs).
261    pub fn generate_server() -> Result<Self> {
262        Self::generate_leaf(
263            "Stakpak MCP Server",
264            vec![
265                SanType::DnsName("localhost".to_string()),
266                SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0))),
267                SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
268            ],
269        )
270    }
271
272    /// Get the CA certificate PEM (public, safe to share with the peer).
273    pub fn ca_cert_pem(&self) -> Result<String> {
274        Ok(self.ca_cert.serialize_pem()?)
275    }
276
277    /// Build a `rustls::ServerConfig` that serves with this identity's leaf
278    /// cert and trusts the given client CA PEM for client authentication.
279    pub fn create_server_config(&self, trusted_client_ca_pem: &str) -> Result<ServerConfig> {
280        let leaf_cert_der = self.leaf_cert.serialize_der_with_signer(&self.ca_cert)?;
281        let leaf_key_der = self.leaf_cert.serialize_private_key_der();
282
283        let cert_chain = vec![CertificateDer::from(leaf_cert_der)];
284        let private_key = PrivateKeyDer::try_from(leaf_key_der)
285            .map_err(|e| anyhow::anyhow!("Failed to convert server private key: {:?}", e))?;
286
287        let mut root_store = RootCertStore::empty();
288        for cert in CertificateDer::pem_slice_iter(trusted_client_ca_pem.as_bytes()) {
289            let cert = cert.context("Failed to parse trusted client CA PEM")?;
290            root_store
291                .add(cert)
292                .context("Failed to add trusted client CA to root store")?;
293        }
294
295        let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
296            .build()
297            .map_err(|e| anyhow::anyhow!("Failed to build client cert verifier: {}", e))?;
298
299        let config = ServerConfig::builder()
300            .with_client_cert_verifier(verifier)
301            .with_single_cert(cert_chain, private_key)?;
302
303        Ok(config)
304    }
305
306    /// Build a `rustls::ClientConfig` that authenticates with this identity's
307    /// leaf cert and trusts the given server CA PEM.
308    pub fn create_client_config(&self, trusted_server_ca_pem: &str) -> Result<ClientConfig> {
309        let leaf_cert_der = self.leaf_cert.serialize_der_with_signer(&self.ca_cert)?;
310        let leaf_key_der = self.leaf_cert.serialize_private_key_der();
311
312        let cert_chain = vec![CertificateDer::from(leaf_cert_der)];
313        let private_key = PrivateKeyDer::try_from(leaf_key_der)
314            .map_err(|e| anyhow::anyhow!("Failed to convert client private key: {:?}", e))?;
315
316        let mut root_store = RootCertStore::empty();
317        for cert in CertificateDer::pem_slice_iter(trusted_server_ca_pem.as_bytes()) {
318            let cert = cert.context("Failed to parse trusted server CA PEM")?;
319            root_store
320                .add(cert)
321                .context("Failed to add trusted server CA to root store")?;
322        }
323
324        let config = ClientConfig::builder()
325            .with_root_certificates(root_store)
326            .with_client_auth_cert(cert_chain, private_key)?;
327
328        Ok(config)
329    }
330}
331
332/// A certificate chain loaded from PEM files on disk.
333///
334/// Expects a directory containing:
335/// - `ca.pem` — CA certificate (PEM)
336/// - `server.pem` — Server certificate (PEM)
337/// - `server-key.pem` — Server private key (PEM)
338/// - `client.pem` — Client certificate (PEM)
339/// - `client-key.pem` — Client private key (PEM)
340pub struct LoadedCertificateChain {
341    pub ca_cert_pem: String,
342    pub server_cert_pem: String,
343    pub server_key_pem: String,
344    pub client_cert_pem: String,
345    pub client_key_pem: String,
346}
347
348impl LoadedCertificateChain {
349    pub fn load_from_dir(dir: &Path) -> Result<Self> {
350        let ca_cert_pem = std::fs::read_to_string(dir.join("ca.pem"))
351            .with_context(|| format!("Failed to read ca.pem from {}", dir.display()))?;
352        let server_cert_pem = std::fs::read_to_string(dir.join("server.pem"))
353            .with_context(|| format!("Failed to read server.pem from {}", dir.display()))?;
354        let server_key_pem = std::fs::read_to_string(dir.join("server-key.pem"))
355            .with_context(|| format!("Failed to read server-key.pem from {}", dir.display()))?;
356        let client_cert_pem = std::fs::read_to_string(dir.join("client.pem"))
357            .with_context(|| format!("Failed to read client.pem from {}", dir.display()))?;
358        let client_key_pem = std::fs::read_to_string(dir.join("client-key.pem"))
359            .with_context(|| format!("Failed to read client-key.pem from {}", dir.display()))?;
360
361        Ok(Self {
362            ca_cert_pem,
363            server_cert_pem,
364            server_key_pem,
365            client_cert_pem,
366            client_key_pem,
367        })
368    }
369
370    fn parse_root_cert_store(&self) -> Result<RootCertStore> {
371        let mut root_cert_store = RootCertStore::empty();
372        for cert in CertificateDer::pem_slice_iter(self.ca_cert_pem.as_bytes()) {
373            let cert = cert.context("Failed to parse CA certificate PEM")?;
374            root_cert_store
375                .add(cert)
376                .context("Failed to add CA certificate to root store")?;
377        }
378        Ok(root_cert_store)
379    }
380
381    pub fn create_server_config(&self) -> Result<ServerConfig> {
382        let server_certs: Vec<CertificateDer<'static>> =
383            CertificateDer::pem_slice_iter(self.server_cert_pem.as_bytes())
384                .collect::<std::result::Result<Vec<_>, _>>()
385                .context("Failed to parse server certificate PEM")?;
386
387        let server_key = PrivateKeyDer::from_pem_slice(self.server_key_pem.as_bytes())
388            .context("Failed to parse server private key PEM")?;
389
390        let root_cert_store = self.parse_root_cert_store()?;
391
392        let client_cert_verifier =
393            rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
394                .build()
395                .map_err(|e| anyhow::anyhow!("Failed to build client cert verifier: {}", e))?;
396
397        let config = ServerConfig::builder()
398            .with_client_cert_verifier(client_cert_verifier)
399            .with_single_cert(server_certs, server_key)?;
400
401        Ok(config)
402    }
403
404    pub fn create_client_config(&self) -> Result<ClientConfig> {
405        let client_certs: Vec<CertificateDer<'static>> =
406            CertificateDer::pem_slice_iter(self.client_cert_pem.as_bytes())
407                .collect::<std::result::Result<Vec<_>, _>>()
408                .context("Failed to parse client certificate PEM")?;
409
410        let client_key = PrivateKeyDer::from_pem_slice(self.client_key_pem.as_bytes())
411            .context("Failed to parse client private key PEM")?;
412
413        let root_cert_store = self.parse_root_cert_store()?;
414
415        let config = ClientConfig::builder()
416            .with_root_certificates(root_cert_store)
417            .with_client_auth_cert(client_certs, client_key)?;
418
419        Ok(config)
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use axum::{Router, response::Json, routing::get};
427    use axum_server::tls_rustls::RustlsConfig;
428    use reqwest::Client;
429    use serde_json::json;
430    use std::sync::Arc;
431    use tokio::net::TcpListener;
432    use tokio::time::{Duration, timeout};
433
434    fn init_crypto_provider() {
435        use std::sync::Once;
436        static INIT: Once = Once::new();
437        INIT.call_once(|| {
438            rustls::crypto::aws_lc_rs::default_provider()
439                .install_default()
440                .expect("Failed to install crypto provider");
441        });
442    }
443
444    #[tokio::test]
445    async fn test_mtls_handshake_success() {
446        init_crypto_provider();
447        // Generate certificate chain
448        let cert_chain =
449            CertificateChain::generate().expect("Failed to generate certificate chain");
450
451        // Create server config
452        let server_config = cert_chain
453            .create_server_config()
454            .expect("Failed to create server config");
455
456        // Create client config
457        let client_config = cert_chain
458            .create_client_config()
459            .expect("Failed to create client config");
460
461        // Create a simple axum app
462        let app = Router::new().route(
463            "/test",
464            get(|| async { Json(json!({"status": "success"})) }),
465        );
466
467        // Start server with mTLS
468        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
469
470        // Use a fixed port for testing
471        let test_port = 8443;
472        let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
473
474        let server_handle = tokio::spawn(async move {
475            axum_server::bind_rustls(server_addr, rustls_config)
476                .serve(app.into_make_service())
477                .await
478        });
479
480        // Give server time to start
481        tokio::time::sleep(Duration::from_millis(500)).await;
482
483        // Create reqwest client with mTLS config
484        let client = Client::builder()
485            .use_preconfigured_tls(client_config)
486            .build()
487            .expect("Failed to build client");
488
489        // Test successful mTLS connection
490        let url = format!("https://127.0.0.1:{}/test", test_port);
491        println!("Testing mTLS connection to: {}", url);
492
493        let response = timeout(Duration::from_secs(10), client.get(&url).send())
494            .await
495            .expect("Request timed out")
496            .expect("Failed to send request");
497
498        assert!(
499            response.status().is_success(),
500            "Request should succeed with valid mTLS"
501        );
502
503        let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
504        assert_eq!(body["status"], "success");
505
506        // Shutdown server
507        server_handle.abort();
508    }
509
510    #[tokio::test]
511    async fn test_mtls_handshake_failure_no_client_cert() {
512        init_crypto_provider();
513        // Generate certificate chain
514        let cert_chain =
515            CertificateChain::generate().expect("Failed to generate certificate chain");
516
517        // Create server config (requires client certs)
518        let server_config = cert_chain
519            .create_server_config()
520            .expect("Failed to create server config");
521
522        // Create a simple axum app
523        let app = Router::new().route(
524            "/test",
525            get(|| async { Json(json!({"status": "success"})) }),
526        );
527
528        // Start server with mTLS
529        let listener = TcpListener::bind("127.0.0.1:0")
530            .await
531            .expect("Failed to bind listener");
532        let server_addr = listener.local_addr().expect("Failed to get local address");
533        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
534
535        let server_handle = tokio::spawn(async move {
536            axum_server::bind_rustls(server_addr, rustls_config)
537                .serve(app.into_make_service())
538                .await
539        });
540
541        // Give server time to start
542        tokio::time::sleep(Duration::from_millis(100)).await;
543
544        // Create reqwest client without client certificates (should fail)
545        let client = Client::builder()
546            .danger_accept_invalid_certs(true) // Accept self-signed certs but still no client cert
547            .build()
548            .expect("Failed to build client");
549
550        // Test that connection fails without client certificate
551        let result = timeout(
552            Duration::from_secs(5),
553            client
554                .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
555                .send(),
556        )
557        .await;
558
559        // Should fail because no client certificate is provided
560        assert!(
561            result.is_err() || result.unwrap().is_err(),
562            "Request should fail without client certificate"
563        );
564
565        // Shutdown server
566        server_handle.abort();
567    }
568
569    #[tokio::test]
570    async fn test_mtls_handshake_failure_wrong_ca() {
571        init_crypto_provider();
572        // Generate two separate certificate chains
573        let cert_chain1 =
574            CertificateChain::generate().expect("Failed to generate certificate chain 1");
575        let cert_chain2 =
576            CertificateChain::generate().expect("Failed to generate certificate chain 2");
577
578        // Create server config with first cert chain
579        let server_config = cert_chain1
580            .create_server_config()
581            .expect("Failed to create server config");
582
583        // Create client config with second cert chain (different CA)
584        let client_config = cert_chain2
585            .create_client_config()
586            .expect("Failed to create client config");
587
588        // Create a simple axum app
589        let app = Router::new().route(
590            "/test",
591            get(|| async { Json(json!({"status": "success"})) }),
592        );
593
594        // Start server with mTLS
595        let listener = TcpListener::bind("127.0.0.1:0")
596            .await
597            .expect("Failed to bind listener");
598        let server_addr = listener.local_addr().expect("Failed to get local address");
599        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
600
601        let server_handle = tokio::spawn(async move {
602            axum_server::bind_rustls(server_addr, rustls_config)
603                .serve(app.into_make_service())
604                .await
605        });
606
607        // Give server time to start
608        tokio::time::sleep(Duration::from_millis(100)).await;
609
610        // Create reqwest client with wrong CA certificates
611        let client = Client::builder()
612            .use_preconfigured_tls(client_config)
613            .build()
614            .expect("Failed to build client");
615
616        // Test that connection fails with wrong CA
617        let result = timeout(
618            Duration::from_secs(5),
619            client
620                .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
621                .send(),
622        )
623        .await;
624
625        // Should fail because client and server have different CAs
626        assert!(
627            result.is_err() || result.unwrap().is_err(),
628            "Request should fail with wrong CA certificates"
629        );
630
631        // Shutdown server
632        server_handle.abort();
633    }
634
635    #[tokio::test]
636    async fn test_certificate_chain_generation() {
637        init_crypto_provider();
638        let cert_chain =
639            CertificateChain::generate().expect("Failed to generate certificate chain");
640
641        // Test that we can get PEM representations
642        let ca_pem = cert_chain.get_ca_cert_pem().expect("Failed to get CA PEM");
643        let server_pem = cert_chain
644            .get_server_cert_pem()
645            .expect("Failed to get server PEM");
646        let client_pem = cert_chain
647            .get_client_cert_pem()
648            .expect("Failed to get client PEM");
649        let server_key_pem = cert_chain
650            .get_server_key_pem()
651            .expect("Failed to get server key PEM");
652        let client_key_pem = cert_chain
653            .get_client_key_pem()
654            .expect("Failed to get client key PEM");
655
656        // Verify PEM format
657        assert!(ca_pem.contains("-----BEGIN CERTIFICATE-----"));
658        assert!(ca_pem.contains("-----END CERTIFICATE-----"));
659        assert!(server_pem.contains("-----BEGIN CERTIFICATE-----"));
660        assert!(server_pem.contains("-----END CERTIFICATE-----"));
661        assert!(client_pem.contains("-----BEGIN CERTIFICATE-----"));
662        assert!(client_pem.contains("-----END CERTIFICATE-----"));
663        assert!(server_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
664        assert!(server_key_pem.contains("-----END PRIVATE KEY-----"));
665        assert!(client_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
666        assert!(client_key_pem.contains("-----END PRIVATE KEY-----"));
667    }
668
669    #[tokio::test]
670    async fn test_server_config_creation() {
671        init_crypto_provider();
672        let cert_chain =
673            CertificateChain::generate().expect("Failed to generate certificate chain");
674        let _server_config = cert_chain
675            .create_server_config()
676            .expect("Failed to create server config");
677
678        // Verify server config is created successfully
679        // The fact that it doesn't panic/error is the main test
680        assert!(true, "Server config created successfully");
681    }
682
683    #[tokio::test]
684    async fn test_client_config_creation() {
685        init_crypto_provider();
686        let cert_chain =
687            CertificateChain::generate().expect("Failed to generate certificate chain");
688        let _client_config = cert_chain
689            .create_client_config()
690            .expect("Failed to create client config");
691
692        // Verify client config is created successfully
693        // The fact that it doesn't panic/error is the main test
694        assert!(true, "Client config created successfully");
695    }
696
697    #[tokio::test]
698    async fn test_mtls_multiple_requests() {
699        init_crypto_provider();
700        // Generate certificate chain
701        let cert_chain =
702            CertificateChain::generate().expect("Failed to generate certificate chain");
703
704        // Create server and client configs
705        let server_config = cert_chain
706            .create_server_config()
707            .expect("Failed to create server config");
708        let client_config = cert_chain
709            .create_client_config()
710            .expect("Failed to create client config");
711
712        // Create a simple axum app with multiple routes
713        let app = Router::new()
714            .route(
715                "/test1",
716                get(|| async { Json(json!({"endpoint": "test1"})) }),
717            )
718            .route(
719                "/test2",
720                get(|| async { Json(json!({"endpoint": "test2"})) }),
721            )
722            .route(
723                "/test3",
724                get(|| async { Json(json!({"endpoint": "test3"})) }),
725            );
726
727        // Start server with mTLS
728        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
729
730        // Use a fixed port for testing
731        let test_port = 8444; // Different port from the first test
732        let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
733
734        let server_handle = tokio::spawn(async move {
735            axum_server::bind_rustls(server_addr, rustls_config)
736                .serve(app.into_make_service())
737                .await
738        });
739
740        // Give server time to start
741        tokio::time::sleep(Duration::from_millis(500)).await;
742
743        // Create reqwest client with mTLS config
744        let client = Client::builder()
745            .use_preconfigured_tls(client_config)
746            .build()
747            .expect("Failed to build client");
748
749        // Test multiple requests to different endpoints
750        for endpoint in ["test1", "test2", "test3"] {
751            let response = timeout(
752                Duration::from_secs(10),
753                client
754                    .get(format!("https://127.0.0.1:{}/{}", test_port, endpoint))
755                    .send(),
756            )
757            .await
758            .expect("Request timed out")
759            .expect("Failed to send request");
760
761            assert!(
762                response.status().is_success(),
763                "Request to {} should succeed",
764                endpoint
765            );
766
767            let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
768            assert_eq!(body["endpoint"], endpoint);
769        }
770
771        // Shutdown server
772        server_handle.abort();
773    }
774
775    #[tokio::test]
776    async fn test_mtls_configuration_compatibility() {
777        init_crypto_provider();
778
779        // Generate certificate chain
780        let cert_chain =
781            CertificateChain::generate().expect("Failed to generate certificate chain");
782
783        // Create server config - should work without errors
784        let server_config = cert_chain
785            .create_server_config()
786            .expect("Failed to create server config");
787
788        // Create client config - should work without errors
789        let client_config = cert_chain
790            .create_client_config()
791            .expect("Failed to create client config");
792
793        // Verify we can create a reqwest client with the client config
794        let _client = Client::builder()
795            .use_preconfigured_tls(client_config)
796            .build()
797            .expect("Failed to build reqwest client with mTLS config");
798
799        // Verify we can create an axum-server RustlsConfig with the server config
800        let _rustls_config = RustlsConfig::from_config(Arc::new(server_config));
801
802        // Verify certificate chain properties
803        assert!(cert_chain.get_ca_cert_pem().is_ok());
804        assert!(cert_chain.get_server_cert_pem().is_ok());
805        assert!(cert_chain.get_client_cert_pem().is_ok());
806        assert!(cert_chain.get_server_key_pem().is_ok());
807        assert!(cert_chain.get_client_key_pem().is_ok());
808
809        // If we get here, the mTLS configuration is properly set up
810        println!("✅ mTLS configuration successfully created");
811        println!("✅ Reqwest client can be configured with client certificates");
812        println!("✅ Axum server can be configured with server certificates");
813        println!("✅ Certificate chain includes CA, server, and client certificates");
814    }
815
816    #[tokio::test]
817    async fn test_mtls_certificate_validation() {
818        init_crypto_provider();
819
820        // Test that different certificate chains are incompatible
821        let cert_chain1 =
822            CertificateChain::generate().expect("Failed to generate certificate chain 1");
823        let cert_chain2 =
824            CertificateChain::generate().expect("Failed to generate certificate chain 2");
825
826        // Create configs from different chains
827        let server_config1 = cert_chain1
828            .create_server_config()
829            .expect("Failed to create server config 1");
830        let client_config2 = cert_chain2
831            .create_client_config()
832            .expect("Failed to create client config 2");
833
834        // These should be created successfully but would fail in actual connection
835        let _client = Client::builder()
836            .use_preconfigured_tls(client_config2)
837            .build()
838            .expect("Failed to build client with different CA");
839
840        let _rustls_config = RustlsConfig::from_config(Arc::new(server_config1));
841
842        // The configurations are created successfully, but they would fail during handshake
843        // because they use different CAs
844        println!("✅ Different certificate chains create valid but incompatible configurations");
845    }
846}