x_wing/
lib.rs

1#![cfg_attr(not(test), no_std)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4#![doc(
5    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg",
6    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
7)]
8#![deny(missing_docs)]
9#![warn(clippy::pedantic)]
10
11//! # Usage
12//!
13//! This crate implements the X-Wing Key Encapsulation Method (X-Wing-KEM) algorithm.
14//! X-Wing-KEM is a KEM in the sense that it creates an (decapsulation key, encapsulation key) pair,
15//! such that anyone can use the encapsulation key to establish a shared key with the holder of the
16//! decapsulation key. X-Wing-KEM is a general-purpose hybrid post-quantum KEM, combining x25519 and ML-KEM-768.
17#![cfg_attr(feature = "getrandom", doc = "```")]
18#![cfg_attr(not(feature = "getrandom"), doc = "```ignore")]
19//! // NOTE: requires the `getrandom` feature is enabled
20//! use kem::{Decapsulate, Encapsulate};
21//!
22//! let (sk, pk) = x_wing::generate_key_pair();
23//! let (ct, ss_sender) = pk.encapsulate().unwrap();
24//! let ss_receiver = sk.decapsulate(&ct).unwrap();
25//! assert_eq!(ss_sender, ss_receiver);
26//! ```
27
28pub use kem::{self, Decapsulate, Encapsulate, Generate};
29
30use core::convert::Infallible;
31use ml_kem::{
32    B32, EncodedSizeUser, Error, KemCore, MlKem768, MlKem768Params,
33    array::{ArrayN, typenum::consts::U32},
34};
35use rand_core::{CryptoRng, TryCryptoRng, TryRngCore};
36use sha3::{
37    Sha3_256, Shake256, Shake256Reader,
38    digest::{ExtendableOutput, XofReader},
39};
40use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret};
41
42#[cfg(feature = "zeroize")]
43use zeroize::{Zeroize, ZeroizeOnDrop};
44
45type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey<MlKem768Params>;
46type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey<MlKem768Params>;
47
48const X_WING_LABEL: &[u8; 6] = br"\.//^\";
49
50/// Size in bytes of the `EncapsulationKey`.
51pub const ENCAPSULATION_KEY_SIZE: usize = 1216;
52/// Size in bytes of the `DecapsulationKey`.
53pub const DECAPSULATION_KEY_SIZE: usize = 32;
54/// Size in bytes of the `Ciphertext`.
55pub const CIPHERTEXT_SIZE: usize = 1120;
56
57/// Shared secret key.
58pub type SharedSecret = [u8; 32];
59
60// The naming convention of variables matches the RFC.
61// ss -> Shared Secret
62// ct -> Cipher Text
63// ek -> Ephemeral Key
64// pk -> Public Key
65// sk -> Secret Key
66// Postfixes:
67// _m -> ML-Kem related key
68// _x -> x25519 related key
69
70/// X-Wing encapsulation or public key.
71#[derive(Clone, PartialEq)]
72pub struct EncapsulationKey {
73    pk_m: MlKem768EncapsulationKey,
74    pk_x: PublicKey,
75}
76
77impl Encapsulate<Ciphertext, SharedSecret> for EncapsulationKey {
78    type Error = Error;
79
80    fn encapsulate_with_rng<R: TryCryptoRng + ?Sized>(
81        &self,
82        rng: &mut R,
83    ) -> Result<(Ciphertext, SharedSecret), Self::Error> {
84        // Swapped order of operations compared to RFC, so that usage of the rng matches the RFC
85        let (ct_m, ss_m) = self.pk_m.encapsulate_with_rng(rng)?;
86
87        let ek_x = EphemeralSecret::random_from_rng(&mut rng.unwrap_mut());
88        // Equal to ct_x = x25519(ek_x, BASE_POINT)
89        let ct_x = PublicKey::from(&ek_x);
90        // Equal to ss_x = x25519(ek_x, pk_x)
91        let ss_x = ek_x.diffie_hellman(&self.pk_x);
92
93        let ss = combiner(&ss_m, &ss_x, &ct_x, &self.pk_x);
94        let ct = Ciphertext { ct_m, ct_x };
95        Ok((ct, ss))
96    }
97}
98
99impl EncapsulationKey {
100    /// Convert the key to the following format:
101    /// ML-KEM-768 public key(1184 bytes) || X25519 public key(32 bytes).
102    #[must_use]
103    pub fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_SIZE] {
104        let mut buffer = [0u8; ENCAPSULATION_KEY_SIZE];
105        buffer[0..1184].copy_from_slice(&self.pk_m.to_bytes());
106        buffer[1184..1216].copy_from_slice(self.pk_x.as_bytes());
107        buffer
108    }
109}
110
111impl TryFrom<&[u8; ENCAPSULATION_KEY_SIZE]> for EncapsulationKey {
112    type Error = ml_kem::Error;
113
114    fn try_from(value: &[u8; ENCAPSULATION_KEY_SIZE]) -> Result<Self, ml_kem::Error> {
115        let mut pk_m = [0; 1184];
116        pk_m.copy_from_slice(&value[0..1184]);
117        let pk_m = MlKem768EncapsulationKey::from_bytes(&pk_m.into())?;
118
119        let mut pk_x = [0; 32];
120        pk_x.copy_from_slice(&value[1184..]);
121        let pk_x = PublicKey::from(pk_x);
122        Ok(EncapsulationKey { pk_m, pk_x })
123    }
124}
125
126/// X-Wing decapsulation key or private key.
127#[derive(Clone)]
128#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
129#[cfg_attr(test, derive(PartialEq, Eq))]
130pub struct DecapsulationKey {
131    sk: [u8; DECAPSULATION_KEY_SIZE],
132}
133
134impl Decapsulate<Ciphertext, SharedSecret> for DecapsulationKey {
135    type Encapsulator = EncapsulationKey;
136    type Error = Infallible;
137
138    #[allow(clippy::similar_names)] // So we can use the names as in the RFC
139    fn decapsulate(&self, ct: &Ciphertext) -> Result<SharedSecret, Self::Error> {
140        let (sk_m, sk_x, _pk_m, pk_x) = self.expand_key();
141
142        let ss_m = sk_m.decapsulate(&ct.ct_m)?;
143
144        // equal to ss_x = x25519(sk_x, ct_x)
145        let ss_x = sk_x.diffie_hellman(&ct.ct_x);
146
147        let ss = combiner(&ss_m, &ss_x, &ct.ct_x, &pk_x);
148        Ok(ss)
149    }
150
151    fn encapsulator(&self) -> EncapsulationKey {
152        self.encapsulation_key()
153    }
154}
155
156impl ::kem::KeySizeUser for DecapsulationKey {
157    type KeySize = U32;
158}
159
160impl ::kem::KeyInit for DecapsulationKey {
161    fn new(key: &ArrayN<u8, 32>) -> Self {
162        Self { sk: key.0 }
163    }
164}
165
166impl DecapsulationKey {
167    /// Provide the matching `EncapsulationKey`.
168    #[must_use]
169    pub fn encapsulation_key(&self) -> EncapsulationKey {
170        let (_sk_m, _sk_x, pk_m, pk_x) = self.expand_key();
171        EncapsulationKey { pk_m, pk_x }
172    }
173
174    fn expand_key(
175        &self,
176    ) -> (
177        MlKem768DecapsulationKey,
178        StaticSecret,
179        MlKem768EncapsulationKey,
180        PublicKey,
181    ) {
182        use sha3::digest::Update;
183        let mut shaker = Shake256::default();
184        shaker.update(&self.sk);
185        let mut expanded: Shake256Reader = shaker.finalize_xof();
186
187        let seed = read_from(&mut expanded).into();
188        let (sk_m, pk_m) = MlKem768::from_seed(seed);
189
190        let sk_x = read_from(&mut expanded);
191        let sk_x = StaticSecret::from(sk_x);
192        let pk_x = PublicKey::from(&sk_x);
193
194        (sk_m, sk_x, pk_m, pk_x)
195    }
196
197    /// Private key as bytes.
198    #[must_use]
199    pub fn as_bytes(&self) -> &[u8; DECAPSULATION_KEY_SIZE] {
200        &self.sk
201    }
202}
203
204impl Generate for DecapsulationKey {
205    fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRngCore>::Error>
206    where
207        R: TryCryptoRng + ?Sized,
208    {
209        <[u8; DECAPSULATION_KEY_SIZE]>::try_generate_from_rng(rng).map(Into::into)
210    }
211}
212
213impl From<[u8; DECAPSULATION_KEY_SIZE]> for DecapsulationKey {
214    fn from(sk: [u8; DECAPSULATION_KEY_SIZE]) -> Self {
215        DecapsulationKey { sk }
216    }
217}
218
219/// X-Wing ciphertext.
220#[derive(Clone, PartialEq, Eq)]
221#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
222pub struct Ciphertext {
223    ct_m: ArrayN<u8, 1088>,
224    ct_x: PublicKey,
225}
226
227impl Ciphertext {
228    /// Convert the ciphertext to the following format:
229    /// ML-KEM-768 ciphertext(1088 bytes) || X25519 ciphertext(32 bytes).
230    #[must_use]
231    pub fn to_bytes(&self) -> [u8; CIPHERTEXT_SIZE] {
232        let mut buffer = [0; CIPHERTEXT_SIZE];
233        buffer[0..1088].copy_from_slice(&self.ct_m);
234        buffer[1088..].copy_from_slice(self.ct_x.as_bytes());
235        buffer
236    }
237}
238
239impl From<&[u8; CIPHERTEXT_SIZE]> for Ciphertext {
240    fn from(value: &[u8; CIPHERTEXT_SIZE]) -> Self {
241        let mut ct_m = [0; 1088];
242        ct_m.copy_from_slice(&value[0..1088]);
243        let mut ct_x = [0; 32];
244        ct_x.copy_from_slice(&value[1088..]);
245
246        Ciphertext {
247            ct_m: ct_m.into(),
248            ct_x: ct_x.into(),
249        }
250    }
251}
252
253/// Generate a X-Wing key pair using `OsRng`.
254#[cfg(feature = "getrandom")]
255#[must_use]
256pub fn generate_key_pair() -> (DecapsulationKey, EncapsulationKey) {
257    let sk = DecapsulationKey::generate();
258    let pk = sk.encapsulation_key();
259    (sk, pk)
260}
261
262/// Generate a X-Wing key pair using the provided rng.
263pub fn generate_key_pair_from_rng<R: CryptoRng + ?Sized>(
264    rng: &mut R,
265) -> (DecapsulationKey, EncapsulationKey) {
266    let sk = DecapsulationKey::generate_from_rng(rng);
267    let pk = sk.encapsulation_key();
268    (sk, pk)
269}
270
271fn combiner(
272    ss_m: &B32,
273    ss_x: &x25519_dalek::SharedSecret,
274    ct_x: &PublicKey,
275    pk_x: &PublicKey,
276) -> SharedSecret {
277    use sha3::Digest;
278
279    let mut hasher = Sha3_256::new();
280    hasher.update(ss_m);
281    hasher.update(ss_x);
282    hasher.update(ct_x);
283    hasher.update(pk_x.as_bytes());
284    hasher.update(X_WING_LABEL);
285    hasher.finalize().into()
286}
287
288fn read_from<const N: usize>(reader: &mut Shake256Reader) -> [u8; N] {
289    let mut data = [0; N];
290    reader.read(&mut data);
291    data
292}
293
294#[cfg(test)]
295mod tests {
296    use getrandom::SysRng;
297    use ml_kem::array::Array;
298    use rand_core::{CryptoRng, RngCore, TryRngCore, utils};
299    use serde::Deserialize;
300
301    use super::*;
302
303    pub(crate) struct SeedRng {
304        pub(crate) seed: Vec<u8>,
305    }
306
307    impl SeedRng {
308        fn new(seed: Vec<u8>) -> SeedRng {
309            SeedRng { seed }
310        }
311    }
312
313    impl RngCore for SeedRng {
314        fn next_u32(&mut self) -> u32 {
315            utils::next_word_via_fill(self)
316        }
317
318        fn next_u64(&mut self) -> u64 {
319            utils::next_word_via_fill(self)
320        }
321
322        fn fill_bytes(&mut self, dest: &mut [u8]) {
323            dest.copy_from_slice(&self.seed[0..dest.len()]);
324            self.seed.drain(0..dest.len());
325        }
326    }
327
328    #[derive(Deserialize)]
329    struct TestVector {
330        #[serde(deserialize_with = "hex::serde::deserialize")]
331        seed: Vec<u8>,
332
333        #[serde(deserialize_with = "hex::serde::deserialize")]
334        eseed: Vec<u8>,
335
336        #[serde(deserialize_with = "hex::serde::deserialize")]
337        ss: [u8; 32],
338
339        #[serde(deserialize_with = "hex::serde::deserialize")]
340        sk: [u8; 32],
341
342        #[serde(deserialize_with = "hex::serde::deserialize")]
343        pk: Vec<u8>, //[u8; PUBLIC_KEY_SIZE],
344
345        #[serde(deserialize_with = "hex::serde::deserialize")]
346        ct: Vec<u8>, //[u8; 1120],
347    }
348
349    impl CryptoRng for SeedRng {}
350
351    /// Test with test vectors from: <https://github.com/dconnolly/draft-connolly-cfrg-xwing-kem/blob/main/spec/test-vectors.json>
352    #[test]
353    fn rfc_test_vectors() {
354        let test_vectors =
355            serde_json::from_str::<Vec<TestVector>>(include_str!("test-vectors.json")).unwrap();
356
357        for test_vector in test_vectors {
358            run_test(test_vector);
359        }
360    }
361
362    fn run_test(test_vector: TestVector) {
363        let mut seed = SeedRng::new(test_vector.seed);
364        let (sk, pk) = generate_key_pair_from_rng(&mut seed);
365
366        assert_eq!(sk.as_bytes(), &test_vector.sk);
367        assert_eq!(&pk.to_bytes(), test_vector.pk.as_slice());
368
369        let mut eseed = SeedRng::new(test_vector.eseed);
370        let (ct, ss) = pk.encapsulate_with_rng(&mut eseed).unwrap();
371
372        assert_eq!(ss, test_vector.ss);
373        assert_eq!(&ct.to_bytes(), test_vector.ct.as_slice());
374
375        let ss = sk.decapsulate(&ct).unwrap();
376        assert_eq!(ss, test_vector.ss);
377    }
378
379    #[test]
380    fn ciphertext_serialize() {
381        let mut rng = SysRng.unwrap_err();
382
383        let ct_a = Ciphertext {
384            ct_m: Array::generate_from_rng(&mut rng),
385            ct_x: <[u8; 32]>::generate_from_rng(&mut rng).into(),
386        };
387
388        let bytes = ct_a.to_bytes();
389        let ct_b = Ciphertext::from(&bytes);
390
391        assert!(ct_a == ct_b);
392    }
393
394    #[test]
395    fn key_serialize() {
396        let sk = DecapsulationKey::generate_from_rng(&mut SysRng.unwrap_err());
397        let pk = sk.encapsulation_key();
398
399        let sk_bytes = sk.as_bytes();
400        let pk_bytes = pk.to_bytes();
401
402        let sk_b = DecapsulationKey::from(*sk_bytes);
403        let pk_b = EncapsulationKey::try_from(&pk_bytes).unwrap();
404
405        assert!(sk == sk_b);
406        assert!(pk == pk_b);
407    }
408}