stakpak_shared/
cert_utils.rs

1use anyhow::Result;
2use rcgen::{
3    BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose, SanType,
4};
5use rustls::pki_types::{CertificateDer, PrivateKeyDer};
6use rustls::{ClientConfig, RootCertStore, ServerConfig};
7use std::sync::Arc;
8use time::OffsetDateTime;
9
10pub struct CertificateChain {
11    pub ca_cert: rcgen::Certificate,
12    pub server_cert: rcgen::Certificate,
13    pub client_cert: rcgen::Certificate,
14}
15
16impl std::fmt::Debug for CertificateChain {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("CertificateChain")
19            .field("ca_cert", &"<certificate>")
20            .field("server_cert", &"<certificate>")
21            .field("client_cert", &"<certificate>")
22            .finish()
23    }
24}
25
26impl CertificateChain {
27    pub fn generate() -> Result<Self> {
28        // Generate CA certificate
29        let mut ca_params = CertificateParams::default();
30        ca_params.distinguished_name = DistinguishedName::new();
31        ca_params
32            .distinguished_name
33            .push(DnType::CommonName, "Stakpak MCP CA");
34        ca_params
35            .distinguished_name
36            .push(DnType::OrganizationName, "Stakpak");
37        ca_params.distinguished_name.push(DnType::CountryName, "US");
38
39        ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
40        ca_params.key_usages = vec![
41            KeyUsagePurpose::KeyCertSign,
42            KeyUsagePurpose::CrlSign,
43            KeyUsagePurpose::DigitalSignature,
44        ];
45
46        ca_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
47        ca_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
48
49        let ca_cert = rcgen::Certificate::from_params(ca_params)?;
50
51        // Generate server certificate
52        let mut server_params = CertificateParams::default();
53        server_params.distinguished_name = DistinguishedName::new();
54        server_params
55            .distinguished_name
56            .push(DnType::CommonName, "Stakpak MCP Server");
57        server_params
58            .distinguished_name
59            .push(DnType::OrganizationName, "Stakpak");
60        server_params
61            .distinguished_name
62            .push(DnType::CountryName, "US");
63
64        server_params.subject_alt_names = vec![
65            SanType::DnsName("localhost".to_string()),
66            SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0))),
67            SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
68        ];
69
70        server_params.key_usages = vec![
71            KeyUsagePurpose::DigitalSignature,
72            KeyUsagePurpose::KeyEncipherment,
73        ];
74
75        server_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
76        server_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
77
78        let server_cert = rcgen::Certificate::from_params(server_params)?;
79
80        // Generate client certificate
81        let mut client_params = CertificateParams::default();
82        client_params.distinguished_name = DistinguishedName::new();
83        client_params
84            .distinguished_name
85            .push(DnType::CommonName, "Stakpak MCP Client");
86        client_params
87            .distinguished_name
88            .push(DnType::OrganizationName, "Stakpak");
89        client_params
90            .distinguished_name
91            .push(DnType::CountryName, "US");
92
93        client_params.key_usages = vec![
94            KeyUsagePurpose::DigitalSignature,
95            KeyUsagePurpose::KeyEncipherment,
96        ];
97
98        client_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
99        client_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
100
101        let client_cert = rcgen::Certificate::from_params(client_params)?;
102
103        Ok(CertificateChain {
104            ca_cert,
105            server_cert,
106            client_cert,
107        })
108    }
109
110    pub fn create_server_config(&self) -> Result<ServerConfig> {
111        // Sign server certificate with CA
112        let server_cert_der = self.server_cert.serialize_der_with_signer(&self.ca_cert)?;
113        let server_key_der = self.server_cert.serialize_private_key_der();
114
115        let server_cert_chain = vec![CertificateDer::from(server_cert_der)];
116        let server_private_key = PrivateKeyDer::try_from(server_key_der)
117            .map_err(|e| anyhow::anyhow!("Failed to convert server private key: {:?}", e))?;
118
119        // Set up root certificate store to trust our CA (for client cert validation)
120        let mut root_cert_store = RootCertStore::empty();
121        let ca_cert_der = self.ca_cert.serialize_der()?;
122        root_cert_store.add(CertificateDer::from(ca_cert_der))?;
123
124        // Create client certificate verifier that requires client certificates
125        let client_cert_verifier =
126            rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
127                .build()
128                .map_err(|e| anyhow::anyhow!("Failed to build client cert verifier: {}", e))?;
129
130        let config = ServerConfig::builder()
131            .with_client_cert_verifier(client_cert_verifier)
132            .with_single_cert(server_cert_chain, server_private_key)?;
133
134        Ok(config)
135    }
136
137    pub fn create_client_config(&self) -> Result<ClientConfig> {
138        // Sign client certificate with CA
139        let client_cert_der = self.client_cert.serialize_der_with_signer(&self.ca_cert)?;
140        let client_key_der = self.client_cert.serialize_private_key_der();
141
142        let client_cert_chain = vec![CertificateDer::from(client_cert_der)];
143        let client_private_key = PrivateKeyDer::try_from(client_key_der)
144            .map_err(|e| anyhow::anyhow!("Failed to convert client private key: {:?}", e))?;
145
146        // Set up root certificate store to trust our CA (for server cert validation)
147        let mut root_cert_store = RootCertStore::empty();
148        let ca_cert_der = self.ca_cert.serialize_der()?;
149        root_cert_store.add(CertificateDer::from(ca_cert_der))?;
150
151        let config = ClientConfig::builder()
152            .with_root_certificates(root_cert_store)
153            .with_client_auth_cert(client_cert_chain, client_private_key)?;
154
155        Ok(config)
156    }
157
158    pub fn get_ca_cert_pem(&self) -> Result<String> {
159        Ok(self.ca_cert.serialize_pem()?)
160    }
161
162    pub fn get_server_cert_pem(&self) -> Result<String> {
163        Ok(self.server_cert.serialize_pem_with_signer(&self.ca_cert)?)
164    }
165
166    pub fn get_client_cert_pem(&self) -> Result<String> {
167        Ok(self.client_cert.serialize_pem_with_signer(&self.ca_cert)?)
168    }
169
170    pub fn get_server_key_pem(&self) -> Result<String> {
171        Ok(self.server_cert.serialize_private_key_pem())
172    }
173
174    pub fn get_client_key_pem(&self) -> Result<String> {
175        Ok(self.client_cert.serialize_private_key_pem())
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use axum::{Router, response::Json, routing::get};
183    use axum_server::tls_rustls::RustlsConfig;
184    use reqwest::Client;
185    use serde_json::json;
186    use std::sync::Arc;
187    use tokio::net::TcpListener;
188    use tokio::time::{Duration, timeout};
189
190    fn init_crypto_provider() {
191        use std::sync::Once;
192        static INIT: Once = Once::new();
193        INIT.call_once(|| {
194            rustls::crypto::aws_lc_rs::default_provider()
195                .install_default()
196                .expect("Failed to install crypto provider");
197        });
198    }
199
200    #[tokio::test]
201    async fn test_mtls_handshake_success() {
202        init_crypto_provider();
203        // Generate certificate chain
204        let cert_chain =
205            CertificateChain::generate().expect("Failed to generate certificate chain");
206
207        // Create server config
208        let server_config = cert_chain
209            .create_server_config()
210            .expect("Failed to create server config");
211
212        // Create client config
213        let client_config = cert_chain
214            .create_client_config()
215            .expect("Failed to create client config");
216
217        // Create a simple axum app
218        let app = Router::new().route(
219            "/test",
220            get(|| async { Json(json!({"status": "success"})) }),
221        );
222
223        // Start server with mTLS
224        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
225
226        // Use a fixed port for testing
227        let test_port = 8443;
228        let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
229
230        let server_handle = tokio::spawn(async move {
231            axum_server::bind_rustls(server_addr, rustls_config)
232                .serve(app.into_make_service())
233                .await
234        });
235
236        // Give server time to start
237        tokio::time::sleep(Duration::from_millis(500)).await;
238
239        // Create reqwest client with mTLS config
240        let client = Client::builder()
241            .use_preconfigured_tls(client_config)
242            .build()
243            .expect("Failed to build client");
244
245        // Test successful mTLS connection
246        let url = format!("https://127.0.0.1:{}/test", test_port);
247        println!("Testing mTLS connection to: {}", url);
248
249        let response = timeout(Duration::from_secs(10), client.get(&url).send())
250            .await
251            .expect("Request timed out")
252            .expect("Failed to send request");
253
254        assert!(
255            response.status().is_success(),
256            "Request should succeed with valid mTLS"
257        );
258
259        let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
260        assert_eq!(body["status"], "success");
261
262        // Shutdown server
263        server_handle.abort();
264    }
265
266    #[tokio::test]
267    async fn test_mtls_handshake_failure_no_client_cert() {
268        init_crypto_provider();
269        // Generate certificate chain
270        let cert_chain =
271            CertificateChain::generate().expect("Failed to generate certificate chain");
272
273        // Create server config (requires client certs)
274        let server_config = cert_chain
275            .create_server_config()
276            .expect("Failed to create server config");
277
278        // Create a simple axum app
279        let app = Router::new().route(
280            "/test",
281            get(|| async { Json(json!({"status": "success"})) }),
282        );
283
284        // Start server with mTLS
285        let listener = TcpListener::bind("127.0.0.1:0")
286            .await
287            .expect("Failed to bind listener");
288        let server_addr = listener.local_addr().expect("Failed to get local address");
289        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
290
291        let server_handle = tokio::spawn(async move {
292            axum_server::bind_rustls(server_addr, rustls_config)
293                .serve(app.into_make_service())
294                .await
295        });
296
297        // Give server time to start
298        tokio::time::sleep(Duration::from_millis(100)).await;
299
300        // Create reqwest client without client certificates (should fail)
301        let client = Client::builder()
302            .danger_accept_invalid_certs(true) // Accept self-signed certs but still no client cert
303            .build()
304            .expect("Failed to build client");
305
306        // Test that connection fails without client certificate
307        let result = timeout(
308            Duration::from_secs(5),
309            client
310                .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
311                .send(),
312        )
313        .await;
314
315        // Should fail because no client certificate is provided
316        assert!(
317            result.is_err() || result.unwrap().is_err(),
318            "Request should fail without client certificate"
319        );
320
321        // Shutdown server
322        server_handle.abort();
323    }
324
325    #[tokio::test]
326    async fn test_mtls_handshake_failure_wrong_ca() {
327        init_crypto_provider();
328        // Generate two separate certificate chains
329        let cert_chain1 =
330            CertificateChain::generate().expect("Failed to generate certificate chain 1");
331        let cert_chain2 =
332            CertificateChain::generate().expect("Failed to generate certificate chain 2");
333
334        // Create server config with first cert chain
335        let server_config = cert_chain1
336            .create_server_config()
337            .expect("Failed to create server config");
338
339        // Create client config with second cert chain (different CA)
340        let client_config = cert_chain2
341            .create_client_config()
342            .expect("Failed to create client config");
343
344        // Create a simple axum app
345        let app = Router::new().route(
346            "/test",
347            get(|| async { Json(json!({"status": "success"})) }),
348        );
349
350        // Start server with mTLS
351        let listener = TcpListener::bind("127.0.0.1:0")
352            .await
353            .expect("Failed to bind listener");
354        let server_addr = listener.local_addr().expect("Failed to get local address");
355        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
356
357        let server_handle = tokio::spawn(async move {
358            axum_server::bind_rustls(server_addr, rustls_config)
359                .serve(app.into_make_service())
360                .await
361        });
362
363        // Give server time to start
364        tokio::time::sleep(Duration::from_millis(100)).await;
365
366        // Create reqwest client with wrong CA certificates
367        let client = Client::builder()
368            .use_preconfigured_tls(client_config)
369            .build()
370            .expect("Failed to build client");
371
372        // Test that connection fails with wrong CA
373        let result = timeout(
374            Duration::from_secs(5),
375            client
376                .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
377                .send(),
378        )
379        .await;
380
381        // Should fail because client and server have different CAs
382        assert!(
383            result.is_err() || result.unwrap().is_err(),
384            "Request should fail with wrong CA certificates"
385        );
386
387        // Shutdown server
388        server_handle.abort();
389    }
390
391    #[tokio::test]
392    async fn test_certificate_chain_generation() {
393        init_crypto_provider();
394        let cert_chain =
395            CertificateChain::generate().expect("Failed to generate certificate chain");
396
397        // Test that we can get PEM representations
398        let ca_pem = cert_chain.get_ca_cert_pem().expect("Failed to get CA PEM");
399        let server_pem = cert_chain
400            .get_server_cert_pem()
401            .expect("Failed to get server PEM");
402        let client_pem = cert_chain
403            .get_client_cert_pem()
404            .expect("Failed to get client PEM");
405        let server_key_pem = cert_chain
406            .get_server_key_pem()
407            .expect("Failed to get server key PEM");
408        let client_key_pem = cert_chain
409            .get_client_key_pem()
410            .expect("Failed to get client key PEM");
411
412        // Verify PEM format
413        assert!(ca_pem.contains("-----BEGIN CERTIFICATE-----"));
414        assert!(ca_pem.contains("-----END CERTIFICATE-----"));
415        assert!(server_pem.contains("-----BEGIN CERTIFICATE-----"));
416        assert!(server_pem.contains("-----END CERTIFICATE-----"));
417        assert!(client_pem.contains("-----BEGIN CERTIFICATE-----"));
418        assert!(client_pem.contains("-----END CERTIFICATE-----"));
419        assert!(server_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
420        assert!(server_key_pem.contains("-----END PRIVATE KEY-----"));
421        assert!(client_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
422        assert!(client_key_pem.contains("-----END PRIVATE KEY-----"));
423    }
424
425    #[tokio::test]
426    async fn test_server_config_creation() {
427        init_crypto_provider();
428        let cert_chain =
429            CertificateChain::generate().expect("Failed to generate certificate chain");
430        let _server_config = cert_chain
431            .create_server_config()
432            .expect("Failed to create server config");
433
434        // Verify server config is created successfully
435        // The fact that it doesn't panic/error is the main test
436        assert!(true, "Server config created successfully");
437    }
438
439    #[tokio::test]
440    async fn test_client_config_creation() {
441        init_crypto_provider();
442        let cert_chain =
443            CertificateChain::generate().expect("Failed to generate certificate chain");
444        let _client_config = cert_chain
445            .create_client_config()
446            .expect("Failed to create client config");
447
448        // Verify client config is created successfully
449        // The fact that it doesn't panic/error is the main test
450        assert!(true, "Client config created successfully");
451    }
452
453    #[tokio::test]
454    async fn test_mtls_multiple_requests() {
455        init_crypto_provider();
456        // Generate certificate chain
457        let cert_chain =
458            CertificateChain::generate().expect("Failed to generate certificate chain");
459
460        // Create server and client configs
461        let server_config = cert_chain
462            .create_server_config()
463            .expect("Failed to create server config");
464        let client_config = cert_chain
465            .create_client_config()
466            .expect("Failed to create client config");
467
468        // Create a simple axum app with multiple routes
469        let app = Router::new()
470            .route(
471                "/test1",
472                get(|| async { Json(json!({"endpoint": "test1"})) }),
473            )
474            .route(
475                "/test2",
476                get(|| async { Json(json!({"endpoint": "test2"})) }),
477            )
478            .route(
479                "/test3",
480                get(|| async { Json(json!({"endpoint": "test3"})) }),
481            );
482
483        // Start server with mTLS
484        let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
485
486        // Use a fixed port for testing
487        let test_port = 8444; // Different port from the first test
488        let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
489
490        let server_handle = tokio::spawn(async move {
491            axum_server::bind_rustls(server_addr, rustls_config)
492                .serve(app.into_make_service())
493                .await
494        });
495
496        // Give server time to start
497        tokio::time::sleep(Duration::from_millis(500)).await;
498
499        // Create reqwest client with mTLS config
500        let client = Client::builder()
501            .use_preconfigured_tls(client_config)
502            .build()
503            .expect("Failed to build client");
504
505        // Test multiple requests to different endpoints
506        for endpoint in ["test1", "test2", "test3"] {
507            let response = timeout(
508                Duration::from_secs(10),
509                client
510                    .get(format!("https://127.0.0.1:{}/{}", test_port, endpoint))
511                    .send(),
512            )
513            .await
514            .expect("Request timed out")
515            .expect("Failed to send request");
516
517            assert!(
518                response.status().is_success(),
519                "Request to {} should succeed",
520                endpoint
521            );
522
523            let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
524            assert_eq!(body["endpoint"], endpoint);
525        }
526
527        // Shutdown server
528        server_handle.abort();
529    }
530
531    #[tokio::test]
532    async fn test_mtls_configuration_compatibility() {
533        init_crypto_provider();
534
535        // Generate certificate chain
536        let cert_chain =
537            CertificateChain::generate().expect("Failed to generate certificate chain");
538
539        // Create server config - should work without errors
540        let server_config = cert_chain
541            .create_server_config()
542            .expect("Failed to create server config");
543
544        // Create client config - should work without errors
545        let client_config = cert_chain
546            .create_client_config()
547            .expect("Failed to create client config");
548
549        // Verify we can create a reqwest client with the client config
550        let _client = Client::builder()
551            .use_preconfigured_tls(client_config)
552            .build()
553            .expect("Failed to build reqwest client with mTLS config");
554
555        // Verify we can create an axum-server RustlsConfig with the server config
556        let _rustls_config = RustlsConfig::from_config(Arc::new(server_config));
557
558        // Verify certificate chain properties
559        assert!(cert_chain.get_ca_cert_pem().is_ok());
560        assert!(cert_chain.get_server_cert_pem().is_ok());
561        assert!(cert_chain.get_client_cert_pem().is_ok());
562        assert!(cert_chain.get_server_key_pem().is_ok());
563        assert!(cert_chain.get_client_key_pem().is_ok());
564
565        // If we get here, the mTLS configuration is properly set up
566        println!("✅ mTLS configuration successfully created");
567        println!("✅ Reqwest client can be configured with client certificates");
568        println!("✅ Axum server can be configured with server certificates");
569        println!("✅ Certificate chain includes CA, server, and client certificates");
570    }
571
572    #[tokio::test]
573    async fn test_mtls_certificate_validation() {
574        init_crypto_provider();
575
576        // Test that different certificate chains are incompatible
577        let cert_chain1 =
578            CertificateChain::generate().expect("Failed to generate certificate chain 1");
579        let cert_chain2 =
580            CertificateChain::generate().expect("Failed to generate certificate chain 2");
581
582        // Create configs from different chains
583        let server_config1 = cert_chain1
584            .create_server_config()
585            .expect("Failed to create server config 1");
586        let client_config2 = cert_chain2
587            .create_client_config()
588            .expect("Failed to create client config 2");
589
590        // These should be created successfully but would fail in actual connection
591        let _client = Client::builder()
592            .use_preconfigured_tls(client_config2)
593            .build()
594            .expect("Failed to build client with different CA");
595
596        let _rustls_config = RustlsConfig::from_config(Arc::new(server_config1));
597
598        // The configurations are created successfully, but they would fail during handshake
599        // because they use different CAs
600        println!("✅ Different certificate chains create valid but incompatible configurations");
601    }
602}