1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5 html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6 html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8#![deny(missing_docs)]
9#![warn(clippy::pedantic)]
10
11#![cfg_attr(feature = "getrandom", doc = "```")]
18#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
19pub use kem::{
29 self, Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, KemParams, Key, KeyExport,
30 KeyInit, KeySizeUser, TryKeyInit,
31};
32
33use ml_kem::{
34 B32, EncodedSizeUser, KemCore, MlKem768, MlKem768Params,
35 array::{
36 Array, ArrayN, AsArrayRef,
37 sizes::{U32, U1120, U1216},
38 },
39};
40use rand_core::{CryptoRng, TryCryptoRng, TryRngCore};
41use sha3::{
42 Sha3_256, Shake256, Shake256Reader,
43 digest::{ExtendableOutput, XofReader},
44};
45use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret};
46
47#[cfg(feature = "zeroize")]
48use zeroize::{Zeroize, ZeroizeOnDrop};
49
50type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey<MlKem768Params>;
51type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey<MlKem768Params>;
52
53const X_WING_LABEL: &[u8; 6] = br"\.//^\";
54
55pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
57pub const DECAPSULATION_KEY_SIZE: usize = 32;
59pub const CIPHERTEXT_SIZE: usize = 1120;
61
62pub type Ciphertext = Array<u8, U1120>;
64pub type SharedSecret = Array<u8, U32>;
66
67#[derive(Clone, Eq, PartialEq)]
79pub struct EncapsulationKey {
80 pk_m: MlKem768EncapsulationKey,
81 pk_x: PublicKey,
82}
83
84impl Encapsulate for EncapsulationKey {
85 fn encapsulate_with_rng<R: TryCryptoRng + ?Sized>(
86 &self,
87 rng: &mut R,
88 ) -> Result<(Ciphertext, SharedSecret), R::Error> {
89 let (ct_m, ss_m) = self.pk_m.encapsulate_with_rng(rng)?;
91
92 let ek_x = EphemeralSecret::random_from_rng(&mut rng.unwrap_err());
93 let ct_x = PublicKey::from(&ek_x);
95 let ss_x = ek_x.diffie_hellman(&self.pk_x);
97
98 let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
99 let ct = CiphertextMessage { ct_m, ct_x };
100 Ok((ct.into(), ss))
101 }
102}
103
104impl KemParams for EncapsulationKey {
105 type CiphertextSize = U1120;
106 type SharedSecretSize = U32;
107}
108
109impl KeySizeUser for EncapsulationKey {
110 type KeySize = U1216;
111}
112
113impl KeyExport for EncapsulationKey {
114 fn to_bytes(&self) -> Key<Self> {
115 let mut key_bytes = Key::<Self>::default();
116 let (m, x) = key_bytes.split_at_mut(1184);
117 m.copy_from_slice(&self.pk_m.to_encoded_bytes());
118 x.copy_from_slice(self.pk_x.as_bytes());
119 key_bytes
120 }
121}
122
123impl TryKeyInit for EncapsulationKey {
124 fn new(key_bytes: &Key<Self>) -> Result<Self, InvalidKey> {
125 let mut pk_m = [0; 1184];
126 pk_m.copy_from_slice(&key_bytes[0..1184]);
127 let pk_m =
128 MlKem768EncapsulationKey::from_encoded_bytes(&pk_m.into()).map_err(|_| InvalidKey)?;
129
130 let mut pk_x = [0; 32];
131 pk_x.copy_from_slice(&key_bytes[1184..]);
132 let pk_x = PublicKey::from(pk_x);
133 Ok(EncapsulationKey { pk_m, pk_x })
134 }
135}
136
137impl TryFrom<&[u8]> for EncapsulationKey {
138 type Error = InvalidKey;
139
140 fn try_from(key_bytes: &[u8]) -> Result<Self, InvalidKey> {
141 Self::new_from_slice(key_bytes)
142 }
143}
144
145#[derive(Clone)]
147pub struct DecapsulationKey {
148 sk: [u8; DECAPSULATION_KEY_SIZE],
149 ek: EncapsulationKey,
150}
151
152impl DecapsulationKey {
153 #[must_use]
155 pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
156 &self.sk
157 }
158}
159
160impl Decapsulate for DecapsulationKey {
161 #[allow(clippy::similar_names)] fn decapsulate(&self, ct: &Ciphertext) -> SharedSecret {
163 let ct = CiphertextMessage::from(ct);
164 let (sk_m, sk_x, _pk_m, pk_x) = expand_key(&self.sk);
165
166 let ss_m = sk_m.decapsulate(&ct.ct_m);
167
168 let ss_x = sk_x.diffie_hellman(&ct.ct_x);
170
171 combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x)
172 }
173}
174
175impl Decapsulator for DecapsulationKey {
176 type Encapsulator = EncapsulationKey;
177
178 fn encapsulator(&self) -> &EncapsulationKey {
179 &self.ek
180 }
181}
182
183impl Drop for DecapsulationKey {
184 fn drop(&mut self) {
185 #[cfg(feature = "zeroize")]
186 self.sk.zeroize();
187 }
188}
189
190impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
191 fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
192 DecapsulationKey::new(sk.as_array_ref())
193 }
194}
195
196impl Generate for DecapsulationKey {
197 fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRngCore>::Error>
198 where
199 R: TryCryptoRng + ?Sized,
200 {
201 <[u8; DECAPSULATION_KEY_SIZE]>::try_generate_from_rng(rng).map(Into::into)
202 }
203}
204
205impl KeySizeUser for DecapsulationKey {
206 type KeySize = U32;
207}
208
209impl KeyInit for DecapsulationKey {
210 fn new(key: &ArrayN<u8, 32>) -> Self {
211 let (_sk_m, _sk_x, pk_m, pk_x) = expand_key(key.as_ref());
212 let ek = EncapsulationKey { pk_m, pk_x };
213 Self { sk: key.0, ek }
214 }
215}
216
217#[cfg(feature = "zeroize")]
218impl ZeroizeOnDrop for DecapsulationKey {}
219
220fn expand_key(
221 sk: &[u8; DECAPSULATION_KEY_SIZE],
222) -> (
223 MlKem768DecapsulationKey,
224 StaticSecret,
225 MlKem768EncapsulationKey,
226 PublicKey,
227) {
228 use sha3::digest::Update;
229 let mut shaker = Shake256::default();
230 shaker.update(sk);
231 let mut expanded: Shake256Reader = shaker.finalize_xof();
232
233 let seed = read_from(&mut expanded).into();
234 let (sk_m, pk_m) = MlKem768::from_seed(seed);
235
236 let sk_x = read_from(&mut expanded);
237 let sk_x = StaticSecret::from(sk_x);
238 let pk_x = PublicKey::from(&sk_x);
239
240 (sk_m, sk_x, pk_m, pk_x)
241}
242
243#[derive(Clone, PartialEq, Eq)]
245pub struct CiphertextMessage {
246 ct_m: ArrayN<u8, 1088>,
247 ct_x: PublicKey,
248}
249
250impl CiphertextMessage {
251 #[must_use]
254 pub fn to_bytes(&self) -> Ciphertext {
255 let mut buffer = Ciphertext::default();
256 buffer[0..1088].copy_from_slice(&self.ct_m);
257 buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
258 buffer
259 }
260}
261
262impl From<&Ciphertext> for CiphertextMessage {
263 fn from(value: &Ciphertext) -> Self {
264 let mut ct_m = [0; 1088];
265 ct_m.copy_from_slice(&value[0..1088]);
266 let mut ct_x = [0; 32];
267 ct_x.copy_from_slice(&value[1088..]);
268
269 CiphertextMessage {
270 ct_m: ct_m.into(),
271 ct_x: ct_x.into(),
272 }
273 }
274}
275
276impl From<&CiphertextMessage> for Ciphertext {
277 #[inline]
278 fn from(msg: &CiphertextMessage) -> Self {
279 msg.to_bytes()
280 }
281}
282
283impl From<CiphertextMessage> for Ciphertext {
284 #[inline]
285 fn from(msg: CiphertextMessage) -> Self {
286 Self::from(&msg)
287 }
288}
289
290#[cfg(feature = "getrandom")]
292#[must_use]
293pub fn generate_key_pair() -> (DecapsulationKey, EncapsulationKey) {
294 let sk = DecapsulationKey::generate();
295 let pk = sk.encapsulator().clone();
296 (sk, pk)
297}
298
299pub fn generate_key_pair_from_rng<R: CryptoRng + ?Sized>(
301 rng: &mut R,
302) -> (DecapsulationKey, EncapsulationKey) {
303 let sk = DecapsulationKey::generate_from_rng(rng);
304 let pk = sk.encapsulator().clone();
305 (sk, pk)
306}
307
308fn combiner(
309 ss_m: &B32,
310 ss_x: &x25519_dalek::SharedSecret,
311 ct_x: &PublicKey,
312 pk_x: &PublicKey,
313) -> SharedSecret {
314 use sha3::Digest;
315
316 let mut hasher = Sha3_256::new();
317 hasher.update(ss_m);
318 hasher.update(ss_x);
319 hasher.update(ct_x);
320 hasher.update(pk_x.as_bytes());
321 hasher.update(X_WING_LABEL);
322 hasher.finalize()
323}
324
325fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
326 let mut data = [0; N];
327 reader.read(&mut data);
328 data
329}
330
331#[cfg(test)]
332mod tests {
333 use core::convert::Infallible;
334 use getrandom::SysRng;
335 use ml_kem::array::Array;
336 use rand_core::{TryCryptoRng, TryRngCore, utils};
337 use serde::Deserialize;
338
339 use super::*;
340
341 pub(crate) struct SeedRng {
342 pub(crate) seed: Vec<u8>,
343 }
344
345 impl SeedRng {
346 fn new(seed: Vec<u8>) -> SeedRng {
347 SeedRng { seed }
348 }
349 }
350
351 impl TryRngCore for SeedRng {
352 type Error = Infallible;
353
354 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
355 utils::next_word_via_fill(self)
356 }
357
358 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
359 utils::next_word_via_fill(self)
360 }
361
362 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
363 dest.copy_from_slice(&self.seed[0..dest.len()]);
364 self.seed.drain(0..dest.len());
365 Ok(())
366 }
367 }
368
369 #[derive(Deserialize)]
370 struct TestVector {
371 #[serde(deserialize_with = "hex::serde::deserialize")]
372 seed: Vec<u8>,
373
374 #[serde(deserialize_with = "hex::serde::deserialize")]
375 eseed: Vec<u8>,
376
377 #[serde(deserialize_with = "hex::serde::deserialize")]
378 ss: [u8; 32],
379
380 #[serde(deserialize_with = "hex::serde::deserialize")]
381 sk: [u8; 32],
382
383 #[serde(deserialize_with = "hex::serde::deserialize")]
384 pk: Vec<u8>, #[serde(deserialize_with = "hex::serde::deserialize")]
387 ct: Vec<u8>, }
389
390 impl TryCryptoRng for SeedRng {}
391
392 #[test]
394 fn rfc_test_vectors() {
395 let test_vectors =
396 serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
397
398 for test_vector in test_vectors {
399 run_test(test_vector);
400 }
401 }
402
403 fn run_test(test_vector: TestVector) {
404 let mut seed = SeedRng::new(test_vector.seed);
405 let (sk, pk) = generate_key_pair_from_rng(&mut seed);
406
407 assert_eq!(sk.as_bytes(), &test_vector.sk);
408 assert_eq!(&*pk.to_bytes(), test_vector.pk.as_slice());
409
410 let mut eseed = SeedRng::new(test_vector.eseed);
411 let (ct, ss) = pk.encapsulate_with_rng(&mut eseed).unwrap();
412
413 assert_eq!(ss, test_vector.ss);
414 assert_eq!(&*ct, test_vector.ct.as_slice());
415
416 let ss = sk.decapsulate(&ct);
417 assert_eq!(ss, test_vector.ss);
418 }
419
420 #[test]
421 fn ciphertext_serialize() {
422 let mut rng = SysRng.unwrap_err();
423
424 let ct_a = CiphertextMessage {
425 ct_m: Array::generate_from_rng(&mut rng),
426 ct_x: <[u8; 32]>::generate_from_rng(&mut rng).into(),
427 };
428
429 let bytes = ct_a.to_bytes();
430 let ct_b = CiphertextMessage::from(&bytes);
431
432 assert!(ct_a == ct_b);
433 }
434
435 #[test]
436 fn key_serialize() {
437 let sk = DecapsulationKey::generate_from_rng(&mut SysRng.unwrap_err());
438 let pk = sk.encapsulator().clone();
439
440 let sk_bytes = sk.as_bytes();
441 let pk_bytes = pk.to_bytes();
442
443 let sk_b = DecapsulationKey::from(*sk_bytes);
444 let pk_b = EncapsulationKey::new(&pk_bytes).unwrap();
445
446 assert_eq!(sk.sk, sk_b.sk);
447 assert!(pk == pk_b);
448 }
449}