1#[cfg(feature = "aws-lc-rs-unstable")]
14use rustls::SignatureScheme;
15use rustls::crypto::CryptoProvider;
16#[cfg(feature = "aws-lc-rs-unstable")]
17use rustls::crypto::WebPkiSupportedAlgorithms;
18pub use rustls::crypto::aws_lc_rs::kx_group::{MLKEM768, X25519MLKEM768};
19#[cfg(feature = "aws-lc-rs-unstable")]
20use webpki::aws_lc_rs as webpki_algs;
21
22pub fn provider() -> CryptoProvider {
23 #[cfg_attr(not(feature = "aws-lc-rs-unstable"), allow(unused_mut))]
24 let mut provider = rustls::crypto::aws_lc_rs::default_provider();
25 #[cfg(feature = "aws-lc-rs-unstable")]
26 {
27 provider.signature_verification_algorithms = SUPPORTED_SIG_ALGS;
28 provider.key_provider = &key_provider::PqAwsLcRs;
29 }
30 provider
31}
32
33#[cfg(feature = "aws-lc-rs-unstable")]
34mod key_provider {
35 use std::fmt::{self, Debug, Formatter};
36 use std::sync::Arc;
37
38 use aws_lc_rs::signature::KeyPair;
39 use aws_lc_rs::unstable::signature::{
40 ML_DSA_44_SIGNING, ML_DSA_65_SIGNING, ML_DSA_87_SIGNING, PqdsaKeyPair,
41 PqdsaSigningAlgorithm,
42 };
43 use rustls::crypto::KeyProvider;
44 use rustls::crypto::aws_lc_rs::sign;
45 use rustls::pki_types::{AlgorithmIdentifier, PrivateKeyDer, SubjectPublicKeyInfoDer, alg_id};
46 use rustls::sign::{Signer, SigningKey, public_key_to_spki};
47 use rustls::{Error, SignatureAlgorithm, SignatureScheme};
48
49 #[derive(Debug)]
50 pub(super) struct PqAwsLcRs;
51
52 impl KeyProvider for PqAwsLcRs {
53 fn load_private_key(
54 &self,
55 key_der: PrivateKeyDer<'static>,
56 ) -> Result<Arc<dyn SigningKey>, Error> {
57 if let PrivateKeyDer::Pkcs8(pkcs8) = &key_der {
59 for kind in PqdsaKeyKind::iter() {
60 match PqdsaKeyPair::from_pkcs8(kind.to_alg(), pkcs8.secret_pkcs8_der()) {
61 Ok(key_pair) => {
62 return Ok(Arc::new(PqdsaSigningKey {
63 kind,
64 inner: Arc::new(key_pair),
65 }));
66 }
67 Err(_) => continue,
68 }
69 }
70 }
71
72 match sign::any_supported_type(&key_der) {
73 Ok(key) => Ok(key),
74 Err(_) => Err(Error::General(
75 "failed to parse private key as ML-DSA, RSA, ECDSA, or EdDSA".into(),
76 )),
77 }
78 }
79
80 fn fips(&self) -> bool {
81 false
82 }
83 }
84
85 struct PqdsaSigningKey {
86 kind: PqdsaKeyKind,
87 inner: Arc<PqdsaKeyPair>,
88 }
89
90 impl SigningKey for PqdsaSigningKey {
91 fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option<Box<dyn Signer>> {
92 if !offered.contains(&self.kind.scheme()) {
93 return None;
94 }
95
96 Some(Box::new(PqdsaSigner {
97 key: self.inner.clone(),
98 kind: self.kind,
99 }))
100 }
101
102 fn public_key(&self) -> Option<SubjectPublicKeyInfoDer<'_>> {
103 Some(public_key_to_spki(
104 &self.kind.alg_id(),
105 self.inner.public_key(),
106 ))
107 }
108
109 fn algorithm(&self) -> SignatureAlgorithm {
112 SignatureAlgorithm::Unknown(255)
113 }
114 }
115
116 impl Debug for PqdsaSigningKey {
117 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
118 f.debug_struct("PqdsaSigningKey")
119 .field("scheme", &self.kind.scheme())
120 .finish_non_exhaustive()
121 }
122 }
123
124 struct PqdsaSigner {
125 key: Arc<PqdsaKeyPair>,
126 kind: PqdsaKeyKind,
127 }
128
129 impl Signer for PqdsaSigner {
130 fn sign(&self, message: &[u8]) -> Result<Vec<u8>, Error> {
131 let expected_sig_len = self.key.algorithm().signature_len();
132 let mut sig = vec![0; expected_sig_len];
133 let actual_sig_len = self
134 .key
135 .sign(message, &mut sig)
136 .map_err(|_| Error::General("signing failed".into()))?;
137
138 if actual_sig_len != expected_sig_len {
139 return Err(Error::General("unexpected signature length".into()));
140 }
141
142 Ok(sig)
143 }
144
145 fn scheme(&self) -> SignatureScheme {
146 self.kind.scheme()
147 }
148 }
149
150 impl Debug for PqdsaSigner {
151 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
152 f.debug_struct("PqdsaSigner")
153 .field("scheme", &self.kind.scheme())
154 .finish_non_exhaustive()
155 }
156 }
157
158 #[derive(Clone, Copy)]
159 enum PqdsaKeyKind {
160 MlDsa44,
161 MlDsa65,
162 MlDsa87,
163 }
164
165 impl PqdsaKeyKind {
166 fn iter() -> impl Iterator<Item = Self> {
167 [Self::MlDsa44, Self::MlDsa65, Self::MlDsa87].into_iter()
168 }
169
170 fn to_alg(self) -> &'static PqdsaSigningAlgorithm {
171 match self {
172 Self::MlDsa44 => &ML_DSA_44_SIGNING,
173 Self::MlDsa65 => &ML_DSA_65_SIGNING,
174 Self::MlDsa87 => &ML_DSA_87_SIGNING,
175 }
176 }
177
178 fn scheme(&self) -> SignatureScheme {
179 match self {
180 Self::MlDsa44 => SignatureScheme::ML_DSA_44,
181 Self::MlDsa65 => SignatureScheme::ML_DSA_65,
182 Self::MlDsa87 => SignatureScheme::ML_DSA_87,
183 }
184 }
185
186 fn alg_id(&self) -> AlgorithmIdentifier {
187 match self {
188 Self::MlDsa44 => alg_id::ML_DSA_44,
189 Self::MlDsa65 => alg_id::ML_DSA_65,
190 Self::MlDsa87 => alg_id::ML_DSA_87,
191 }
192 }
193 }
194}
195
196#[cfg(feature = "aws-lc-rs-unstable")]
198static SUPPORTED_SIG_ALGS: WebPkiSupportedAlgorithms = WebPkiSupportedAlgorithms {
199 all: &[
200 webpki_algs::ECDSA_P256_SHA256,
201 webpki_algs::ECDSA_P256_SHA384,
202 webpki_algs::ECDSA_P384_SHA256,
203 webpki_algs::ECDSA_P384_SHA384,
204 webpki_algs::ECDSA_P521_SHA256,
205 webpki_algs::ECDSA_P521_SHA384,
206 webpki_algs::ECDSA_P521_SHA512,
207 webpki_algs::ED25519,
208 webpki_algs::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
209 webpki_algs::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
210 webpki_algs::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
211 webpki_algs::RSA_PKCS1_2048_8192_SHA256,
212 webpki_algs::RSA_PKCS1_2048_8192_SHA384,
213 webpki_algs::RSA_PKCS1_2048_8192_SHA512,
214 webpki_algs::RSA_PKCS1_2048_8192_SHA256_ABSENT_PARAMS,
215 webpki_algs::RSA_PKCS1_2048_8192_SHA384_ABSENT_PARAMS,
216 webpki_algs::RSA_PKCS1_2048_8192_SHA512_ABSENT_PARAMS,
217 #[cfg(feature = "aws-lc-rs-unstable")]
218 webpki_algs::ML_DSA_44,
219 #[cfg(feature = "aws-lc-rs-unstable")]
220 webpki_algs::ML_DSA_65,
221 #[cfg(feature = "aws-lc-rs-unstable")]
222 webpki_algs::ML_DSA_87,
223 ],
224 mapping: &[
225 (
227 SignatureScheme::ECDSA_NISTP384_SHA384,
228 &[
229 webpki_algs::ECDSA_P384_SHA384,
230 webpki_algs::ECDSA_P256_SHA384,
231 webpki_algs::ECDSA_P521_SHA384,
232 ],
233 ),
234 (
235 SignatureScheme::ECDSA_NISTP256_SHA256,
236 &[
237 webpki_algs::ECDSA_P256_SHA256,
238 webpki_algs::ECDSA_P384_SHA256,
239 webpki_algs::ECDSA_P521_SHA256,
240 ],
241 ),
242 (
243 SignatureScheme::ECDSA_NISTP521_SHA512,
244 &[
245 webpki_algs::ECDSA_P521_SHA512,
246 webpki_algs::ECDSA_P384_SHA512,
247 webpki_algs::ECDSA_P256_SHA512,
248 ],
249 ),
250 (SignatureScheme::ED25519, &[webpki_algs::ED25519]),
251 (
252 SignatureScheme::RSA_PSS_SHA512,
253 &[webpki_algs::RSA_PSS_2048_8192_SHA512_LEGACY_KEY],
254 ),
255 (
256 SignatureScheme::RSA_PSS_SHA384,
257 &[webpki_algs::RSA_PSS_2048_8192_SHA384_LEGACY_KEY],
258 ),
259 (
260 SignatureScheme::RSA_PSS_SHA256,
261 &[webpki_algs::RSA_PSS_2048_8192_SHA256_LEGACY_KEY],
262 ),
263 (
264 SignatureScheme::RSA_PKCS1_SHA512,
265 &[webpki_algs::RSA_PKCS1_2048_8192_SHA512],
266 ),
267 (
268 SignatureScheme::RSA_PKCS1_SHA384,
269 &[webpki_algs::RSA_PKCS1_2048_8192_SHA384],
270 ),
271 (
272 SignatureScheme::RSA_PKCS1_SHA256,
273 &[webpki_algs::RSA_PKCS1_2048_8192_SHA256],
274 ),
275 #[cfg(feature = "aws-lc-rs-unstable")]
276 (SignatureScheme::ML_DSA_44, &[webpki_algs::ML_DSA_44]),
277 #[cfg(feature = "aws-lc-rs-unstable")]
278 (SignatureScheme::ML_DSA_65, &[webpki_algs::ML_DSA_65]),
279 #[cfg(feature = "aws-lc-rs-unstable")]
280 (SignatureScheme::ML_DSA_87, &[webpki_algs::ML_DSA_87]),
281 ],
282};
283
284#[cfg(all(test, feature = "aws-lc-rs-unstable"))]
285mod tests {
286 use std::io;
287 use std::ops::DerefMut;
288 use std::sync::Arc;
289
290 use rcgen::{
291 CertificateParams, CertifiedIssuer, ExtendedKeyUsagePurpose, IsCa, KeyPair, KeyUsagePurpose,
292 };
293 use rustls::pki_types::PrivateKeyDer;
294 use rustls::{
295 ClientConfig, ClientConnection, ConnectionCommon, RootCertStore, ServerConfig,
296 ServerConnection, SideData,
297 };
298
299 #[test]
300 fn ml_dsa() {
301 let ca_key = KeyPair::generate_for(&rcgen::PKCS_ML_DSA_44).unwrap();
302 let mut ca_params = CertificateParams::new(vec!["Test CA".into()]).unwrap();
303 ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
304 ca_params.key_usages = vec![
305 KeyUsagePurpose::DigitalSignature,
306 KeyUsagePurpose::KeyCertSign,
307 ];
308 ca_params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
309 let issuer = CertifiedIssuer::self_signed(ca_params, ca_key).unwrap();
310
311 let ee_key = KeyPair::generate_for(&rcgen::PKCS_ML_DSA_87).unwrap();
312 let ee_params = CertificateParams::new(vec!["localhost".into()]).unwrap();
313 let ee_cert = ee_params
314 .signed_by(&ee_key, &issuer)
315 .unwrap();
316
317 let provider = Arc::new(super::provider());
318 let server_config = ServerConfig::builder_with_provider(provider.clone())
319 .with_safe_default_protocol_versions()
320 .unwrap()
321 .with_no_client_auth()
322 .with_single_cert(
323 vec![ee_cert.der().clone()],
324 PrivateKeyDer::try_from(ee_key.serialize_der()).unwrap(),
325 )
326 .unwrap();
327
328 let mut roots = RootCertStore::empty();
329 roots.add(issuer.der().clone()).unwrap();
330 let client_config = ClientConfig::builder_with_provider(provider)
331 .with_safe_default_protocol_versions()
332 .unwrap()
333 .with_root_certificates(roots)
334 .with_no_client_auth();
335
336 let mut client =
337 ClientConnection::new(Arc::new(client_config), "localhost".try_into().unwrap())
338 .unwrap();
339 let mut server = ServerConnection::new(Arc::new(server_config)).unwrap();
340 do_handshake(&mut client, &mut server);
341 }
342
343 fn do_handshake(
345 client: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
346 server: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
347 ) -> (usize, usize) {
348 let (mut to_client, mut to_server) = (0, 0);
349 while server.is_handshaking() || client.is_handshaking() {
350 to_server += transfer(client, server);
351 server.process_new_packets().unwrap();
352 to_client += transfer(server, client);
353 client.process_new_packets().unwrap();
354 }
355 (to_server, to_client)
356 }
357
358 fn transfer(
360 left: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
361 right: &mut impl DerefMut<Target = ConnectionCommon<impl SideData>>,
362 ) -> usize {
363 let mut buf = [0u8; 262144];
364 let mut total = 0;
365
366 while left.wants_write() {
367 let sz = {
368 let into_buf: &mut dyn io::Write = &mut &mut buf[..];
369 left.write_tls(into_buf).unwrap()
370 };
371 total += sz;
372 if sz == 0 {
373 return total;
374 }
375
376 let mut offs = 0;
377 loop {
378 let from_buf: &mut dyn io::Read = &mut &buf[offs..sz];
379 offs += right.read_tls(from_buf).unwrap();
380 if sz == offs {
381 break;
382 }
383 }
384 }
385
386 total
387 }
388}