1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5 html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6 html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8
9#![cfg_attr(feature = "getrandom", doc = "```")]
16#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
17pub use kem::{
30 self, Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, Kem, Key, KeyExport,
31 KeyInit, KeySizeUser, TryKeyInit,
32};
33
34use core::fmt::{self, Debug};
35use ml_kem::{
36 FromSeed, MlKem768,
37 array::{
38 Array, ArrayN, AsArrayRef,
39 sizes::{U32, U1120, U1184, U1216},
40 },
41 ml_kem_768,
42};
43use rand_core::{CryptoRng, TryCryptoRng, TryRng};
44use sha3::{
45 Sha3_256, Shake256, Shake256Reader,
46 digest::{ExtendableOutput, XofReader},
47};
48use x25519_dalek::{PublicKey, StaticSecret};
49
50#[cfg(feature = "zeroize")]
51use zeroize::{Zeroize, ZeroizeOnDrop};
52
53type MlKem768DecapsulationKey = ml_kem_768::DecapsulationKey;
54type MlKem768EncapsulationKey = ml_kem_768::EncapsulationKey;
55
56const X_WING_LABEL: &[u8; 6] = br"\.//^\";
57
58pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
60pub const DECAPSULATION_KEY_SIZE: usize = 32;
62pub const CIPHERTEXT_SIZE: usize = 1120;
64pub const ENCAPSULATION_RANDOMNESS_SIZE: usize = 64;
66
67pub type Ciphertext = kem::Ciphertext<XWingKem>;
69pub type SharedKey = Array<u8, U32>;
71
72#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
74pub struct XWingKem;
75
76impl Kem for XWingKem {
77 type DecapsulationKey = DecapsulationKey;
78 type EncapsulationKey = EncapsulationKey;
79 type CiphertextSize = U1120;
80 type SharedKeySize = U32;
81}
82
83#[derive(Clone, Debug, Eq, PartialEq)]
95pub struct EncapsulationKey {
96 pk_m: MlKem768EncapsulationKey,
97 pk_x: PublicKey,
98}
99
100impl EncapsulationKey {
101 #[doc(hidden)]
108 #[cfg_attr(not(feature = "hazmat"), doc(hidden))]
109 #[expect(clippy::must_use_candidate)]
110 pub fn encapsulate_deterministic(
111 &self,
112 randomness: &ArrayN<u8, ENCAPSULATION_RANDOMNESS_SIZE>,
113 ) -> (Ciphertext, SharedKey) {
114 let (rand_m, rand_x) = randomness.split::<U32>();
116
117 let (ct_m, ss_m) = self.pk_m.encapsulate_deterministic(&rand_m);
119
120 let ek_x = StaticSecret::from(rand_x.0);
121 let ct_x = PublicKey::from(&ek_x);
123 let ss_x = ek_x.diffie_hellman(&self.pk_x);
125
126 let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
127 let ct = CiphertextMessage { ct_m, ct_x };
128
129 (ct.into(), ss)
130 }
131}
132
133impl Encapsulate for EncapsulationKey {
134 type Kem = XWingKem;
135
136 fn encapsulate_with_rng<R>(&self, rng: &mut R) -> (Ciphertext, SharedKey)
137 where
138 R: CryptoRng + ?Sized,
139 {
140 #[allow(unused_mut)]
141 let mut randomness = Array::generate_from_rng(rng);
142 let res = self.encapsulate_deterministic(&randomness);
143
144 #[cfg(feature = "zeroize")]
145 randomness.zeroize();
146
147 res
148 }
149}
150
151impl KeySizeUser for EncapsulationKey {
152 type KeySize = U1216;
153}
154
155impl KeyExport for EncapsulationKey {
156 fn to_bytes(&self) -> Key<Self> {
157 let mut key_bytes = Key::<Self>::default();
158 let (m, x) = key_bytes.split_at_mut(1184);
159 m.copy_from_slice(&self.pk_m.to_bytes());
160 x.copy_from_slice(self.pk_x.as_bytes());
161 key_bytes
162 }
163}
164
165impl TryKeyInit for EncapsulationKey {
166 fn new(key_bytes: &Key<Self>) -> Result<Self, InvalidKey> {
167 let (m_bytes, x_bytes) = key_bytes.split_ref::<U1184>();
168
169 let pk_m = MlKem768EncapsulationKey::new(m_bytes)?;
170 let pk_x = PublicKey::from(x_bytes.0);
171
172 Ok(EncapsulationKey { pk_m, pk_x })
173 }
174}
175
176impl TryFrom<&[u8]> for EncapsulationKey {
177 type Error = InvalidKey;
178
179 fn try_from(key_bytes: &[u8]) -> Result<Self, InvalidKey> {
180 Self::new_from_slice(key_bytes)
181 }
182}
183
184#[derive(Clone)]
186pub struct DecapsulationKey {
187 sk: [u8; DECAPSULATION_KEY_SIZE],
188 ek: EncapsulationKey,
189}
190
191impl DecapsulationKey {
192 #[must_use]
194 pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
195 &self.sk
196 }
197}
198
199impl Debug for DecapsulationKey {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.debug_struct("DecapsulationKey")
202 .field("ek", &self.ek)
203 .finish_non_exhaustive()
204 }
205}
206
207impl Decapsulate for DecapsulationKey {
208 #[allow(clippy::similar_names)] fn decapsulate(&self, ct: &Ciphertext) -> SharedKey {
210 let ct = CiphertextMessage::from(ct);
211 let (sk_m, sk_x, _pk_m, pk_x) = expand_key(&self.sk);
212
213 let ss_m = sk_m.decapsulate(&ct.ct_m);
214
215 let ss_x = sk_x.diffie_hellman(&ct.ct_x);
217
218 combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x)
219 }
220}
221
222impl Decapsulator for DecapsulationKey {
223 type Kem = XWingKem;
224
225 fn encapsulation_key(&self) -> &EncapsulationKey {
226 &self.ek
227 }
228}
229
230impl Drop for DecapsulationKey {
231 fn drop(&mut self) {
232 #[cfg(feature = "zeroize")]
233 self.sk.zeroize();
234 }
235}
236
237impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
238 fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
239 DecapsulationKey::new(sk.as_array_ref())
240 }
241}
242
243impl Generate for DecapsulationKey {
244 fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
245 where
246 R: TryCryptoRng + ?Sized,
247 {
248 <[u8; DECAPSULATION_KEY_SIZE]>::try_generate_from_rng(rng).map(Into::into)
249 }
250}
251
252impl KeySizeUser for DecapsulationKey {
253 type KeySize = U32;
254}
255
256impl KeyInit for DecapsulationKey {
257 fn new(key: &Key<Self>) -> Self {
258 let (_sk_m, _sk_x, pk_m, pk_x) = expand_key(key.as_ref());
259 let ek = EncapsulationKey { pk_m, pk_x };
260 Self { sk: key.0, ek }
261 }
262}
263
264impl KeyExport for DecapsulationKey {
265 fn to_bytes(&self) -> Key<Self> {
266 self.sk.into()
267 }
268}
269
270#[cfg(feature = "zeroize")]
271impl ZeroizeOnDrop for DecapsulationKey {}
272
273fn expand_key(
274 sk: &[u8; DECAPSULATION_KEY_SIZE],
275) -> (
276 MlKem768DecapsulationKey,
277 StaticSecret,
278 MlKem768EncapsulationKey,
279 PublicKey,
280) {
281 use sha3::digest::Update;
282 let mut shaker = Shake256::default();
283 shaker.update(sk);
284 let mut expanded: Shake256Reader = shaker.finalize_xof();
285
286 let seed = read_from(&mut expanded).into();
287 let (sk_m, pk_m) = MlKem768::from_seed(&seed);
288
289 let sk_x = read_from(&mut expanded);
290 let sk_x = StaticSecret::from(sk_x);
291 let pk_x = PublicKey::from(&sk_x);
292
293 (sk_m, sk_x, pk_m, pk_x)
294}
295
296#[derive(Clone, Debug, PartialEq, Eq)]
298pub struct CiphertextMessage {
299 ct_m: ArrayN<u8, 1088>,
300 ct_x: PublicKey,
301}
302
303impl CiphertextMessage {
304 #[must_use]
307 pub fn to_bytes(&self) -> Ciphertext {
308 let mut buffer = Ciphertext::default();
309 buffer[0..1088].copy_from_slice(&self.ct_m);
310 buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
311 buffer
312 }
313}
314
315impl From<&Ciphertext> for CiphertextMessage {
316 fn from(value: &Ciphertext) -> Self {
317 let mut ct_m = [0; 1088];
318 ct_m.copy_from_slice(&value[0..1088]);
319 let mut ct_x = [0; 32];
320 ct_x.copy_from_slice(&value[1088..]);
321
322 CiphertextMessage {
323 ct_m: ct_m.into(),
324 ct_x: ct_x.into(),
325 }
326 }
327}
328
329impl From<&CiphertextMessage> for Ciphertext {
330 #[inline]
331 fn from(msg: &CiphertextMessage) -> Self {
332 msg.to_bytes()
333 }
334}
335
336impl From<CiphertextMessage> for Ciphertext {
337 #[inline]
338 fn from(msg: CiphertextMessage) -> Self {
339 Self::from(&msg)
340 }
341}
342
343fn combiner(
344 ss_m: &ArrayN<u8, 32>,
345 ss_x: &x25519_dalek::SharedSecret,
346 ct_x: &PublicKey,
347 pk_x: &PublicKey,
348) -> SharedKey {
349 use sha3::Digest;
350
351 let mut hasher = Sha3_256::new();
352 hasher.update(ss_m);
353 hasher.update(ss_x);
354 hasher.update(ct_x);
355 hasher.update(pk_x.as_bytes());
356 hasher.update(X_WING_LABEL);
357 hasher.finalize()
358}
359
360fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
361 let mut data = [0; N];
362 reader.read(&mut data);
363 data
364}
365
366#[cfg(test)]
367mod tests {
368 use crate::{Kem, XWingKem};
369 use core::convert::Infallible;
370 use getrandom::SysRng;
371 use ml_kem::array::Array;
372 use rand_core::{TryCryptoRng, TryRng, UnwrapErr, utils};
373 use serde::Deserialize;
374
375 use super::*;
376
377 pub(crate) struct SeedRng {
378 pub(crate) seed: Vec<u8>,
379 }
380
381 impl SeedRng {
382 fn new(seed: Vec<u8>) -> SeedRng {
383 SeedRng { seed }
384 }
385 }
386
387 impl TryRng for SeedRng {
388 type Error = Infallible;
389
390 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
391 utils::next_word_via_fill(self)
392 }
393
394 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
395 utils::next_word_via_fill(self)
396 }
397
398 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
399 dest.copy_from_slice(&self.seed[0..dest.len()]);
400 self.seed.drain(0..dest.len());
401 Ok(())
402 }
403 }
404
405 #[derive(Deserialize)]
406 struct TestVector {
407 #[serde(deserialize_with = "hex::serde::deserialize")]
408 seed: Vec<u8>,
409
410 #[serde(deserialize_with = "hex::serde::deserialize")]
411 eseed: Vec<u8>,
412
413 #[serde(deserialize_with = "hex::serde::deserialize")]
414 ss: [u8; 32],
415
416 #[serde(deserialize_with = "hex::serde::deserialize")]
417 sk: [u8; 32],
418
419 #[serde(deserialize_with = "hex::serde::deserialize")]
420 pk: Vec<u8>, #[serde(deserialize_with = "hex::serde::deserialize")]
423 ct: Vec<u8>, }
425
426 impl TryCryptoRng for SeedRng {}
427
428 #[test]
430 fn rfc_test_vectors() {
431 let test_vectors =
432 serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
433
434 for test_vector in test_vectors {
435 run_test(test_vector);
436 }
437 }
438
439 fn run_test(test_vector: TestVector) {
440 let mut seed = SeedRng::new(test_vector.seed);
441 let (sk, pk) = XWingKem::generate_keypair_from_rng(&mut seed);
442
443 assert_eq!(sk.as_bytes(), &test_vector.sk);
444 assert_eq!(&*pk.to_bytes(), test_vector.pk.as_slice());
445
446 let mut eseed = SeedRng::new(test_vector.eseed);
447 let (ct, ss) = pk.encapsulate_with_rng(&mut eseed);
448
449 assert_eq!(ss, test_vector.ss);
450 assert_eq!(&*ct, test_vector.ct.as_slice());
451
452 let ss = sk.decapsulate(&ct);
453 assert_eq!(ss, test_vector.ss);
454 }
455
456 #[test]
457 fn ciphertext_serialize() {
458 let mut rng = UnwrapErr(SysRng);
459
460 let ct_a = CiphertextMessage {
461 ct_m: Array::generate_from_rng(&mut rng),
462 ct_x: <[u8; 32]>::generate_from_rng(&mut rng).into(),
463 };
464
465 let bytes = ct_a.to_bytes();
466 let ct_b = CiphertextMessage::from(&bytes);
467
468 assert!(ct_a == ct_b);
469 }
470
471 #[test]
472 #[cfg(feature = "getrandom")]
473 fn key_serialize() {
474 let (sk, pk) = XWingKem::generate_keypair();
475
476 let sk_bytes = sk.as_bytes();
477 let pk_bytes = pk.to_bytes();
478
479 let sk_b = DecapsulationKey::from(*sk_bytes);
480 let pk_b = EncapsulationKey::new(&pk_bytes).unwrap();
481
482 assert_eq!(sk.sk, sk_b.sk);
483 assert!(pk == pk_b);
484 }
485}