1use anyhow::Result;
2use rcgen::{
3 BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose, SanType,
4};
5use rustls::pki_types::{CertificateDer, PrivateKeyDer};
6use rustls::{ClientConfig, RootCertStore, ServerConfig};
7use std::sync::Arc;
8use time::OffsetDateTime;
9
10pub struct CertificateChain {
11 pub ca_cert: rcgen::Certificate,
12 pub server_cert: rcgen::Certificate,
13 pub client_cert: rcgen::Certificate,
14}
15
16impl CertificateChain {
17 pub fn generate() -> Result<Self> {
18 let mut ca_params = CertificateParams::default();
20 ca_params.distinguished_name = DistinguishedName::new();
21 ca_params
22 .distinguished_name
23 .push(DnType::CommonName, "Stakpak MCP CA");
24 ca_params
25 .distinguished_name
26 .push(DnType::OrganizationName, "Stakpak");
27 ca_params.distinguished_name.push(DnType::CountryName, "US");
28
29 ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
30 ca_params.key_usages = vec![
31 KeyUsagePurpose::KeyCertSign,
32 KeyUsagePurpose::CrlSign,
33 KeyUsagePurpose::DigitalSignature,
34 ];
35
36 ca_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
37 ca_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
38
39 let ca_cert = rcgen::Certificate::from_params(ca_params)?;
40
41 let mut server_params = CertificateParams::default();
43 server_params.distinguished_name = DistinguishedName::new();
44 server_params
45 .distinguished_name
46 .push(DnType::CommonName, "Stakpak MCP Server");
47 server_params
48 .distinguished_name
49 .push(DnType::OrganizationName, "Stakpak");
50 server_params
51 .distinguished_name
52 .push(DnType::CountryName, "US");
53
54 server_params.subject_alt_names = vec![
55 SanType::DnsName("localhost".to_string()),
56 SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0))),
57 SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
58 ];
59
60 server_params.key_usages = vec![
61 KeyUsagePurpose::DigitalSignature,
62 KeyUsagePurpose::KeyEncipherment,
63 ];
64
65 server_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
66 server_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
67
68 let server_cert = rcgen::Certificate::from_params(server_params)?;
69
70 let mut client_params = CertificateParams::default();
72 client_params.distinguished_name = DistinguishedName::new();
73 client_params
74 .distinguished_name
75 .push(DnType::CommonName, "Stakpak MCP Client");
76 client_params
77 .distinguished_name
78 .push(DnType::OrganizationName, "Stakpak");
79 client_params
80 .distinguished_name
81 .push(DnType::CountryName, "US");
82
83 client_params.key_usages = vec![
84 KeyUsagePurpose::DigitalSignature,
85 KeyUsagePurpose::KeyEncipherment,
86 ];
87
88 client_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
89 client_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
90
91 let client_cert = rcgen::Certificate::from_params(client_params)?;
92
93 Ok(CertificateChain {
94 ca_cert,
95 server_cert,
96 client_cert,
97 })
98 }
99
100 pub fn create_server_config(&self) -> Result<ServerConfig> {
101 let server_cert_der = self.server_cert.serialize_der_with_signer(&self.ca_cert)?;
103 let server_key_der = self.server_cert.serialize_private_key_der();
104
105 let server_cert_chain = vec![CertificateDer::from(server_cert_der)];
106 let server_private_key = PrivateKeyDer::try_from(server_key_der)
107 .map_err(|e| anyhow::anyhow!("Failed to convert server private key: {:?}", e))?;
108
109 let mut root_cert_store = RootCertStore::empty();
111 let ca_cert_der = self.ca_cert.serialize_der()?;
112 root_cert_store.add(CertificateDer::from(ca_cert_der))?;
113
114 let client_cert_verifier =
116 rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
117 .build()
118 .map_err(|e| anyhow::anyhow!("Failed to build client cert verifier: {}", e))?;
119
120 let config = ServerConfig::builder()
121 .with_client_cert_verifier(client_cert_verifier)
122 .with_single_cert(server_cert_chain, server_private_key)?;
123
124 Ok(config)
125 }
126
127 pub fn create_client_config(&self) -> Result<ClientConfig> {
128 let client_cert_der = self.client_cert.serialize_der_with_signer(&self.ca_cert)?;
130 let client_key_der = self.client_cert.serialize_private_key_der();
131
132 let client_cert_chain = vec![CertificateDer::from(client_cert_der)];
133 let client_private_key = PrivateKeyDer::try_from(client_key_der)
134 .map_err(|e| anyhow::anyhow!("Failed to convert client private key: {:?}", e))?;
135
136 let mut root_cert_store = RootCertStore::empty();
138 let ca_cert_der = self.ca_cert.serialize_der()?;
139 root_cert_store.add(CertificateDer::from(ca_cert_der))?;
140
141 let config = ClientConfig::builder()
142 .with_root_certificates(root_cert_store)
143 .with_client_auth_cert(client_cert_chain, client_private_key)?;
144
145 Ok(config)
146 }
147
148 pub fn get_ca_cert_pem(&self) -> Result<String> {
149 Ok(self.ca_cert.serialize_pem()?)
150 }
151
152 pub fn get_server_cert_pem(&self) -> Result<String> {
153 Ok(self.server_cert.serialize_pem_with_signer(&self.ca_cert)?)
154 }
155
156 pub fn get_client_cert_pem(&self) -> Result<String> {
157 Ok(self.client_cert.serialize_pem_with_signer(&self.ca_cert)?)
158 }
159
160 pub fn get_server_key_pem(&self) -> Result<String> {
161 Ok(self.server_cert.serialize_private_key_pem())
162 }
163
164 pub fn get_client_key_pem(&self) -> Result<String> {
165 Ok(self.client_cert.serialize_private_key_pem())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use axum::{Router, response::Json, routing::get};
173 use axum_server::tls_rustls::RustlsConfig;
174 use reqwest::Client;
175 use serde_json::json;
176 use std::sync::Arc;
177 use tokio::net::TcpListener;
178 use tokio::time::{Duration, timeout};
179
180 fn init_crypto_provider() {
181 use std::sync::Once;
182 static INIT: Once = Once::new();
183 INIT.call_once(|| {
184 rustls::crypto::aws_lc_rs::default_provider()
185 .install_default()
186 .expect("Failed to install crypto provider");
187 });
188 }
189
190 #[tokio::test]
191 async fn test_mtls_handshake_success() {
192 init_crypto_provider();
193 let cert_chain =
195 CertificateChain::generate().expect("Failed to generate certificate chain");
196
197 let server_config = cert_chain
199 .create_server_config()
200 .expect("Failed to create server config");
201
202 let client_config = cert_chain
204 .create_client_config()
205 .expect("Failed to create client config");
206
207 let app = Router::new().route(
209 "/test",
210 get(|| async { Json(json!({"status": "success"})) }),
211 );
212
213 let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
215
216 let test_port = 8443;
218 let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
219
220 let server_handle = tokio::spawn(async move {
221 axum_server::bind_rustls(server_addr, rustls_config)
222 .serve(app.into_make_service())
223 .await
224 });
225
226 tokio::time::sleep(Duration::from_millis(500)).await;
228
229 let client = Client::builder()
231 .use_preconfigured_tls(client_config)
232 .build()
233 .expect("Failed to build client");
234
235 let url = format!("https://127.0.0.1:{}/test", test_port);
237 println!("Testing mTLS connection to: {}", url);
238
239 let response = timeout(Duration::from_secs(10), client.get(&url).send())
240 .await
241 .expect("Request timed out")
242 .expect("Failed to send request");
243
244 assert!(
245 response.status().is_success(),
246 "Request should succeed with valid mTLS"
247 );
248
249 let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
250 assert_eq!(body["status"], "success");
251
252 server_handle.abort();
254 }
255
256 #[tokio::test]
257 async fn test_mtls_handshake_failure_no_client_cert() {
258 init_crypto_provider();
259 let cert_chain =
261 CertificateChain::generate().expect("Failed to generate certificate chain");
262
263 let server_config = cert_chain
265 .create_server_config()
266 .expect("Failed to create server config");
267
268 let app = Router::new().route(
270 "/test",
271 get(|| async { Json(json!({"status": "success"})) }),
272 );
273
274 let listener = TcpListener::bind("127.0.0.1:0")
276 .await
277 .expect("Failed to bind listener");
278 let server_addr = listener.local_addr().expect("Failed to get local address");
279 let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
280
281 let server_handle = tokio::spawn(async move {
282 axum_server::bind_rustls(server_addr, rustls_config)
283 .serve(app.into_make_service())
284 .await
285 });
286
287 tokio::time::sleep(Duration::from_millis(100)).await;
289
290 let client = Client::builder()
292 .danger_accept_invalid_certs(true) .build()
294 .expect("Failed to build client");
295
296 let result = timeout(
298 Duration::from_secs(5),
299 client
300 .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
301 .send(),
302 )
303 .await;
304
305 assert!(
307 result.is_err() || result.unwrap().is_err(),
308 "Request should fail without client certificate"
309 );
310
311 server_handle.abort();
313 }
314
315 #[tokio::test]
316 async fn test_mtls_handshake_failure_wrong_ca() {
317 init_crypto_provider();
318 let cert_chain1 =
320 CertificateChain::generate().expect("Failed to generate certificate chain 1");
321 let cert_chain2 =
322 CertificateChain::generate().expect("Failed to generate certificate chain 2");
323
324 let server_config = cert_chain1
326 .create_server_config()
327 .expect("Failed to create server config");
328
329 let client_config = cert_chain2
331 .create_client_config()
332 .expect("Failed to create client config");
333
334 let app = Router::new().route(
336 "/test",
337 get(|| async { Json(json!({"status": "success"})) }),
338 );
339
340 let listener = TcpListener::bind("127.0.0.1:0")
342 .await
343 .expect("Failed to bind listener");
344 let server_addr = listener.local_addr().expect("Failed to get local address");
345 let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
346
347 let server_handle = tokio::spawn(async move {
348 axum_server::bind_rustls(server_addr, rustls_config)
349 .serve(app.into_make_service())
350 .await
351 });
352
353 tokio::time::sleep(Duration::from_millis(100)).await;
355
356 let client = Client::builder()
358 .use_preconfigured_tls(client_config)
359 .build()
360 .expect("Failed to build client");
361
362 let result = timeout(
364 Duration::from_secs(5),
365 client
366 .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
367 .send(),
368 )
369 .await;
370
371 assert!(
373 result.is_err() || result.unwrap().is_err(),
374 "Request should fail with wrong CA certificates"
375 );
376
377 server_handle.abort();
379 }
380
381 #[tokio::test]
382 async fn test_certificate_chain_generation() {
383 init_crypto_provider();
384 let cert_chain =
385 CertificateChain::generate().expect("Failed to generate certificate chain");
386
387 let ca_pem = cert_chain.get_ca_cert_pem().expect("Failed to get CA PEM");
389 let server_pem = cert_chain
390 .get_server_cert_pem()
391 .expect("Failed to get server PEM");
392 let client_pem = cert_chain
393 .get_client_cert_pem()
394 .expect("Failed to get client PEM");
395 let server_key_pem = cert_chain
396 .get_server_key_pem()
397 .expect("Failed to get server key PEM");
398 let client_key_pem = cert_chain
399 .get_client_key_pem()
400 .expect("Failed to get client key PEM");
401
402 assert!(ca_pem.contains("-----BEGIN CERTIFICATE-----"));
404 assert!(ca_pem.contains("-----END CERTIFICATE-----"));
405 assert!(server_pem.contains("-----BEGIN CERTIFICATE-----"));
406 assert!(server_pem.contains("-----END CERTIFICATE-----"));
407 assert!(client_pem.contains("-----BEGIN CERTIFICATE-----"));
408 assert!(client_pem.contains("-----END CERTIFICATE-----"));
409 assert!(server_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
410 assert!(server_key_pem.contains("-----END PRIVATE KEY-----"));
411 assert!(client_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
412 assert!(client_key_pem.contains("-----END PRIVATE KEY-----"));
413 }
414
415 #[tokio::test]
416 async fn test_server_config_creation() {
417 init_crypto_provider();
418 let cert_chain =
419 CertificateChain::generate().expect("Failed to generate certificate chain");
420 let _server_config = cert_chain
421 .create_server_config()
422 .expect("Failed to create server config");
423
424 assert!(true, "Server config created successfully");
427 }
428
429 #[tokio::test]
430 async fn test_client_config_creation() {
431 init_crypto_provider();
432 let cert_chain =
433 CertificateChain::generate().expect("Failed to generate certificate chain");
434 let _client_config = cert_chain
435 .create_client_config()
436 .expect("Failed to create client config");
437
438 assert!(true, "Client config created successfully");
441 }
442
443 #[tokio::test]
444 async fn test_mtls_multiple_requests() {
445 init_crypto_provider();
446 let cert_chain =
448 CertificateChain::generate().expect("Failed to generate certificate chain");
449
450 let server_config = cert_chain
452 .create_server_config()
453 .expect("Failed to create server config");
454 let client_config = cert_chain
455 .create_client_config()
456 .expect("Failed to create client config");
457
458 let app = Router::new()
460 .route(
461 "/test1",
462 get(|| async { Json(json!({"endpoint": "test1"})) }),
463 )
464 .route(
465 "/test2",
466 get(|| async { Json(json!({"endpoint": "test2"})) }),
467 )
468 .route(
469 "/test3",
470 get(|| async { Json(json!({"endpoint": "test3"})) }),
471 );
472
473 let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
475
476 let test_port = 8444; let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
479
480 let server_handle = tokio::spawn(async move {
481 axum_server::bind_rustls(server_addr, rustls_config)
482 .serve(app.into_make_service())
483 .await
484 });
485
486 tokio::time::sleep(Duration::from_millis(500)).await;
488
489 let client = Client::builder()
491 .use_preconfigured_tls(client_config)
492 .build()
493 .expect("Failed to build client");
494
495 for endpoint in ["test1", "test2", "test3"] {
497 let response = timeout(
498 Duration::from_secs(10),
499 client
500 .get(format!("https://127.0.0.1:{}/{}", test_port, endpoint))
501 .send(),
502 )
503 .await
504 .expect("Request timed out")
505 .expect("Failed to send request");
506
507 assert!(
508 response.status().is_success(),
509 "Request to {} should succeed",
510 endpoint
511 );
512
513 let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
514 assert_eq!(body["endpoint"], endpoint);
515 }
516
517 server_handle.abort();
519 }
520
521 #[tokio::test]
522 async fn test_mtls_configuration_compatibility() {
523 init_crypto_provider();
524
525 let cert_chain =
527 CertificateChain::generate().expect("Failed to generate certificate chain");
528
529 let server_config = cert_chain
531 .create_server_config()
532 .expect("Failed to create server config");
533
534 let client_config = cert_chain
536 .create_client_config()
537 .expect("Failed to create client config");
538
539 let _client = Client::builder()
541 .use_preconfigured_tls(client_config)
542 .build()
543 .expect("Failed to build reqwest client with mTLS config");
544
545 let _rustls_config = RustlsConfig::from_config(Arc::new(server_config));
547
548 assert!(cert_chain.get_ca_cert_pem().is_ok());
550 assert!(cert_chain.get_server_cert_pem().is_ok());
551 assert!(cert_chain.get_client_cert_pem().is_ok());
552 assert!(cert_chain.get_server_key_pem().is_ok());
553 assert!(cert_chain.get_client_key_pem().is_ok());
554
555 println!("✅ mTLS configuration successfully created");
557 println!("✅ Reqwest client can be configured with client certificates");
558 println!("✅ Axum server can be configured with server certificates");
559 println!("✅ Certificate chain includes CA, server, and client certificates");
560 }
561
562 #[tokio::test]
563 async fn test_mtls_certificate_validation() {
564 init_crypto_provider();
565
566 let cert_chain1 =
568 CertificateChain::generate().expect("Failed to generate certificate chain 1");
569 let cert_chain2 =
570 CertificateChain::generate().expect("Failed to generate certificate chain 2");
571
572 let server_config1 = cert_chain1
574 .create_server_config()
575 .expect("Failed to create server config 1");
576 let client_config2 = cert_chain2
577 .create_client_config()
578 .expect("Failed to create client config 2");
579
580 let _client = Client::builder()
582 .use_preconfigured_tls(client_config2)
583 .build()
584 .expect("Failed to build client with different CA");
585
586 let _rustls_config = RustlsConfig::from_config(Arc::new(server_config1));
587
588 println!("✅ Different certificate chains create valid but incompatible configurations");
591 }
592}