Skip to main content

uselesskey_rustls/
config.rs

1//! Convenience builders for `rustls::ServerConfig` and `rustls::ClientConfig`.
2
3use std::sync::Arc;
4
5use rustls::crypto::CryptoProvider;
6
7#[cfg(feature = "x509")]
8use crate::RustlsCertExt;
9#[cfg(feature = "x509")]
10use crate::RustlsChainExt;
11#[cfg(feature = "server-config")]
12use crate::RustlsPrivateKeyExt;
13
14// ---------------------------------------------------------------------------
15// ServerConfig
16// ---------------------------------------------------------------------------
17
18/// Extension trait that builds a `rustls::ServerConfig` from uselesskey fixtures.
19#[cfg(feature = "server-config")]
20pub trait RustlsServerConfigExt {
21    /// Build a `ServerConfig` using the process-default `CryptoProvider`.
22    fn server_config_rustls(&self) -> rustls::ServerConfig;
23
24    /// Build a `ServerConfig` with an explicit `CryptoProvider`.
25    fn server_config_rustls_with_provider(
26        &self,
27        provider: Arc<CryptoProvider>,
28    ) -> rustls::ServerConfig;
29}
30
31#[cfg(all(feature = "x509", feature = "server-config"))]
32impl RustlsServerConfigExt for uselesskey_x509::X509Chain {
33    fn server_config_rustls(&self) -> rustls::ServerConfig {
34        let private_key = self.private_key_der_rustls();
35        let cert_chain = self.chain_der_rustls();
36        rustls::ServerConfig::builder()
37            .with_no_client_auth()
38            .with_single_cert(cert_chain, private_key)
39            .expect("valid server config")
40    }
41
42    fn server_config_rustls_with_provider(
43        &self,
44        provider: Arc<CryptoProvider>,
45    ) -> rustls::ServerConfig {
46        let private_key = self.private_key_der_rustls();
47        let cert_chain = self.chain_der_rustls();
48        rustls::ServerConfig::builder_with_provider(provider)
49            .with_safe_default_protocol_versions()
50            .expect("valid protocol versions")
51            .with_no_client_auth()
52            .with_single_cert(cert_chain, private_key)
53            .expect("valid server config")
54    }
55}
56
57#[cfg(all(feature = "x509", feature = "server-config"))]
58impl RustlsServerConfigExt for uselesskey_x509::X509Cert {
59    fn server_config_rustls(&self) -> rustls::ServerConfig {
60        let private_key = self.private_key_der_rustls();
61        let cert_chain = vec![self.certificate_der_rustls()];
62        rustls::ServerConfig::builder()
63            .with_no_client_auth()
64            .with_single_cert(cert_chain, private_key)
65            .expect("valid server config")
66    }
67
68    fn server_config_rustls_with_provider(
69        &self,
70        provider: Arc<CryptoProvider>,
71    ) -> rustls::ServerConfig {
72        let private_key = self.private_key_der_rustls();
73        let cert_chain = vec![self.certificate_der_rustls()];
74        rustls::ServerConfig::builder_with_provider(provider)
75            .with_safe_default_protocol_versions()
76            .expect("valid protocol versions")
77            .with_no_client_auth()
78            .with_single_cert(cert_chain, private_key)
79            .expect("valid server config")
80    }
81}
82
83// ---------------------------------------------------------------------------
84// ClientConfig
85// ---------------------------------------------------------------------------
86
87/// Extension trait that builds a `rustls::ClientConfig` from uselesskey fixtures.
88#[cfg(feature = "client-config")]
89pub trait RustlsClientConfigExt {
90    /// Build a `ClientConfig` that trusts the root CA, with no client certificate.
91    fn client_config_rustls(&self) -> rustls::ClientConfig;
92
93    /// Build a `ClientConfig` with an explicit `CryptoProvider`.
94    fn client_config_rustls_with_provider(
95        &self,
96        provider: Arc<CryptoProvider>,
97    ) -> rustls::ClientConfig;
98}
99
100#[cfg(all(feature = "x509", feature = "client-config"))]
101impl RustlsClientConfigExt for uselesskey_x509::X509Chain {
102    fn client_config_rustls(&self) -> rustls::ClientConfig {
103        let mut root_store = rustls::RootCertStore::empty();
104        root_store
105            .add(self.root_certificate_der_rustls())
106            .expect("valid root cert");
107        rustls::ClientConfig::builder()
108            .with_root_certificates(root_store)
109            .with_no_client_auth()
110    }
111
112    fn client_config_rustls_with_provider(
113        &self,
114        provider: Arc<CryptoProvider>,
115    ) -> rustls::ClientConfig {
116        let mut root_store = rustls::RootCertStore::empty();
117        root_store
118            .add(self.root_certificate_der_rustls())
119            .expect("valid root cert");
120        rustls::ClientConfig::builder_with_provider(provider)
121            .with_safe_default_protocol_versions()
122            .expect("valid protocol versions")
123            .with_root_certificates(root_store)
124            .with_no_client_auth()
125    }
126}
127
128#[cfg(all(feature = "x509", feature = "client-config"))]
129impl RustlsClientConfigExt for uselesskey_x509::X509Cert {
130    fn client_config_rustls(&self) -> rustls::ClientConfig {
131        let mut root_store = rustls::RootCertStore::empty();
132        root_store
133            .add(self.certificate_der_rustls())
134            .expect("valid root cert");
135        rustls::ClientConfig::builder()
136            .with_root_certificates(root_store)
137            .with_no_client_auth()
138    }
139
140    fn client_config_rustls_with_provider(
141        &self,
142        provider: Arc<CryptoProvider>,
143    ) -> rustls::ClientConfig {
144        let mut root_store = rustls::RootCertStore::empty();
145        root_store
146            .add(self.certificate_der_rustls())
147            .expect("valid root cert");
148        rustls::ClientConfig::builder_with_provider(provider)
149            .with_safe_default_protocol_versions()
150            .expect("valid protocol versions")
151            .with_root_certificates(root_store)
152            .with_no_client_auth()
153    }
154}
155
156// ---------------------------------------------------------------------------
157// mTLS
158// ---------------------------------------------------------------------------
159
160/// Extension trait for mutual TLS configurations.
161#[cfg(all(feature = "server-config", feature = "client-config"))]
162pub trait RustlsMtlsExt {
163    /// Build a `ServerConfig` that requires client certificates verified against
164    /// the chain's root CA.
165    fn server_config_mtls_rustls(&self) -> rustls::ServerConfig;
166
167    /// Build a `ServerConfig` for mTLS with an explicit `CryptoProvider`.
168    fn server_config_mtls_rustls_with_provider(
169        &self,
170        provider: Arc<CryptoProvider>,
171    ) -> rustls::ServerConfig;
172
173    /// Build a `ClientConfig` that presents the leaf certificate as a client
174    /// certificate and trusts the root CA.
175    fn client_config_mtls_rustls(&self) -> rustls::ClientConfig;
176
177    /// Build a `ClientConfig` for mTLS with an explicit `CryptoProvider`.
178    fn client_config_mtls_rustls_with_provider(
179        &self,
180        provider: Arc<CryptoProvider>,
181    ) -> rustls::ClientConfig;
182}
183
184#[cfg(all(feature = "x509", feature = "server-config", feature = "client-config"))]
185impl RustlsMtlsExt for uselesskey_x509::X509Chain {
186    fn server_config_mtls_rustls(&self) -> rustls::ServerConfig {
187        let mut root_store = rustls::RootCertStore::empty();
188        root_store
189            .add(self.root_certificate_der_rustls())
190            .expect("valid root cert");
191
192        let client_verifier = rustls::server::WebPkiClientVerifier::builder(root_store.into())
193            .build()
194            .expect("valid client verifier");
195
196        let private_key = self.private_key_der_rustls();
197        let cert_chain = self.chain_der_rustls();
198
199        rustls::ServerConfig::builder()
200            .with_client_cert_verifier(client_verifier)
201            .with_single_cert(cert_chain, private_key)
202            .expect("valid mTLS server config")
203    }
204
205    fn server_config_mtls_rustls_with_provider(
206        &self,
207        provider: Arc<CryptoProvider>,
208    ) -> rustls::ServerConfig {
209        let mut root_store = rustls::RootCertStore::empty();
210        root_store
211            .add(self.root_certificate_der_rustls())
212            .expect("valid root cert");
213
214        let client_verifier = rustls::server::WebPkiClientVerifier::builder(root_store.into())
215            .build()
216            .expect("valid client verifier");
217
218        let private_key = self.private_key_der_rustls();
219        let cert_chain = self.chain_der_rustls();
220
221        rustls::ServerConfig::builder_with_provider(provider)
222            .with_safe_default_protocol_versions()
223            .expect("valid protocol versions")
224            .with_client_cert_verifier(client_verifier)
225            .with_single_cert(cert_chain, private_key)
226            .expect("valid mTLS server config")
227    }
228
229    fn client_config_mtls_rustls(&self) -> rustls::ClientConfig {
230        let mut root_store = rustls::RootCertStore::empty();
231        root_store
232            .add(self.root_certificate_der_rustls())
233            .expect("valid root cert");
234
235        let private_key = self.private_key_der_rustls();
236        let cert_chain = self.chain_der_rustls();
237
238        rustls::ClientConfig::builder()
239            .with_root_certificates(root_store)
240            .with_client_auth_cert(cert_chain, private_key)
241            .expect("valid mTLS client config")
242    }
243
244    fn client_config_mtls_rustls_with_provider(
245        &self,
246        provider: Arc<CryptoProvider>,
247    ) -> rustls::ClientConfig {
248        let mut root_store = rustls::RootCertStore::empty();
249        root_store
250            .add(self.root_certificate_der_rustls())
251            .expect("valid root cert");
252
253        let private_key = self.private_key_der_rustls();
254        let cert_chain = self.chain_der_rustls();
255
256        rustls::ClientConfig::builder_with_provider(provider)
257            .with_safe_default_protocol_versions()
258            .expect("valid protocol versions")
259            .with_root_certificates(root_store)
260            .with_client_auth_cert(cert_chain, private_key)
261            .expect("valid mTLS client config")
262    }
263}
264
265// ---------------------------------------------------------------------------
266// Tests
267// ---------------------------------------------------------------------------
268
269#[cfg(test)]
270#[cfg(all(feature = "server-config", feature = "client-config"))]
271mod tests {
272    use super::*;
273    use uselesskey_x509::{ChainSpec, X509FactoryExt, X509Spec};
274
275    use std::sync::Once;
276    static INIT: Once = Once::new();
277
278    fn install_provider() {
279        INIT.call_once(|| {
280            // When both `rustls-ring` and `rustls-aws-lc-rs` features are
281            // enabled (e.g. via `--all-features`), another provider may
282            // already be set as process-default. Ignore the error — the
283            // explicit-provider tests cover the critical paths.
284            let _ = rustls::crypto::ring::default_provider().install_default();
285        });
286    }
287
288    fn ring_provider() -> Arc<CryptoProvider> {
289        Arc::new(rustls::crypto::ring::default_provider())
290    }
291
292    // Maximum iterations for TLS handshake loops to prevent infinite loops
293    // A normal TLS handshake completes in well under 10 iterations
294    const MAX_HANDSHAKE_ITERATIONS: usize = 10;
295
296    #[test]
297    fn server_config_from_chain() {
298        install_provider();
299        let fx = super::super::testutil::fx();
300        let chain = fx.x509_chain("test", ChainSpec::new("test.example.com"));
301        // Succeeds without panic = config was built with valid cert/key
302        let _cfg = chain.server_config_rustls();
303    }
304
305    #[test]
306    fn server_config_from_chain_with_provider() {
307        install_provider();
308        let fx = super::super::testutil::fx();
309        let chain = fx.x509_chain("test-provider", ChainSpec::new("test.example.com"));
310        let _cfg = chain.server_config_rustls_with_provider(ring_provider());
311    }
312
313    #[test]
314    fn client_config_from_chain() {
315        install_provider();
316        let fx = super::super::testutil::fx();
317        let chain = fx.x509_chain("test", ChainSpec::new("test.example.com"));
318        let _cfg = chain.client_config_rustls();
319    }
320
321    #[test]
322    fn client_config_from_chain_with_provider() {
323        install_provider();
324        let fx = super::super::testutil::fx();
325        let chain = fx.x509_chain("test-provider", ChainSpec::new("test.example.com"));
326        let _cfg = chain.client_config_rustls_with_provider(ring_provider());
327    }
328
329    #[test]
330    fn server_config_from_self_signed() {
331        install_provider();
332        let fx = super::super::testutil::fx();
333        let cert = fx.x509_self_signed("test", X509Spec::self_signed("test.example.com"));
334        let _cfg = cert.server_config_rustls();
335    }
336
337    #[test]
338    fn server_config_from_self_signed_with_provider() {
339        install_provider();
340        let fx = super::super::testutil::fx();
341        let cert = fx.x509_self_signed("test-provider", X509Spec::self_signed("test.example.com"));
342        let _cfg = cert.server_config_rustls_with_provider(ring_provider());
343    }
344
345    #[test]
346    fn client_config_from_self_signed() {
347        install_provider();
348        let fx = super::super::testutil::fx();
349        let cert = fx.x509_self_signed("test", X509Spec::self_signed("test.example.com"));
350        let _cfg = cert.client_config_rustls();
351    }
352
353    #[test]
354    fn client_config_from_self_signed_with_provider() {
355        install_provider();
356        let fx = super::super::testutil::fx();
357        let cert = fx.x509_self_signed("test-provider", X509Spec::self_signed("test.example.com"));
358        let _cfg = cert.client_config_rustls_with_provider(ring_provider());
359    }
360
361    #[test]
362    fn tls_handshake_roundtrip() {
363        let fx = super::super::testutil::fx();
364        let chain = fx.x509_chain("tls-test", ChainSpec::new("test.example.com"));
365
366        let provider = ring_provider();
367        let server_config = Arc::new(chain.server_config_rustls_with_provider(provider.clone()));
368        let client_config = Arc::new(chain.client_config_rustls_with_provider(provider));
369
370        let server_name: rustls::pki_types::ServerName<'_> = "test.example.com".try_into().unwrap();
371        let mut server = rustls::ServerConnection::new(server_config).unwrap();
372        let mut client =
373            rustls::ClientConnection::new(client_config, server_name.to_owned()).unwrap();
374
375        // Drive the handshake to completion by transferring bytes between
376        // client and server until neither side needs to write.
377        let mut buf = Vec::new();
378        for iteration in 0..MAX_HANDSHAKE_ITERATIONS {
379            let mut progress = false;
380
381            // client -> server
382            buf.clear();
383            if client.wants_write() {
384                client.write_tls(&mut buf).unwrap();
385                if !buf.is_empty() {
386                    server.read_tls(&mut &buf[..]).unwrap();
387                    server.process_new_packets().unwrap();
388                    progress = true;
389                }
390            }
391
392            // server -> client
393            buf.clear();
394            if server.wants_write() {
395                server.write_tls(&mut buf).unwrap();
396                if !buf.is_empty() {
397                    client.read_tls(&mut &buf[..]).unwrap();
398                    client.process_new_packets().unwrap();
399                    progress = true;
400                }
401            }
402
403            if !progress {
404                break;
405            }
406
407            // Safety check: if we've exhausted iterations without completing,
408            // something is wrong with the handshake state machine
409            assert!(
410                iteration < MAX_HANDSHAKE_ITERATIONS - 1,
411                "TLS handshake did not complete within {} iterations",
412                MAX_HANDSHAKE_ITERATIONS
413            );
414        }
415
416        assert!(!client.is_handshaking());
417        assert!(!server.is_handshaking());
418    }
419
420    #[test]
421    fn mtls_with_provider_roundtrip() {
422        let fx = super::super::testutil::fx();
423        let chain = fx.x509_chain("mtls-provider-test", ChainSpec::new("test.example.com"));
424
425        let provider = ring_provider();
426        let server_config =
427            Arc::new(chain.server_config_mtls_rustls_with_provider(provider.clone()));
428        let client_config = Arc::new(chain.client_config_mtls_rustls_with_provider(provider));
429
430        let server_name: rustls::pki_types::ServerName<'_> = "test.example.com".try_into().unwrap();
431        let mut server = rustls::ServerConnection::new(server_config).unwrap();
432        let mut client =
433            rustls::ClientConnection::new(client_config, server_name.to_owned()).unwrap();
434
435        let mut buf = Vec::new();
436        for iteration in 0..MAX_HANDSHAKE_ITERATIONS {
437            let mut progress = false;
438
439            buf.clear();
440            if client.wants_write() {
441                client.write_tls(&mut buf).unwrap();
442                if !buf.is_empty() {
443                    server.read_tls(&mut &buf[..]).unwrap();
444                    server.process_new_packets().unwrap();
445                    progress = true;
446                }
447            }
448
449            buf.clear();
450            if server.wants_write() {
451                server.write_tls(&mut buf).unwrap();
452                if !buf.is_empty() {
453                    client.read_tls(&mut &buf[..]).unwrap();
454                    client.process_new_packets().unwrap();
455                    progress = true;
456                }
457            }
458
459            if !progress {
460                break;
461            }
462
463            // Safety check: if we've exhausted iterations without completing,
464            // something is wrong with the handshake state machine
465            assert!(
466                iteration < MAX_HANDSHAKE_ITERATIONS - 1,
467                "mTLS handshake did not complete within {} iterations",
468                MAX_HANDSHAKE_ITERATIONS
469            );
470        }
471
472        assert!(!client.is_handshaking());
473        assert!(!server.is_handshaking());
474    }
475
476    #[test]
477    fn mtls_roundtrip() {
478        let fx = super::super::testutil::fx();
479        let chain = fx.x509_chain("mtls-test", ChainSpec::new("test.example.com"));
480
481        let provider = ring_provider();
482        let server_config =
483            Arc::new(chain.server_config_mtls_rustls_with_provider(provider.clone()));
484        let client_config = Arc::new(chain.client_config_mtls_rustls_with_provider(provider));
485
486        let server_name: rustls::pki_types::ServerName<'_> = "test.example.com".try_into().unwrap();
487        let mut server = rustls::ServerConnection::new(server_config).unwrap();
488        let mut client =
489            rustls::ClientConnection::new(client_config, server_name.to_owned()).unwrap();
490
491        let mut buf = Vec::new();
492        for iteration in 0..MAX_HANDSHAKE_ITERATIONS {
493            let mut progress = false;
494
495            buf.clear();
496            if client.wants_write() {
497                client.write_tls(&mut buf).unwrap();
498                if !buf.is_empty() {
499                    server.read_tls(&mut &buf[..]).unwrap();
500                    server.process_new_packets().unwrap();
501                    progress = true;
502                }
503            }
504
505            buf.clear();
506            if server.wants_write() {
507                server.write_tls(&mut buf).unwrap();
508                if !buf.is_empty() {
509                    client.read_tls(&mut &buf[..]).unwrap();
510                    client.process_new_packets().unwrap();
511                    progress = true;
512                }
513            }
514
515            if !progress {
516                break;
517            }
518
519            // Safety check: if we've exhausted iterations without completing,
520            // something is wrong with the handshake state machine
521            assert!(
522                iteration < MAX_HANDSHAKE_ITERATIONS - 1,
523                "mTLS handshake did not complete within {} iterations",
524                MAX_HANDSHAKE_ITERATIONS
525            );
526        }
527
528        assert!(!client.is_handshaking());
529        assert!(!server.is_handshaking());
530    }
531}