1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5 html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6 html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8#![deny(missing_docs)]
9#![warn(clippy::pedantic)]
10
11pub use kem::{self, Decapsulate, Encapsulate};
30
31use core::convert::Infallible;
32use ml_kem::array::{ArrayN, typenum::consts::U32};
33use ml_kem::{B32, EncodedSizeUser, KemCore, MlKem768, MlKem768Params};
34use rand_core::{CryptoRng, TryCryptoRng};
35#[cfg(feature = "os_rng")]
36use rand_core::{OsRng, TryRngCore};
37use sha3::digest::{ExtendableOutput, XofReader};
38use sha3::{Sha3_256, Shake256, Shake256Reader};
39use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret};
40
41#[cfg(feature = "zeroize")]
42use zeroize::{Zeroize, ZeroizeOnDrop};
43
44type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey<MlKem768Params>;
45type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey<MlKem768Params>;
46
47const X_WING_LABEL: &[u8; 6] = br"\.//^\";
48
49pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
51pub const DECAPSULATION_KEY_SIZE: usize = 32;
53pub const CIPHERTEXT_SIZE: usize = 1120;
55
56pub type SharedSecret = [u8; 32];
58
59#[derive(Clone, PartialEq)]
71pub struct EncapsulationKey {
72 pk_m: MlKem768EncapsulationKey,
73 pk_x: PublicKey,
74}
75
76impl Encapsulate<Ciphertext, SharedSecret> for EncapsulationKey {
77 type Error = Infallible;
78
79 fn encapsulate<R: TryCryptoRng + ?Sized>(
80 &self,
81 rng: &mut R,
82 ) -> Result<(Ciphertext, SharedSecret), Self::Error> {
83 let (ct_m, ss_m) = self.pk_m.encapsulate(rng)?;
85
86 let ek_x = EphemeralSecret::random_from_rng(&mut rng.unwrap_mut());
87 let ct_x = PublicKey::from(&ek_x);
89 let ss_x = ek_x.diffie_hellman(&self.pk_x);
91
92 let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
93 let ct = Ciphertext { ct_m, ct_x };
94 Ok((ct, ss))
95 }
96}
97
98impl EncapsulationKey {
99 #[must_use]
102 pub fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_SIZE] {
103 let mut buffer = [0u8; ENCAPSULATION_KEY_SIZE];
104 buffer[0..1184].copy_from_slice(&self.pk_m.as_bytes());
105 buffer[1184..1216].copy_from_slice(self.pk_x.as_bytes());
106 buffer
107 }
108}
109
110impl From<&[u8; ENCAPSULATION_KEY_SIZE]> for EncapsulationKey {
111 fn from(value: &[u8; ENCAPSULATION_KEY_SIZE]) -> Self {
112 let mut pk_m = [0; 1184];
113 pk_m.copy_from_slice(&value[0..1184]);
114 let pk_m = MlKem768EncapsulationKey::from_bytes(&pk_m.into());
115
116 let mut pk_x = [0; 32];
117 pk_x.copy_from_slice(&value[1184..]);
118 let pk_x = PublicKey::from(pk_x);
119 EncapsulationKey { pk_m, pk_x }
120 }
121}
122
123#[derive(Clone)]
125#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
126#[cfg_attr(test, derive(PartialEq, Eq))]
127pub struct DecapsulationKey {
128 sk: [u8; DECAPSULATION_KEY_SIZE],
129}
130
131impl Decapsulate<Ciphertext, SharedSecret> for DecapsulationKey {
132 type Encapsulator = EncapsulationKey;
133 type Error = Infallible;
134
135 #[allow(clippy::similar_names)] fn decapsulate(&self, ct: &Ciphertext) -> Result<SharedSecret, Self::Error> {
137 let (sk_m, sk_x, _pk_m, pk_x) = self.expand_key();
138
139 let ss_m = sk_m.decapsulate(&ct.ct_m)?;
140
141 let ss_x = sk_x.diffie_hellman(&ct.ct_x);
143
144 let ss = combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x);
145 Ok(ss)
146 }
147
148 fn encapsulator(&self) -> EncapsulationKey {
149 self.encapsulation_key()
150 }
151}
152
153impl ::kem::KeySizeUser for DecapsulationKey {
154 type KeySize = U32;
155}
156
157impl ::kem::KeyInit for DecapsulationKey {
158 fn new(key: &ArrayN<u8, 32>) -> Self {
159 Self { sk: key.0 }
160 }
161}
162
163impl DecapsulationKey {
164 #[cfg(feature = "os_rng")]
166 #[must_use]
167 pub fn generate_from_os_rng() -> DecapsulationKey {
168 Self::generate(&mut OsRng.unwrap_err())
169 }
170
171 pub fn generate<R: CryptoRng + ?Sized>(rng: &mut R) -> DecapsulationKey {
173 let sk = generate(rng);
174 DecapsulationKey { sk }
175 }
176
177 #[must_use]
179 pub fn encapsulation_key(&self) -> EncapsulationKey {
180 let (_sk_m, _sk_x, pk_m, pk_x) = self.expand_key();
181 EncapsulationKey { pk_m, pk_x }
182 }
183
184 fn expand_key(
185 &self,
186 ) -> (
187 MlKem768DecapsulationKey,
188 StaticSecret,
189 MlKem768EncapsulationKey,
190 PublicKey,
191 ) {
192 use sha3::digest::Update;
193 let mut shaker = Shake256::default();
194 shaker.update(&self.sk);
195 let mut expanded: Shake256Reader = shaker.finalize_xof();
196
197 let seed = read_from(&mut expanded).into();
198 let (sk_m, pk_m) = MlKem768::from_seed(seed);
199
200 let sk_x = read_from(&mut expanded);
201 let sk_x = StaticSecret::from(sk_x);
202 let pk_x = PublicKey::from(&sk_x);
203
204 (sk_m, sk_x, pk_m, pk_x)
205 }
206
207 #[must_use]
209 pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
210 &self.sk
211 }
212}
213
214impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
215 fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
216 DecapsulationKey { sk }
217 }
218}
219
220#[derive(Clone, PartialEq, Eq)]
222#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
223pub struct Ciphertext {
224 ct_m: ArrayN<u8, 1088>,
225 ct_x: PublicKey,
226}
227
228impl Ciphertext {
229 #[must_use]
232 pub fn to_bytes(&self) -> [u8; CIPHERTEXT_SIZE] {
233 let mut buffer = [0; CIPHERTEXT_SIZE];
234 buffer[0..1088].copy_from_slice(&self.ct_m);
235 buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
236 buffer
237 }
238}
239
240impl From<&[u8; CIPHERTEXT_SIZE]> for Ciphertext {
241 fn from(value: &[u8; CIPHERTEXT_SIZE]) -> Self {
242 let mut ct_m = [0; 1088];
243 ct_m.copy_from_slice(&value[0..1088]);
244 let mut ct_x = [0; 32];
245 ct_x.copy_from_slice(&value[1088..]);
246
247 Ciphertext {
248 ct_m: ct_m.into(),
249 ct_x: ct_x.into(),
250 }
251 }
252}
253
254#[cfg(feature = "os_rng")]
256#[must_use]
257pub fn generate_key_pair_from_os_rng() -> (DecapsulationKey, EncapsulationKey) {
258 generate_key_pair(&mut OsRng.unwrap_err())
259}
260
261pub fn generate_key_pair<R: CryptoRng + ?Sized>(
263 rng: &mut R,
264) -> (DecapsulationKey, EncapsulationKey) {
265 let sk = DecapsulationKey::generate(rng);
266 let pk = sk.encapsulation_key();
267 (sk, pk)
268}
269
270fn combiner(
271 ss_m: &B32,
272 ss_x: &x25519_dalek::SharedSecret,
273 ct_x: &PublicKey,
274 pk_x: &PublicKey,
275) -> SharedSecret {
276 use sha3::Digest;
277
278 let mut hasher = Sha3_256::new();
279 hasher.update(ss_m);
280 hasher.update(ss_x);
281 hasher.update(ct_x);
282 hasher.update(pk_x.as_bytes());
283 hasher.update(X_WING_LABEL);
284 hasher.finalize().into()
285}
286
287fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
288 let mut data = [0; N];
289 reader.read(&mut data);
290 data
291}
292
293fn generate<const N: usize, R: CryptoRng + ?Sized>(rng: &mut R) -> [u8; N] {
294 let mut random = [0; N];
295 rng.fill_bytes(&mut random);
296 random
297}
298
299#[cfg(test)]
300mod tests {
301 use rand_core::{CryptoRng, OsRng, RngCore, TryRngCore, impls};
302 use serde::Deserialize;
303
304 use super::*;
305
306 pub(crate) struct SeedRng {
307 pub(crate) seed: Vec<u8>,
308 }
309
310 impl SeedRng {
311 fn new(seed: Vec<u8>) -> SeedRng {
312 SeedRng { seed }
313 }
314 }
315
316 impl RngCore for SeedRng {
317 fn next_u32(&mut self) -> u32 {
318 impls::next_u32_via_fill(self)
319 }
320
321 fn next_u64(&mut self) -> u64 {
322 impls::next_u64_via_fill(self)
323 }
324
325 fn fill_bytes(&mut self, dest: &mut [u8]) {
326 dest.copy_from_slice(&self.seed[0..dest.len()]);
327 self.seed.drain(0..dest.len());
328 }
329 }
330
331 #[derive(Deserialize)]
332 struct TestVector {
333 #[serde(deserialize_with = "hex::serde::deserialize")]
334 seed: Vec<u8>,
335
336 #[serde(deserialize_with = "hex::serde::deserialize")]
337 eseed: Vec<u8>,
338
339 #[serde(deserialize_with = "hex::serde::deserialize")]
340 ss: [u8; 32],
341
342 #[serde(deserialize_with = "hex::serde::deserialize")]
343 sk: [u8; 32],
344
345 #[serde(deserialize_with = "hex::serde::deserialize")]
346 pk: Vec<u8>, #[serde(deserialize_with = "hex::serde::deserialize")]
349 ct: Vec<u8>, }
351
352 impl CryptoRng for SeedRng {}
353
354 #[test]
356 fn rfc_test_vectors() {
357 let test_vectors =
358 serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
359
360 for test_vector in test_vectors {
361 run_test(test_vector);
362 }
363 }
364
365 fn run_test(test_vector: TestVector) {
366 let mut seed = SeedRng::new(test_vector.seed);
367 let (sk, pk) = generate_key_pair(&mut seed);
368
369 assert_eq!(sk.as_bytes(), &test_vector.sk);
370 assert_eq!(&pk.to_bytes(), test_vector.pk.as_slice());
371
372 let mut eseed = SeedRng::new(test_vector.eseed);
373 let (ct, ss) = pk.encapsulate(&mut eseed).unwrap();
374
375 assert_eq!(ss, test_vector.ss);
376 assert_eq!(&ct.to_bytes(), test_vector.ct.as_slice());
377
378 let ss = sk.decapsulate(&ct).unwrap();
379 assert_eq!(ss, test_vector.ss);
380 }
381
382 #[test]
383 fn ciphertext_serialize() {
384 let mut rng = OsRng.unwrap_err();
385
386 let ct_a = Ciphertext {
387 ct_m: generate(&mut rng).into(),
388 ct_x: generate(&mut rng).into(),
389 };
390
391 let bytes = ct_a.to_bytes();
392
393 let ct_b = Ciphertext::from(&bytes);
394
395 assert!(ct_a == ct_b);
396 }
397
398 #[test]
399 fn key_serialize() {
400 let sk = DecapsulationKey::generate(&mut OsRng.unwrap_err());
401 let pk = sk.encapsulation_key();
402
403 let sk_bytes = sk.as_bytes();
404 let pk_bytes = pk.to_bytes();
405
406 let sk_b = DecapsulationKey::from(*sk_bytes);
407 let pk_b = EncapsulationKey::from(&pk_bytes);
408
409 assert!(sk == sk_b);
410 assert!(pk == pk_b);
411 }
412}