rustls_post_quantum/
lib.rs

1//! This crate provide a [`CryptoProvider`] built on the default aws-lc-rs default provider.
2//!
3//! Features:
4//!
5//! - `aws-lc-rs-unstable`: adds support for three variants of the experimental ML-DSA signature
6//!   algorithm.
7//!
8//! Before rustls 0.23.22, this crate additionally provided support for the ML-KEM key exchange
9//! (both "pure" and hybrid variants), but these have been moved to the rustls crate itself.
10//! In rustls 0.23.22 and later, you can use rustls' `prefer-post-quantum` feature to determine
11//! whether the ML-KEM key exchange is preferred over non-post-quantum key exchanges.
12
13#[cfg(feature = "aws-lc-rs-unstable")]
14use rustls::SignatureScheme;
15use rustls::crypto::CryptoProvider;
16#[cfg(feature = "aws-lc-rs-unstable")]
17use rustls::crypto::WebPkiSupportedAlgorithms;
18pub use rustls::crypto::aws_lc_rs::kx_group::{MLKEM768, X25519MLKEM768};
19#[cfg(feature = "aws-lc-rs-unstable")]
20use webpki::aws_lc_rs as webpki_algs;
21
22pub fn provider() -> CryptoProvider {
23    #[cfg_attr(not(feature = "aws-lc-rs-unstable"), allow(unused_mut))]
24    let mut provider = rustls::crypto::aws_lc_rs::default_provider();
25    #[cfg(feature = "aws-lc-rs-unstable")]
26    {
27        provider.signature_verification_algorithms = SUPPORTED_SIG_ALGS;
28        provider.key_provider = &key_provider::PqAwsLcRs;
29    }
30    provider
31}
32
33#[cfg(feature = "aws-lc-rs-unstable")]
34mod key_provider {
35    use std::fmt::{self, Debug, Formatter};
36    use std::sync::Arc;
37
38    use aws_lc_rs::signature::KeyPair;
39    use aws_lc_rs::unstable::signature::{
40        ML_DSA_44_SIGNING, ML_DSA_65_SIGNING, ML_DSA_87_SIGNING, PqdsaKeyPair,
41        PqdsaSigningAlgorithm,
42    };
43    use rustls::crypto::KeyProvider;
44    use rustls::crypto::aws_lc_rs::sign;
45    use rustls::pki_types::{AlgorithmIdentifier, PrivateKeyDer, SubjectPublicKeyInfoDer, alg_id};
46    use rustls::sign::{Signer, SigningKey, public_key_to_spki};
47    use rustls::{Error, SignatureAlgorithm, SignatureScheme};
48
49    #[derive(Debug)]
50    pub(super) struct PqAwsLcRs;
51
52    impl KeyProvider for PqAwsLcRs {
53        fn load_private_key(
54            &self,
55            key_der: PrivateKeyDer<'static>,
56        ) -> Result<Arc<dyn SigningKey>, Error> {
57            // TODO: support `PqdsaKeyPair::from_raw_private_key()`?
58            if let PrivateKeyDer::Pkcs8(pkcs8) = &key_der {
59                for kind in PqdsaKeyKind::iter() {
60                    match PqdsaKeyPair::from_pkcs8(kind.to_alg(), pkcs8.secret_pkcs8_der()) {
61                        Ok(key_pair) => {
62                            return Ok(Arc::new(PqdsaSigningKey {
63                                kind,
64                                inner: Arc::new(key_pair),
65                            }));
66                        }
67                        Err(_) => continue,
68                    }
69                }
70            }
71
72            match sign::any_supported_type(&key_der) {
73                Ok(key) => Ok(key),
74                Err(_) => Err(Error::General(
75                    "failed to parse private key as ML-DSA, RSA, ECDSA, or EdDSA".into(),
76                )),
77            }
78        }
79
80        fn fips(&self) -> bool {
81            false
82        }
83    }
84
85    struct PqdsaSigningKey {
86        kind: PqdsaKeyKind,
87        inner: Arc<PqdsaKeyPair>,
88    }
89
90    impl SigningKey for PqdsaSigningKey {
91        fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option<Box<dyn Signer>> {
92            if !offered.contains(&self.kind.scheme()) {
93                return None;
94            }
95
96            Some(Box::new(PqdsaSigner {
97                key: self.inner.clone(),
98                kind: self.kind,
99            }))
100        }
101
102        fn public_key(&self) -> Option<SubjectPublicKeyInfoDer<'_>> {
103            Some(public_key_to_spki(
104                &self.kind.alg_id(),
105                self.inner.public_key(),
106            ))
107        }
108
109        // [`SignatureAlgorithm`] is for TLS 1.2, for which ML-DSA is not specified.
110        // Pick a "Reserved for Private Use" value.
111        fn algorithm(&self) -> SignatureAlgorithm {
112            SignatureAlgorithm::Unknown(255)
113        }
114    }
115
116    impl Debug for PqdsaSigningKey {
117        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
118            f.debug_struct("PqdsaSigningKey")
119                .field("scheme", &self.kind.scheme())
120                .finish_non_exhaustive()
121        }
122    }
123
124    struct PqdsaSigner {
125        key: Arc<PqdsaKeyPair>,
126        kind: PqdsaKeyKind,
127    }
128
129    impl Signer for PqdsaSigner {
130        fn sign(&self, message: &[u8]) -> Result<Vec<u8>, Error> {
131            let expected_sig_len = self.key.algorithm().signature_len();
132            let mut sig = vec![0; expected_sig_len];
133            let actual_sig_len = self
134                .key
135                .sign(message, &mut sig)
136                .map_err(|_| Error::General("signing failed".into()))?;
137
138            if actual_sig_len != expected_sig_len {
139                return Err(Error::General("unexpected signature length".into()));
140            }
141
142            Ok(sig)
143        }
144
145        fn scheme(&self) -> SignatureScheme {
146            self.kind.scheme()
147        }
148    }
149
150    impl Debug for PqdsaSigner {
151        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
152            f.debug_struct("PqdsaSigner")
153                .field("scheme", &self.kind.scheme())
154                .finish_non_exhaustive()
155        }
156    }
157
158    #[derive(Clone, Copy)]
159    enum PqdsaKeyKind {
160        MlDsa44,
161        MlDsa65,
162        MlDsa87,
163    }
164
165    impl PqdsaKeyKind {
166        fn iter() -> impl Iterator<Item = Self> {
167            [Self::MlDsa44, Self::MlDsa65, Self::MlDsa87].into_iter()
168        }
169
170        fn to_alg(self) -> &'static PqdsaSigningAlgorithm {
171            match self {
172                Self::MlDsa44 => &ML_DSA_44_SIGNING,
173                Self::MlDsa65 => &ML_DSA_65_SIGNING,
174                Self::MlDsa87 => &ML_DSA_87_SIGNING,
175            }
176        }
177
178        fn scheme(&self) -> SignatureScheme {
179            match self {
180                Self::MlDsa44 => SignatureScheme::ML_DSA_44,
181                Self::MlDsa65 => SignatureScheme::ML_DSA_65,
182                Self::MlDsa87 => SignatureScheme::ML_DSA_87,
183            }
184        }
185
186        fn alg_id(&self) -> AlgorithmIdentifier {
187            match self {
188                Self::MlDsa44 => alg_id::ML_DSA_44,
189                Self::MlDsa65 => alg_id::ML_DSA_65,
190                Self::MlDsa87 => alg_id::ML_DSA_87,
191            }
192        }
193    }
194}
195
196/// Keep in sync with the `SUPPORTED_SIG_ALGS` in `rustls::crypto::aws_lc_rs`.
197#[cfg(feature = "aws-lc-rs-unstable")]
198static SUPPORTED_SIG_ALGS: WebPkiSupportedAlgorithms = WebPkiSupportedAlgorithms {
199    all: &[
200        webpki_algs::ECDSA_P256_SHA256,
201        webpki_algs::ECDSA_P256_SHA384,
202        webpki_algs::ECDSA_P384_SHA256,
203        webpki_algs::ECDSA_P384_SHA384,
204        webpki_algs::ECDSA_P521_SHA256,
205        webpki_algs::ECDSA_P521_SHA384,
206        webpki_algs::ECDSA_P521_SHA512,
207        webpki_algs::ED25519,
208        webpki_algs::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
209        webpki_algs::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
210        webpki_algs::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
211        webpki_algs::RSA_PKCS1_2048_8192_SHA256,
212        webpki_algs::RSA_PKCS1_2048_8192_SHA384,
213        webpki_algs::RSA_PKCS1_2048_8192_SHA512,
214        webpki_algs::RSA_PKCS1_2048_8192_SHA256_ABSENT_PARAMS,
215        webpki_algs::RSA_PKCS1_2048_8192_SHA384_ABSENT_PARAMS,
216        webpki_algs::RSA_PKCS1_2048_8192_SHA512_ABSENT_PARAMS,
217        #[cfg(feature = "aws-lc-rs-unstable")]
218        webpki_algs::ML_DSA_44,
219        #[cfg(feature = "aws-lc-rs-unstable")]
220        webpki_algs::ML_DSA_65,
221        #[cfg(feature = "aws-lc-rs-unstable")]
222        webpki_algs::ML_DSA_87,
223    ],
224    mapping: &[
225        // Note: for TLS1.2 the curve is not fixed by SignatureScheme. For TLS1.3 it is.
226        (
227            SignatureScheme::ECDSA_NISTP384_SHA384,
228            &[
229                webpki_algs::ECDSA_P384_SHA384,
230                webpki_algs::ECDSA_P256_SHA384,
231                webpki_algs::ECDSA_P521_SHA384,
232            ],
233        ),
234        (
235            SignatureScheme::ECDSA_NISTP256_SHA256,
236            &[
237                webpki_algs::ECDSA_P256_SHA256,
238                webpki_algs::ECDSA_P384_SHA256,
239                webpki_algs::ECDSA_P521_SHA256,
240            ],
241        ),
242        (
243            SignatureScheme::ECDSA_NISTP521_SHA512,
244            &[
245                webpki_algs::ECDSA_P521_SHA512,
246                webpki_algs::ECDSA_P384_SHA512,
247                webpki_algs::ECDSA_P256_SHA512,
248            ],
249        ),
250        (SignatureScheme::ED25519, &[webpki_algs::ED25519]),
251        (
252            SignatureScheme::RSA_PSS_SHA512,
253            &[webpki_algs::RSA_PSS_2048_8192_SHA512_LEGACY_KEY],
254        ),
255        (
256            SignatureScheme::RSA_PSS_SHA384,
257            &[webpki_algs::RSA_PSS_2048_8192_SHA384_LEGACY_KEY],
258        ),
259        (
260            SignatureScheme::RSA_PSS_SHA256,
261            &[webpki_algs::RSA_PSS_2048_8192_SHA256_LEGACY_KEY],
262        ),
263        (
264            SignatureScheme::RSA_PKCS1_SHA512,
265            &[webpki_algs::RSA_PKCS1_2048_8192_SHA512],
266        ),
267        (
268            SignatureScheme::RSA_PKCS1_SHA384,
269            &[webpki_algs::RSA_PKCS1_2048_8192_SHA384],
270        ),
271        (
272            SignatureScheme::RSA_PKCS1_SHA256,
273            &[webpki_algs::RSA_PKCS1_2048_8192_SHA256],
274        ),
275        #[cfg(feature = "aws-lc-rs-unstable")]
276        (SignatureScheme::ML_DSA_44, &[webpki_algs::ML_DSA_44]),
277        #[cfg(feature = "aws-lc-rs-unstable")]
278        (SignatureScheme::ML_DSA_65, &[webpki_algs::ML_DSA_65]),
279        #[cfg(feature = "aws-lc-rs-unstable")]
280        (SignatureScheme::ML_DSA_87, &[webpki_algs::ML_DSA_87]),
281    ],
282};
283
284#[cfg(all(test, feature = "aws-lc-rs-unstable"))]
285mod tests {
286    use std::io;
287    use std::ops::DerefMut;
288    use std::sync::Arc;
289
290    use rcgen::{
291        CertificateParams, CertifiedIssuer, ExtendedKeyUsagePurpose, IsCa, KeyPair, KeyUsagePurpose,
292    };
293    use rustls::pki_types::PrivateKeyDer;
294    use rustls::{
295        ClientConfig, ClientConnection, ConnectionCommon, RootCertStore, ServerConfig,
296        ServerConnection, SideData,
297    };
298
299    #[test]
300    fn ml_dsa() {
301        let ca_key = KeyPair::generate_for(&rcgen::PKCS_ML_DSA_44).unwrap();
302        let mut ca_params = CertificateParams::new(vec!["Test CA".into()]).unwrap();
303        ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
304        ca_params.key_usages = vec![
305            KeyUsagePurpose::DigitalSignature,
306            KeyUsagePurpose::KeyCertSign,
307        ];
308        ca_params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
309        let issuer = CertifiedIssuer::self_signed(ca_params, ca_key).unwrap();
310
311        let ee_key = KeyPair::generate_for(&rcgen::PKCS_ML_DSA_87).unwrap();
312        let ee_params = CertificateParams::new(vec!["localhost".into()]).unwrap();
313        let ee_cert = ee_params
314            .signed_by(&ee_key, &issuer)
315            .unwrap();
316
317        let provider = Arc::new(super::provider());
318        let server_config = ServerConfig::builder_with_provider(provider.clone())
319            .with_safe_default_protocol_versions()
320            .unwrap()
321            .with_no_client_auth()
322            .with_single_cert(
323                vec![ee_cert.der().clone()],
324                PrivateKeyDer::try_from(ee_key.serialize_der()).unwrap(),
325            )
326            .unwrap();
327
328        let mut roots = RootCertStore::empty();
329        roots.add(issuer.der().clone()).unwrap();
330        let client_config = ClientConfig::builder_with_provider(provider)
331            .with_safe_default_protocol_versions()
332            .unwrap()
333            .with_root_certificates(roots)
334            .with_no_client_auth();
335
336        let mut client =
337            ClientConnection::new(Arc::new(client_config), "localhost".try_into().unwrap())
338                .unwrap();
339        let mut server = ServerConnection::new(Arc::new(server_config)).unwrap();
340        do_handshake(&mut client, &mut server);
341    }
342
343    // Copied from rustls while rustls-post-quantum depends on an older rustls.
344    fn do_handshake(
345        client: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
346        server: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
347    ) -> (usize, usize) {
348        let (mut to_client, mut to_server) = (0, 0);
349        while server.is_handshaking() || client.is_handshaking() {
350            to_server += transfer(client, server);
351            server.process_new_packets().unwrap();
352            to_client += transfer(server, client);
353            client.process_new_packets().unwrap();
354        }
355        (to_server, to_client)
356    }
357
358    // Copied from rustls-test while rustls-post-quantum depends on an older rustls.
359    fn transfer(
360        left: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
361        right: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
362    ) -> usize {
363        let mut buf = [0u8; 262144];
364        let mut total = 0;
365
366        while left.wants_write() {
367            let sz = {
368                let into_buf: &mut dyn io::Write = &mut &mut buf[..];
369                left.write_tls(into_buf).unwrap()
370            };
371            total += sz;
372            if sz == 0 {
373                return total;
374            }
375
376            let mut offs = 0;
377            loop {
378                let from_buf: &mut dyn io::Read = &mut &buf[offs..sz];
379                offs += right.read_tls(from_buf).unwrap();
380                if sz == offs {
381                    break;
382                }
383            }
384        }
385
386        total
387    }
388}