1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5 html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6 html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8#![deny(missing_docs)]
9#![warn(clippy::pedantic)]
10
11#![cfg_attr(feature = "getrandom", doc = "```")]
18#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
19pub use kem::{self, Decapsulate, Encapsulate, Generate};
29
30use core::convert::Infallible;
31use ml_kem::{
32 B32, EncodedSizeUser, Error, KemCore, MlKem768, MlKem768Params,
33 array::{ArrayN, typenum::consts::U32},
34};
35use rand_core::{CryptoRng, TryCryptoRng, TryRngCore};
36use sha3::{
37 Sha3_256, Shake256, Shake256Reader,
38 digest::{ExtendableOutput, XofReader},
39};
40use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret};
41
42#[cfg(feature = "zeroize")]
43use zeroize::{Zeroize, ZeroizeOnDrop};
44
45type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey<MlKem768Params>;
46type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey<MlKem768Params>;
47
48const X_WING_LABEL: &[u8; 6] = br"\.//^\";
49
50pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
52pub const DECAPSULATION_KEY_SIZE: usize = 32;
54pub const CIPHERTEXT_SIZE: usize = 1120;
56
57pub type SharedSecret = [u8; 32];
59
60#[derive(Clone, PartialEq)]
72pub struct EncapsulationKey {
73 pk_m: MlKem768EncapsulationKey,
74 pk_x: PublicKey,
75}
76
77impl Encapsulate<Ciphertext, SharedSecret> for EncapsulationKey {
78 type Error = Error;
79
80 fn encapsulate_with_rng<R: TryCryptoRng + ?Sized>(
81 &self,
82 rng: &mut R,
83 ) -> Result<(Ciphertext, SharedSecret), Self::Error> {
84 let (ct_m, ss_m) = self.pk_m.encapsulate_with_rng(rng)?;
86
87 let ek_x = EphemeralSecret::random_from_rng(&mut rng.unwrap_mut());
88 let ct_x = PublicKey::from(&ek_x);
90 let ss_x = ek_x.diffie_hellman(&self.pk_x);
92
93 let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
94 let ct = Ciphertext { ct_m, ct_x };
95 Ok((ct, ss))
96 }
97}
98
99impl EncapsulationKey {
100 #[must_use]
103 pub fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_SIZE] {
104 let mut buffer = [0u8; ENCAPSULATION_KEY_SIZE];
105 buffer[0..1184].copy_from_slice(&self.pk_m.to_bytes());
106 buffer[1184..1216].copy_from_slice(self.pk_x.as_bytes());
107 buffer
108 }
109}
110
111impl TryFrom<&[u8; ENCAPSULATION_KEY_SIZE]> for EncapsulationKey {
112 type Error = ml_kem::Error;
113
114 fn try_from(value: &[u8; ENCAPSULATION_KEY_SIZE]) -> Result<Self, ml_kem::Error> {
115 let mut pk_m = [0; 1184];
116 pk_m.copy_from_slice(&value[0..1184]);
117 let pk_m = MlKem768EncapsulationKey::from_bytes(&pk_m.into())?;
118
119 let mut pk_x = [0; 32];
120 pk_x.copy_from_slice(&value[1184..]);
121 let pk_x = PublicKey::from(pk_x);
122 Ok(EncapsulationKey { pk_m, pk_x })
123 }
124}
125
126#[derive(Clone)]
128#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
129#[cfg_attr(test, derive(PartialEq, Eq))]
130pub struct DecapsulationKey {
131 sk: [u8; DECAPSULATION_KEY_SIZE],
132}
133
134impl Decapsulate<Ciphertext, SharedSecret> for DecapsulationKey {
135 type Encapsulator = EncapsulationKey;
136 type Error = Infallible;
137
138 #[allow(clippy::similar_names)] fn decapsulate(&self, ct: &Ciphertext) -> Result<SharedSecret, Self::Error> {
140 let (sk_m, sk_x, _pk_m, pk_x) = self.expand_key();
141
142 let ss_m = sk_m.decapsulate(&ct.ct_m)?;
143
144 let ss_x = sk_x.diffie_hellman(&ct.ct_x);
146
147 let ss = combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x);
148 Ok(ss)
149 }
150
151 fn encapsulator(&self) -> EncapsulationKey {
152 self.encapsulation_key()
153 }
154}
155
156impl ::kem::KeySizeUser for DecapsulationKey {
157 type KeySize = U32;
158}
159
160impl ::kem::KeyInit for DecapsulationKey {
161 fn new(key: &ArrayN<u8, 32>) -> Self {
162 Self { sk: key.0 }
163 }
164}
165
166impl DecapsulationKey {
167 #[must_use]
169 pub fn encapsulation_key(&self) -> EncapsulationKey {
170 let (_sk_m, _sk_x, pk_m, pk_x) = self.expand_key();
171 EncapsulationKey { pk_m, pk_x }
172 }
173
174 fn expand_key(
175 &self,
176 ) -> (
177 MlKem768DecapsulationKey,
178 StaticSecret,
179 MlKem768EncapsulationKey,
180 PublicKey,
181 ) {
182 use sha3::digest::Update;
183 let mut shaker = Shake256::default();
184 shaker.update(&self.sk);
185 let mut expanded: Shake256Reader = shaker.finalize_xof();
186
187 let seed = read_from(&mut expanded).into();
188 let (sk_m, pk_m) = MlKem768::from_seed(seed);
189
190 let sk_x = read_from(&mut expanded);
191 let sk_x = StaticSecret::from(sk_x);
192 let pk_x = PublicKey::from(&sk_x);
193
194 (sk_m, sk_x, pk_m, pk_x)
195 }
196
197 #[must_use]
199 pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
200 &self.sk
201 }
202}
203
204impl Generate for DecapsulationKey {
205 fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRngCore>::Error>
206 where
207 R: TryCryptoRng + ?Sized,
208 {
209 <[u8; DECAPSULATION_KEY_SIZE]>::try_generate_from_rng(rng).map(Into::into)
210 }
211}
212
213impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
214 fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
215 DecapsulationKey { sk }
216 }
217}
218
219#[derive(Clone, PartialEq, Eq)]
221#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
222pub struct Ciphertext {
223 ct_m: ArrayN<u8, 1088>,
224 ct_x: PublicKey,
225}
226
227impl Ciphertext {
228 #[must_use]
231 pub fn to_bytes(&self) -> [u8; CIPHERTEXT_SIZE] {
232 let mut buffer = [0; CIPHERTEXT_SIZE];
233 buffer[0..1088].copy_from_slice(&self.ct_m);
234 buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
235 buffer
236 }
237}
238
239impl From<&[u8; CIPHERTEXT_SIZE]> for Ciphertext {
240 fn from(value: &[u8; CIPHERTEXT_SIZE]) -> Self {
241 let mut ct_m = [0; 1088];
242 ct_m.copy_from_slice(&value[0..1088]);
243 let mut ct_x = [0; 32];
244 ct_x.copy_from_slice(&value[1088..]);
245
246 Ciphertext {
247 ct_m: ct_m.into(),
248 ct_x: ct_x.into(),
249 }
250 }
251}
252
253#[cfg(feature = "getrandom")]
255#[must_use]
256pub fn generate_key_pair() -> (DecapsulationKey, EncapsulationKey) {
257 let sk = DecapsulationKey::generate();
258 let pk = sk.encapsulation_key();
259 (sk, pk)
260}
261
262pub fn generate_key_pair_from_rng<R: CryptoRng + ?Sized>(
264 rng: &mut R,
265) -> (DecapsulationKey, EncapsulationKey) {
266 let sk = DecapsulationKey::generate_from_rng(rng);
267 let pk = sk.encapsulation_key();
268 (sk, pk)
269}
270
271fn combiner(
272 ss_m: &B32,
273 ss_x: &x25519_dalek::SharedSecret,
274 ct_x: &PublicKey,
275 pk_x: &PublicKey,
276) -> SharedSecret {
277 use sha3::Digest;
278
279 let mut hasher = Sha3_256::new();
280 hasher.update(ss_m);
281 hasher.update(ss_x);
282 hasher.update(ct_x);
283 hasher.update(pk_x.as_bytes());
284 hasher.update(X_WING_LABEL);
285 hasher.finalize().into()
286}
287
288fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
289 let mut data = [0; N];
290 reader.read(&mut data);
291 data
292}
293
294#[cfg(test)]
295mod tests {
296 use getrandom::SysRng;
297 use ml_kem::array::Array;
298 use rand_core::{CryptoRng, RngCore, TryRngCore, utils};
299 use serde::Deserialize;
300
301 use super::*;
302
303 pub(crate) struct SeedRng {
304 pub(crate) seed: Vec<u8>,
305 }
306
307 impl SeedRng {
308 fn new(seed: Vec<u8>) -> SeedRng {
309 SeedRng { seed }
310 }
311 }
312
313 impl RngCore for SeedRng {
314 fn next_u32(&mut self) -> u32 {
315 utils::next_word_via_fill(self)
316 }
317
318 fn next_u64(&mut self) -> u64 {
319 utils::next_word_via_fill(self)
320 }
321
322 fn fill_bytes(&mut self, dest: &mut [u8]) {
323 dest.copy_from_slice(&self.seed[0..dest.len()]);
324 self.seed.drain(0..dest.len());
325 }
326 }
327
328 #[derive(Deserialize)]
329 struct TestVector {
330 #[serde(deserialize_with = "hex::serde::deserialize")]
331 seed: Vec<u8>,
332
333 #[serde(deserialize_with = "hex::serde::deserialize")]
334 eseed: Vec<u8>,
335
336 #[serde(deserialize_with = "hex::serde::deserialize")]
337 ss: [u8; 32],
338
339 #[serde(deserialize_with = "hex::serde::deserialize")]
340 sk: [u8; 32],
341
342 #[serde(deserialize_with = "hex::serde::deserialize")]
343 pk: Vec<u8>, #[serde(deserialize_with = "hex::serde::deserialize")]
346 ct: Vec<u8>, }
348
349 impl CryptoRng for SeedRng {}
350
351 #[test]
353 fn rfc_test_vectors() {
354 let test_vectors =
355 serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
356
357 for test_vector in test_vectors {
358 run_test(test_vector);
359 }
360 }
361
362 fn run_test(test_vector: TestVector) {
363 let mut seed = SeedRng::new(test_vector.seed);
364 let (sk, pk) = generate_key_pair_from_rng(&mut seed);
365
366 assert_eq!(sk.as_bytes(), &test_vector.sk);
367 assert_eq!(&pk.to_bytes(), test_vector.pk.as_slice());
368
369 let mut eseed = SeedRng::new(test_vector.eseed);
370 let (ct, ss) = pk.encapsulate_with_rng(&mut eseed).unwrap();
371
372 assert_eq!(ss, test_vector.ss);
373 assert_eq!(&ct.to_bytes(), test_vector.ct.as_slice());
374
375 let ss = sk.decapsulate(&ct).unwrap();
376 assert_eq!(ss, test_vector.ss);
377 }
378
379 #[test]
380 fn ciphertext_serialize() {
381 let mut rng = SysRng.unwrap_err();
382
383 let ct_a = Ciphertext {
384 ct_m: Array::generate_from_rng(&mut rng),
385 ct_x: <[u8; 32]>::generate_from_rng(&mut rng).into(),
386 };
387
388 let bytes = ct_a.to_bytes();
389 let ct_b = Ciphertext::from(&bytes);
390
391 assert!(ct_a == ct_b);
392 }
393
394 #[test]
395 fn key_serialize() {
396 let sk = DecapsulationKey::generate_from_rng(&mut SysRng.unwrap_err());
397 let pk = sk.encapsulation_key();
398
399 let sk_bytes = sk.as_bytes();
400 let pk_bytes = pk.to_bytes();
401
402 let sk_b = DecapsulationKey::from(*sk_bytes);
403 let pk_b = EncapsulationKey::try_from(&pk_bytes).unwrap();
404
405 assert!(sk == sk_b);
406 assert!(pk == pk_b);
407 }
408}