1use anyhow::Result;
2use rcgen::{
3 BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyUsagePurpose, SanType,
4};
5use rustls::pki_types::{CertificateDer, PrivateKeyDer};
6use rustls::{ClientConfig, RootCertStore, ServerConfig};
7use std::sync::Arc;
8use time::OffsetDateTime;
9
10pub struct CertificateChain {
11 pub ca_cert: rcgen::Certificate,
12 pub server_cert: rcgen::Certificate,
13 pub client_cert: rcgen::Certificate,
14}
15
16impl std::fmt::Debug for CertificateChain {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 f.debug_struct("CertificateChain")
19 .field("ca_cert", &"<certificate>")
20 .field("server_cert", &"<certificate>")
21 .field("client_cert", &"<certificate>")
22 .finish()
23 }
24}
25
26impl CertificateChain {
27 pub fn generate() -> Result<Self> {
28 let mut ca_params = CertificateParams::default();
30 ca_params.distinguished_name = DistinguishedName::new();
31 ca_params
32 .distinguished_name
33 .push(DnType::CommonName, "Stakpak MCP CA");
34 ca_params
35 .distinguished_name
36 .push(DnType::OrganizationName, "Stakpak");
37 ca_params.distinguished_name.push(DnType::CountryName, "US");
38
39 ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
40 ca_params.key_usages = vec![
41 KeyUsagePurpose::KeyCertSign,
42 KeyUsagePurpose::CrlSign,
43 KeyUsagePurpose::DigitalSignature,
44 ];
45
46 ca_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
47 ca_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
48
49 let ca_cert = rcgen::Certificate::from_params(ca_params)?;
50
51 let mut server_params = CertificateParams::default();
53 server_params.distinguished_name = DistinguishedName::new();
54 server_params
55 .distinguished_name
56 .push(DnType::CommonName, "Stakpak MCP Server");
57 server_params
58 .distinguished_name
59 .push(DnType::OrganizationName, "Stakpak");
60 server_params
61 .distinguished_name
62 .push(DnType::CountryName, "US");
63
64 server_params.subject_alt_names = vec![
65 SanType::DnsName("localhost".to_string()),
66 SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0))),
67 SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
68 ];
69
70 server_params.key_usages = vec![
71 KeyUsagePurpose::DigitalSignature,
72 KeyUsagePurpose::KeyEncipherment,
73 ];
74
75 server_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
76 server_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
77
78 let server_cert = rcgen::Certificate::from_params(server_params)?;
79
80 let mut client_params = CertificateParams::default();
82 client_params.distinguished_name = DistinguishedName::new();
83 client_params
84 .distinguished_name
85 .push(DnType::CommonName, "Stakpak MCP Client");
86 client_params
87 .distinguished_name
88 .push(DnType::OrganizationName, "Stakpak");
89 client_params
90 .distinguished_name
91 .push(DnType::CountryName, "US");
92
93 client_params.key_usages = vec![
94 KeyUsagePurpose::DigitalSignature,
95 KeyUsagePurpose::KeyEncipherment,
96 ];
97
98 client_params.not_before = OffsetDateTime::now_utc() - time::Duration::seconds(60);
99 client_params.not_after = OffsetDateTime::now_utc() + time::Duration::days(365);
100
101 let client_cert = rcgen::Certificate::from_params(client_params)?;
102
103 Ok(CertificateChain {
104 ca_cert,
105 server_cert,
106 client_cert,
107 })
108 }
109
110 pub fn create_server_config(&self) -> Result<ServerConfig> {
111 let server_cert_der = self.server_cert.serialize_der_with_signer(&self.ca_cert)?;
113 let server_key_der = self.server_cert.serialize_private_key_der();
114
115 let server_cert_chain = vec![CertificateDer::from(server_cert_der)];
116 let server_private_key = PrivateKeyDer::try_from(server_key_der)
117 .map_err(|e| anyhow::anyhow!("Failed to convert server private key: {:?}", e))?;
118
119 let mut root_cert_store = RootCertStore::empty();
121 let ca_cert_der = self.ca_cert.serialize_der()?;
122 root_cert_store.add(CertificateDer::from(ca_cert_der))?;
123
124 let client_cert_verifier =
126 rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
127 .build()
128 .map_err(|e| anyhow::anyhow!("Failed to build client cert verifier: {}", e))?;
129
130 let config = ServerConfig::builder()
131 .with_client_cert_verifier(client_cert_verifier)
132 .with_single_cert(server_cert_chain, server_private_key)?;
133
134 Ok(config)
135 }
136
137 pub fn create_client_config(&self) -> Result<ClientConfig> {
138 let client_cert_der = self.client_cert.serialize_der_with_signer(&self.ca_cert)?;
140 let client_key_der = self.client_cert.serialize_private_key_der();
141
142 let client_cert_chain = vec![CertificateDer::from(client_cert_der)];
143 let client_private_key = PrivateKeyDer::try_from(client_key_der)
144 .map_err(|e| anyhow::anyhow!("Failed to convert client private key: {:?}", e))?;
145
146 let mut root_cert_store = RootCertStore::empty();
148 let ca_cert_der = self.ca_cert.serialize_der()?;
149 root_cert_store.add(CertificateDer::from(ca_cert_der))?;
150
151 let config = ClientConfig::builder()
152 .with_root_certificates(root_cert_store)
153 .with_client_auth_cert(client_cert_chain, client_private_key)?;
154
155 Ok(config)
156 }
157
158 pub fn get_ca_cert_pem(&self) -> Result<String> {
159 Ok(self.ca_cert.serialize_pem()?)
160 }
161
162 pub fn get_server_cert_pem(&self) -> Result<String> {
163 Ok(self.server_cert.serialize_pem_with_signer(&self.ca_cert)?)
164 }
165
166 pub fn get_client_cert_pem(&self) -> Result<String> {
167 Ok(self.client_cert.serialize_pem_with_signer(&self.ca_cert)?)
168 }
169
170 pub fn get_server_key_pem(&self) -> Result<String> {
171 Ok(self.server_cert.serialize_private_key_pem())
172 }
173
174 pub fn get_client_key_pem(&self) -> Result<String> {
175 Ok(self.client_cert.serialize_private_key_pem())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use axum::{Router, response::Json, routing::get};
183 use axum_server::tls_rustls::RustlsConfig;
184 use reqwest::Client;
185 use serde_json::json;
186 use std::sync::Arc;
187 use tokio::net::TcpListener;
188 use tokio::time::{Duration, timeout};
189
190 fn init_crypto_provider() {
191 use std::sync::Once;
192 static INIT: Once = Once::new();
193 INIT.call_once(|| {
194 rustls::crypto::aws_lc_rs::default_provider()
195 .install_default()
196 .expect("Failed to install crypto provider");
197 });
198 }
199
200 #[tokio::test]
201 async fn test_mtls_handshake_success() {
202 init_crypto_provider();
203 let cert_chain =
205 CertificateChain::generate().expect("Failed to generate certificate chain");
206
207 let server_config = cert_chain
209 .create_server_config()
210 .expect("Failed to create server config");
211
212 let client_config = cert_chain
214 .create_client_config()
215 .expect("Failed to create client config");
216
217 let app = Router::new().route(
219 "/test",
220 get(|| async { Json(json!({"status": "success"})) }),
221 );
222
223 let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
225
226 let test_port = 8443;
228 let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
229
230 let server_handle = tokio::spawn(async move {
231 axum_server::bind_rustls(server_addr, rustls_config)
232 .serve(app.into_make_service())
233 .await
234 });
235
236 tokio::time::sleep(Duration::from_millis(500)).await;
238
239 let client = Client::builder()
241 .use_preconfigured_tls(client_config)
242 .build()
243 .expect("Failed to build client");
244
245 let url = format!("https://127.0.0.1:{}/test", test_port);
247 println!("Testing mTLS connection to: {}", url);
248
249 let response = timeout(Duration::from_secs(10), client.get(&url).send())
250 .await
251 .expect("Request timed out")
252 .expect("Failed to send request");
253
254 assert!(
255 response.status().is_success(),
256 "Request should succeed with valid mTLS"
257 );
258
259 let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
260 assert_eq!(body["status"], "success");
261
262 server_handle.abort();
264 }
265
266 #[tokio::test]
267 async fn test_mtls_handshake_failure_no_client_cert() {
268 init_crypto_provider();
269 let cert_chain =
271 CertificateChain::generate().expect("Failed to generate certificate chain");
272
273 let server_config = cert_chain
275 .create_server_config()
276 .expect("Failed to create server config");
277
278 let app = Router::new().route(
280 "/test",
281 get(|| async { Json(json!({"status": "success"})) }),
282 );
283
284 let listener = TcpListener::bind("127.0.0.1:0")
286 .await
287 .expect("Failed to bind listener");
288 let server_addr = listener.local_addr().expect("Failed to get local address");
289 let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
290
291 let server_handle = tokio::spawn(async move {
292 axum_server::bind_rustls(server_addr, rustls_config)
293 .serve(app.into_make_service())
294 .await
295 });
296
297 tokio::time::sleep(Duration::from_millis(100)).await;
299
300 let client = Client::builder()
302 .danger_accept_invalid_certs(true) .build()
304 .expect("Failed to build client");
305
306 let result = timeout(
308 Duration::from_secs(5),
309 client
310 .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
311 .send(),
312 )
313 .await;
314
315 assert!(
317 result.is_err() || result.unwrap().is_err(),
318 "Request should fail without client certificate"
319 );
320
321 server_handle.abort();
323 }
324
325 #[tokio::test]
326 async fn test_mtls_handshake_failure_wrong_ca() {
327 init_crypto_provider();
328 let cert_chain1 =
330 CertificateChain::generate().expect("Failed to generate certificate chain 1");
331 let cert_chain2 =
332 CertificateChain::generate().expect("Failed to generate certificate chain 2");
333
334 let server_config = cert_chain1
336 .create_server_config()
337 .expect("Failed to create server config");
338
339 let client_config = cert_chain2
341 .create_client_config()
342 .expect("Failed to create client config");
343
344 let app = Router::new().route(
346 "/test",
347 get(|| async { Json(json!({"status": "success"})) }),
348 );
349
350 let listener = TcpListener::bind("127.0.0.1:0")
352 .await
353 .expect("Failed to bind listener");
354 let server_addr = listener.local_addr().expect("Failed to get local address");
355 let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
356
357 let server_handle = tokio::spawn(async move {
358 axum_server::bind_rustls(server_addr, rustls_config)
359 .serve(app.into_make_service())
360 .await
361 });
362
363 tokio::time::sleep(Duration::from_millis(100)).await;
365
366 let client = Client::builder()
368 .use_preconfigured_tls(client_config)
369 .build()
370 .expect("Failed to build client");
371
372 let result = timeout(
374 Duration::from_secs(5),
375 client
376 .get(format!("https://127.0.0.1:{}/test", server_addr.port()))
377 .send(),
378 )
379 .await;
380
381 assert!(
383 result.is_err() || result.unwrap().is_err(),
384 "Request should fail with wrong CA certificates"
385 );
386
387 server_handle.abort();
389 }
390
391 #[tokio::test]
392 async fn test_certificate_chain_generation() {
393 init_crypto_provider();
394 let cert_chain =
395 CertificateChain::generate().expect("Failed to generate certificate chain");
396
397 let ca_pem = cert_chain.get_ca_cert_pem().expect("Failed to get CA PEM");
399 let server_pem = cert_chain
400 .get_server_cert_pem()
401 .expect("Failed to get server PEM");
402 let client_pem = cert_chain
403 .get_client_cert_pem()
404 .expect("Failed to get client PEM");
405 let server_key_pem = cert_chain
406 .get_server_key_pem()
407 .expect("Failed to get server key PEM");
408 let client_key_pem = cert_chain
409 .get_client_key_pem()
410 .expect("Failed to get client key PEM");
411
412 assert!(ca_pem.contains("-----BEGIN CERTIFICATE-----"));
414 assert!(ca_pem.contains("-----END CERTIFICATE-----"));
415 assert!(server_pem.contains("-----BEGIN CERTIFICATE-----"));
416 assert!(server_pem.contains("-----END CERTIFICATE-----"));
417 assert!(client_pem.contains("-----BEGIN CERTIFICATE-----"));
418 assert!(client_pem.contains("-----END CERTIFICATE-----"));
419 assert!(server_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
420 assert!(server_key_pem.contains("-----END PRIVATE KEY-----"));
421 assert!(client_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
422 assert!(client_key_pem.contains("-----END PRIVATE KEY-----"));
423 }
424
425 #[tokio::test]
426 async fn test_server_config_creation() {
427 init_crypto_provider();
428 let cert_chain =
429 CertificateChain::generate().expect("Failed to generate certificate chain");
430 let _server_config = cert_chain
431 .create_server_config()
432 .expect("Failed to create server config");
433
434 assert!(true, "Server config created successfully");
437 }
438
439 #[tokio::test]
440 async fn test_client_config_creation() {
441 init_crypto_provider();
442 let cert_chain =
443 CertificateChain::generate().expect("Failed to generate certificate chain");
444 let _client_config = cert_chain
445 .create_client_config()
446 .expect("Failed to create client config");
447
448 assert!(true, "Client config created successfully");
451 }
452
453 #[tokio::test]
454 async fn test_mtls_multiple_requests() {
455 init_crypto_provider();
456 let cert_chain =
458 CertificateChain::generate().expect("Failed to generate certificate chain");
459
460 let server_config = cert_chain
462 .create_server_config()
463 .expect("Failed to create server config");
464 let client_config = cert_chain
465 .create_client_config()
466 .expect("Failed to create client config");
467
468 let app = Router::new()
470 .route(
471 "/test1",
472 get(|| async { Json(json!({"endpoint": "test1"})) }),
473 )
474 .route(
475 "/test2",
476 get(|| async { Json(json!({"endpoint": "test2"})) }),
477 )
478 .route(
479 "/test3",
480 get(|| async { Json(json!({"endpoint": "test3"})) }),
481 );
482
483 let rustls_config = RustlsConfig::from_config(Arc::new(server_config));
485
486 let test_port = 8444; let server_addr = format!("127.0.0.1:{}", test_port).parse().unwrap();
489
490 let server_handle = tokio::spawn(async move {
491 axum_server::bind_rustls(server_addr, rustls_config)
492 .serve(app.into_make_service())
493 .await
494 });
495
496 tokio::time::sleep(Duration::from_millis(500)).await;
498
499 let client = Client::builder()
501 .use_preconfigured_tls(client_config)
502 .build()
503 .expect("Failed to build client");
504
505 for endpoint in ["test1", "test2", "test3"] {
507 let response = timeout(
508 Duration::from_secs(10),
509 client
510 .get(format!("https://127.0.0.1:{}/{}", test_port, endpoint))
511 .send(),
512 )
513 .await
514 .expect("Request timed out")
515 .expect("Failed to send request");
516
517 assert!(
518 response.status().is_success(),
519 "Request to {} should succeed",
520 endpoint
521 );
522
523 let body: serde_json::Value = response.json().await.expect("Failed to parse JSON");
524 assert_eq!(body["endpoint"], endpoint);
525 }
526
527 server_handle.abort();
529 }
530
531 #[tokio::test]
532 async fn test_mtls_configuration_compatibility() {
533 init_crypto_provider();
534
535 let cert_chain =
537 CertificateChain::generate().expect("Failed to generate certificate chain");
538
539 let server_config = cert_chain
541 .create_server_config()
542 .expect("Failed to create server config");
543
544 let client_config = cert_chain
546 .create_client_config()
547 .expect("Failed to create client config");
548
549 let _client = Client::builder()
551 .use_preconfigured_tls(client_config)
552 .build()
553 .expect("Failed to build reqwest client with mTLS config");
554
555 let _rustls_config = RustlsConfig::from_config(Arc::new(server_config));
557
558 assert!(cert_chain.get_ca_cert_pem().is_ok());
560 assert!(cert_chain.get_server_cert_pem().is_ok());
561 assert!(cert_chain.get_client_cert_pem().is_ok());
562 assert!(cert_chain.get_server_key_pem().is_ok());
563 assert!(cert_chain.get_client_key_pem().is_ok());
564
565 println!("✅ mTLS configuration successfully created");
567 println!("✅ Reqwest client can be configured with client certificates");
568 println!("✅ Axum server can be configured with server certificates");
569 println!("✅ Certificate chain includes CA, server, and client certificates");
570 }
571
572 #[tokio::test]
573 async fn test_mtls_certificate_validation() {
574 init_crypto_provider();
575
576 let cert_chain1 =
578 CertificateChain::generate().expect("Failed to generate certificate chain 1");
579 let cert_chain2 =
580 CertificateChain::generate().expect("Failed to generate certificate chain 2");
581
582 let server_config1 = cert_chain1
584 .create_server_config()
585 .expect("Failed to create server config 1");
586 let client_config2 = cert_chain2
587 .create_client_config()
588 .expect("Failed to create client config 2");
589
590 let _client = Client::builder()
592 .use_preconfigured_tls(client_config2)
593 .build()
594 .expect("Failed to build client with different CA");
595
596 let _rustls_config = RustlsConfig::from_config(Arc::new(server_config1));
597
598 println!("✅ Different certificate chains create valid but incompatible configurations");
601 }
602}