x_wing/
lib.rs

1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8#![deny(missing_docs)]
9#![warn(clippy::pedantic)]
10
11//! # Usage
12//!
13//! This crate implements the X-Wing Key Encapsulation Method (X-Wing-KEM) algorithm.
14//! X-Wing-KEM is a KEM in the sense that it creates an (decapsulation key, encapsulation key) pair,
15//! such that anyone can use the encapsulation key to establish a shared key with the holder of the
16//! decapsulation key. X-Wing-KEM is a general-purpose hybrid post-quantum KEM, combining x25519 and ML-KEM-768.
17#![cfg_attr(feature = "getrandom", doc = "```")]
18#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
19//! // NOTE: requires the `getrandom` feature is enabled
20//! use kem::{Decapsulate, Encapsulate};
21//!
22//! let (sk, pk) = x_wing::generate_key_pair();
23//! let (ct, ss_sender) = pk.encapsulate();
24//! let ss_receiver = sk.decapsulate(&ct);
25//! assert_eq!(ss_sender, ss_receiver);
26//! ```
27
28pub use kem::{
29    self, Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, KemParams, Key, KeyExport,
30    KeyInit, KeySizeUser, TryKeyInit,
31};
32
33use ml_kem::{
34    B32, EncodedSizeUser, KemCore, MlKem768, MlKem768Params,
35    array::{
36        Array, ArrayN, AsArrayRef,
37        sizes::{U32, U1120, U1216},
38    },
39};
40use rand_core::{CryptoRng, TryCryptoRng, TryRngCore};
41use sha3::{
42    Sha3_256, Shake256, Shake256Reader,
43    digest::{ExtendableOutput, XofReader},
44};
45use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret};
46
47#[cfg(feature = "zeroize")]
48use zeroize::{Zeroize, ZeroizeOnDrop};
49
50type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey<MlKem768Params>;
51type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey<MlKem768Params>;
52
53const X_WING_LABEL: &[u8; 6] = br"\.//^\";
54
55/// Size in bytes of the `EncapsulationKey`.
56pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
57/// Size in bytes of the `DecapsulationKey`.
58pub const DECAPSULATION_KEY_SIZE: usize = 32;
59/// Size in bytes of the `Ciphertext`.
60pub const CIPHERTEXT_SIZE: usize = 1120;
61
62/// Serialized ciphertext.
63pub type Ciphertext = Array<u8, U1120>;
64/// Shared secret key.
65pub type SharedSecret = Array<u8, U32>;
66
67// The naming convention of variables matches the RFC.
68// ss -> Shared Secret
69// ct -> Cipher Text
70// ek -> Ephemeral Key
71// pk -> Public Key
72// sk -> Secret Key
73// Postfixes:
74// _m -> ML-Kem related key
75// _x -> x25519 related key
76
77/// X-Wing encapsulation or public key.
78#[derive(Clone, Eq, PartialEq)]
79pub struct EncapsulationKey {
80    pk_m: MlKem768EncapsulationKey,
81    pk_x: PublicKey,
82}
83
84impl Encapsulate for EncapsulationKey {
85    fn encapsulate_with_rng<R: TryCryptoRng + ?Sized>(
86        &self,
87        rng: &mut R,
88    ) -> Result<(Ciphertext, SharedSecret), R::Error> {
89        // Swapped order of operations compared to RFC, so that usage of the rng matches the RFC
90        let (ct_m, ss_m) = self.pk_m.encapsulate_with_rng(rng)?;
91
92        let ek_x = EphemeralSecret::random_from_rng(&mut rng.unwrap_err());
93        // Equal to ct_x = x25519(ek_x, BASE_POINT)
94        let ct_x = PublicKey::from(&ek_x);
95        // Equal to ss_x = x25519(ek_x, pk_x)
96        let ss_x = ek_x.diffie_hellman(&self.pk_x);
97
98        let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
99        let ct = CiphertextMessage { ct_m, ct_x };
100        Ok((ct.into(), ss))
101    }
102}
103
104impl KemParams for EncapsulationKey {
105    type CiphertextSize = U1120;
106    type SharedSecretSize = U32;
107}
108
109impl KeySizeUser for EncapsulationKey {
110    type KeySize = U1216;
111}
112
113impl KeyExport for EncapsulationKey {
114    fn to_bytes(&self) -> Key<Self> {
115        let mut key_bytes = Key::<Self>::default();
116        let (m, x) = key_bytes.split_at_mut(1184);
117        m.copy_from_slice(&self.pk_m.to_encoded_bytes());
118        x.copy_from_slice(self.pk_x.as_bytes());
119        key_bytes
120    }
121}
122
123impl TryKeyInit for EncapsulationKey {
124    fn new(key_bytes: &Key<Self>) -> Result<Self, InvalidKey> {
125        let mut pk_m = [0; 1184];
126        pk_m.copy_from_slice(&key_bytes[0..1184]);
127        let pk_m =
128            MlKem768EncapsulationKey::from_encoded_bytes(&pk_m.into()).map_err(|_| InvalidKey)?;
129
130        let mut pk_x = [0; 32];
131        pk_x.copy_from_slice(&key_bytes[1184..]);
132        let pk_x = PublicKey::from(pk_x);
133        Ok(EncapsulationKey { pk_m, pk_x })
134    }
135}
136
137impl TryFrom<&[u8]> for EncapsulationKey {
138    type Error = InvalidKey;
139
140    fn try_from(key_bytes: &[u8]) -> Result<Self, InvalidKey> {
141        Self::new_from_slice(key_bytes)
142    }
143}
144
145/// X-Wing decapsulation key or private key.
146#[derive(Clone)]
147pub struct DecapsulationKey {
148    sk: [u8; DECAPSULATION_KEY_SIZE],
149    ek: EncapsulationKey,
150}
151
152impl DecapsulationKey {
153    /// Private key as bytes.
154    #[must_use]
155    pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
156        &self.sk
157    }
158}
159
160impl Decapsulate for DecapsulationKey {
161    #[allow(clippy::similar_names)] // So we can use the names as in the RFC
162    fn decapsulate(&self, ct: &Ciphertext) -> SharedSecret {
163        let ct = CiphertextMessage::from(ct);
164        let (sk_m, sk_x, _pk_m, pk_x) = expand_key(&self.sk);
165
166        let ss_m = sk_m.decapsulate(&ct.ct_m);
167
168        // equal to ss_x = x25519(sk_x, ct_x)
169        let ss_x = sk_x.diffie_hellman(&ct.ct_x);
170
171        combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x)
172    }
173}
174
175impl Decapsulator for DecapsulationKey {
176    type Encapsulator = EncapsulationKey;
177
178    fn encapsulator(&self) -> &EncapsulationKey {
179        &self.ek
180    }
181}
182
183impl Drop for DecapsulationKey {
184    fn drop(&mut self) {
185        #[cfg(feature = "zeroize")]
186        self.sk.zeroize();
187    }
188}
189
190impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
191    fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
192        DecapsulationKey::new(sk.as_array_ref())
193    }
194}
195
196impl Generate for DecapsulationKey {
197    fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRngCore>::Error>
198    where
199        R: TryCryptoRng + ?Sized,
200    {
201        <[u8; DECAPSULATION_KEY_SIZE]>::try_generate_from_rng(rng).map(Into::into)
202    }
203}
204
205impl KeySizeUser for DecapsulationKey {
206    type KeySize = U32;
207}
208
209impl KeyInit for DecapsulationKey {
210    fn new(key: &ArrayN<u8, 32>) -> Self {
211        let (_sk_m, _sk_x, pk_m, pk_x) = expand_key(key.as_ref());
212        let ek = EncapsulationKey { pk_m, pk_x };
213        Self { sk: key.0, ek }
214    }
215}
216
217#[cfg(feature = "zeroize")]
218impl ZeroizeOnDrop for DecapsulationKey {}
219
220fn expand_key(
221    sk: &[u8; DECAPSULATION_KEY_SIZE],
222) -> (
223    MlKem768DecapsulationKey,
224    StaticSecret,
225    MlKem768EncapsulationKey,
226    PublicKey,
227) {
228    use sha3::digest::Update;
229    let mut shaker = Shake256::default();
230    shaker.update(sk);
231    let mut expanded: Shake256Reader = shaker.finalize_xof();
232
233    let seed = read_from(&mut expanded).into();
234    let (sk_m, pk_m) = MlKem768::from_seed(seed);
235
236    let sk_x = read_from(&mut expanded);
237    let sk_x = StaticSecret::from(sk_x);
238    let pk_x = PublicKey::from(&sk_x);
239
240    (sk_m, sk_x, pk_m, pk_x)
241}
242
243/// X-Wing ciphertext.
244#[derive(Clone, PartialEq, Eq)]
245pub struct CiphertextMessage {
246    ct_m: ArrayN<u8, 1088>,
247    ct_x: PublicKey,
248}
249
250impl CiphertextMessage {
251    /// Convert the ciphertext to the following format:
252    /// ML-KEM-768 ciphertext(1088 bytes) || X25519 ciphertext(32 bytes).
253    #[must_use]
254    pub fn to_bytes(&self) -> Ciphertext {
255        let mut buffer = Ciphertext::default();
256        buffer[0..1088].copy_from_slice(&self.ct_m);
257        buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
258        buffer
259    }
260}
261
262impl From<&Ciphertext> for CiphertextMessage {
263    fn from(value: &Ciphertext) -> Self {
264        let mut ct_m = [0; 1088];
265        ct_m.copy_from_slice(&value[0..1088]);
266        let mut ct_x = [0; 32];
267        ct_x.copy_from_slice(&value[1088..]);
268
269        CiphertextMessage {
270            ct_m: ct_m.into(),
271            ct_x: ct_x.into(),
272        }
273    }
274}
275
276impl From<&CiphertextMessage> for Ciphertext {
277    #[inline]
278    fn from(msg: &CiphertextMessage) -> Self {
279        msg.to_bytes()
280    }
281}
282
283impl From<CiphertextMessage> for Ciphertext {
284    #[inline]
285    fn from(msg: CiphertextMessage) -> Self {
286        Self::from(&msg)
287    }
288}
289
290/// Generate a X-Wing key pair using `OsRng`.
291#[cfg(feature = "getrandom")]
292#[must_use]
293pub fn generate_key_pair() -> (DecapsulationKey, EncapsulationKey) {
294    let sk = DecapsulationKey::generate();
295    let pk = sk.encapsulator().clone();
296    (sk, pk)
297}
298
299/// Generate a X-Wing key pair using the provided rng.
300pub fn generate_key_pair_from_rng<R: CryptoRng + ?Sized>(
301    rng: &mut R,
302) -> (DecapsulationKey, EncapsulationKey) {
303    let sk = DecapsulationKey::generate_from_rng(rng);
304    let pk = sk.encapsulator().clone();
305    (sk, pk)
306}
307
308fn combiner(
309    ss_m: &B32,
310    ss_x: &x25519_dalek::SharedSecret,
311    ct_x: &PublicKey,
312    pk_x: &PublicKey,
313) -> SharedSecret {
314    use sha3::Digest;
315
316    let mut hasher = Sha3_256::new();
317    hasher.update(ss_m);
318    hasher.update(ss_x);
319    hasher.update(ct_x);
320    hasher.update(pk_x.as_bytes());
321    hasher.update(X_WING_LABEL);
322    hasher.finalize()
323}
324
325fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
326    let mut data = [0; N];
327    reader.read(&mut data);
328    data
329}
330
331#[cfg(test)]
332mod tests {
333    use core::convert::Infallible;
334    use getrandom::SysRng;
335    use ml_kem::array::Array;
336    use rand_core::{TryCryptoRng, TryRngCore, utils};
337    use serde::Deserialize;
338
339    use super::*;
340
341    pub(crate) struct SeedRng {
342        pub(crate) seed: Vec<u8>,
343    }
344
345    impl SeedRng {
346        fn new(seed: Vec<u8>) -> SeedRng {
347            SeedRng { seed }
348        }
349    }
350
351    impl TryRngCore for SeedRng {
352        type Error = Infallible;
353
354        fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
355            utils::next_word_via_fill(self)
356        }
357
358        fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
359            utils::next_word_via_fill(self)
360        }
361
362        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
363            dest.copy_from_slice(&self.seed[0..dest.len()]);
364            self.seed.drain(0..dest.len());
365            Ok(())
366        }
367    }
368
369    #[derive(Deserialize)]
370    struct TestVector {
371        #[serde(deserialize_with = "hex::serde::deserialize")]
372        seed: Vec<u8>,
373
374        #[serde(deserialize_with = "hex::serde::deserialize")]
375        eseed: Vec<u8>,
376
377        #[serde(deserialize_with = "hex::serde::deserialize")]
378        ss: [u8; 32],
379
380        #[serde(deserialize_with = "hex::serde::deserialize")]
381        sk: [u8; 32],
382
383        #[serde(deserialize_with = "hex::serde::deserialize")]
384        pk: Vec<u8>, //[u8; PUBLIC_KEY_SIZE],
385
386        #[serde(deserialize_with = "hex::serde::deserialize")]
387        ct: Vec<u8>, //[u8; 1120],
388    }
389
390    impl TryCryptoRng for SeedRng {}
391
392    /// Test with test vectors from: <https://github.com/dconnolly/draft-connolly-cfrg-xwing-kem/blob/main/spec/test-vectors.json>
393    #[test]
394    fn rfc_test_vectors() {
395        let test_vectors =
396            serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
397
398        for test_vector in test_vectors {
399            run_test(test_vector);
400        }
401    }
402
403    fn run_test(test_vector: TestVector) {
404        let mut seed = SeedRng::new(test_vector.seed);
405        let (sk, pk) = generate_key_pair_from_rng(&mut seed);
406
407        assert_eq!(sk.as_bytes(), &test_vector.sk);
408        assert_eq!(&*pk.to_bytes(), test_vector.pk.as_slice());
409
410        let mut eseed = SeedRng::new(test_vector.eseed);
411        let (ct, ss) = pk.encapsulate_with_rng(&mut eseed).unwrap();
412
413        assert_eq!(ss, test_vector.ss);
414        assert_eq!(&*ct, test_vector.ct.as_slice());
415
416        let ss = sk.decapsulate(&ct);
417        assert_eq!(ss, test_vector.ss);
418    }
419
420    #[test]
421    fn ciphertext_serialize() {
422        let mut rng = SysRng.unwrap_err();
423
424        let ct_a = CiphertextMessage {
425            ct_m: Array::generate_from_rng(&mut rng),
426            ct_x: <[u8; 32]>::generate_from_rng(&mut rng).into(),
427        };
428
429        let bytes = ct_a.to_bytes();
430        let ct_b = CiphertextMessage::from(&bytes);
431
432        assert!(ct_a == ct_b);
433    }
434
435    #[test]
436    fn key_serialize() {
437        let sk = DecapsulationKey::generate_from_rng(&mut SysRng.unwrap_err());
438        let pk = sk.encapsulator().clone();
439
440        let sk_bytes = sk.as_bytes();
441        let pk_bytes = pk.to_bytes();
442
443        let sk_b = DecapsulationKey::from(*sk_bytes);
444        let pk_b = EncapsulationKey::new(&pk_bytes).unwrap();
445
446        assert_eq!(sk.sk, sk_b.sk);
447        assert!(pk == pk_b);
448    }
449}