twenty_first/math/
lattice.rs

1use std::ops::Add;
2use std::ops::AddAssign;
3use std::ops::Mul;
4use std::ops::Sub;
5
6use itertools::Itertools;
7use num_traits::ConstZero;
8use num_traits::Zero;
9use rayon::prelude::IntoParallelIterator;
10use rayon::prelude::ParallelIterator;
11use serde_big_array::BigArray;
12use serde_derive::Deserialize;
13use serde_derive::Serialize;
14
15use super::b_field_element::BFieldElement;
16
17pub fn coset_intt_noswap_64(array: &mut [BFieldElement; 64]) {
18    const N: usize = 64;
19    const N_INV: BFieldElement = BFieldElement::new(18158513693329981441);
20    let powers_of_psi_inv_bitreversed = [
21        BFieldElement::new(1),
22        BFieldElement::new(18446462594437873665),
23        BFieldElement::new(18446742969902956801),
24        BFieldElement::new(18446744069397807105),
25        BFieldElement::new(18442240469788262401),
26        BFieldElement::new(18446744000695107585),
27        BFieldElement::new(17293822564807737345),
28        BFieldElement::new(18446744069414580225),
29        BFieldElement::new(18158513693329981441),
30        BFieldElement::new(18446739671368073217),
31        BFieldElement::new(18446744052234715141),
32        BFieldElement::new(18446744069414322177),
33        BFieldElement::new(18446673700670423041),
34        BFieldElement::new(18446744068340842497),
35        BFieldElement::new(18428729670905102337),
36        BFieldElement::new(18446744069414584257),
37        BFieldElement::new(16140901060737761281),
38        BFieldElement::new(18446708885042495489),
39        BFieldElement::new(18446743931975630881),
40        BFieldElement::new(18446744069412487169),
41        BFieldElement::new(18446181119461294081),
42        BFieldElement::new(18446744060824649729),
43        BFieldElement::new(18302628881338728449),
44        BFieldElement::new(18446744069414583809),
45        BFieldElement::new(18410715272404008961),
46        BFieldElement::new(18446743519658770433),
47        BFieldElement::new(9223372032559808513),
48        BFieldElement::new(18446744069414551553),
49        BFieldElement::new(18446735273321564161),
50        BFieldElement::new(18446744069280366593),
51        BFieldElement::new(18444492269600899073),
52        BFieldElement::new(18446744069414584313),
53        BFieldElement::new(274873712576),
54        BFieldElement::new(274882101184),
55        BFieldElement::new(4611756386097823744),
56        BFieldElement::new(13835128420805115905),
57        BFieldElement::new(288230376151710720),
58        BFieldElement::new(288230376151712768),
59        BFieldElement::new(1125917086449664),
60        BFieldElement::new(18445618186687873025),
61        BFieldElement::new(4294901759),
62        BFieldElement::new(4295032831),
63        BFieldElement::new(72058693532778496),
64        BFieldElement::new(18374687574905061377),
65        BFieldElement::new(4503599627370480),
66        BFieldElement::new(4503599627370512),
67        BFieldElement::new(17592454475776),
68        BFieldElement::new(18446726477496979457),
69        BFieldElement::new(34359214072),
70        BFieldElement::new(34360262648),
71        BFieldElement::new(576469548262227968),
72        BFieldElement::new(17870292113338400769),
73        BFieldElement::new(36028797018963840),
74        BFieldElement::new(36028797018964096),
75        BFieldElement::new(140739635806208),
76        BFieldElement::new(18446603334073745409),
77        BFieldElement::new(2305843009213685760),
78        BFieldElement::new(2305843009213702144),
79        BFieldElement::new(9007336691597312),
80        BFieldElement::new(18437737007600893953),
81        BFieldElement::new(562949953421310),
82        BFieldElement::new(562949953421314),
83        BFieldElement::new(2199056809472),
84        BFieldElement::new(18446741870424883713),
85    ];
86    const LOGN: usize = 6;
87
88    let mut t = 1;
89    let mut h = N / 2;
90    for _ in 0..LOGN {
91        let mut k = 0;
92        for i in 0..h {
93            let zeta = powers_of_psi_inv_bitreversed[h + i];
94            for j in k..(k + t) {
95                let u = array[j];
96                let v = array[j + t];
97                array[j] = u + v;
98                array[j + t] = (u - v) * zeta;
99            }
100
101            k += 2 * t;
102        }
103
104        t *= 2;
105        h >>= 1;
106    }
107
108    for a in array.iter_mut() {
109        *a *= N_INV;
110    }
111}
112
113pub fn coset_ntt_noswap_64(array: &mut [BFieldElement; 64]) {
114    const N: usize = 64;
115
116    let powers_of_psi_bitreversed = [
117        BFieldElement::new(1),
118        BFieldElement::new(281474976710656),
119        BFieldElement::new(16777216),
120        BFieldElement::new(1099511627520),
121        BFieldElement::new(4096),
122        BFieldElement::new(1152921504606846976),
123        BFieldElement::new(68719476736),
124        BFieldElement::new(4503599626321920),
125        BFieldElement::new(64),
126        BFieldElement::new(18014398509481984),
127        BFieldElement::new(1073741824),
128        BFieldElement::new(70368744161280),
129        BFieldElement::new(262144),
130        BFieldElement::new(17179869180),
131        BFieldElement::new(4398046511104),
132        BFieldElement::new(288230376084602880),
133        BFieldElement::new(8),
134        BFieldElement::new(2251799813685248),
135        BFieldElement::new(134217728),
136        BFieldElement::new(8796093020160),
137        BFieldElement::new(32768),
138        BFieldElement::new(9223372036854775808),
139        BFieldElement::new(549755813888),
140        BFieldElement::new(36028797010575360),
141        BFieldElement::new(512),
142        BFieldElement::new(144115188075855872),
143        BFieldElement::new(8589934592),
144        BFieldElement::new(562949953290240),
145        BFieldElement::new(2097152),
146        BFieldElement::new(137438953440),
147        BFieldElement::new(35184372088832),
148        BFieldElement::new(2305843008676823040),
149        BFieldElement::new(2198989700608),
150        BFieldElement::new(18446741870357774849),
151        BFieldElement::new(18446181119461163007),
152        BFieldElement::new(18446181119461163011),
153        BFieldElement::new(9007061813690368),
154        BFieldElement::new(18437736732722987009),
155        BFieldElement::new(16140901060200882177),
156        BFieldElement::new(16140901060200898561),
157        BFieldElement::new(140735340838912),
158        BFieldElement::new(18446603329778778113),
159        BFieldElement::new(18410715272395620225),
160        BFieldElement::new(18410715272395620481),
161        BFieldElement::new(576451956076183552),
162        BFieldElement::new(17870274521152356353),
163        BFieldElement::new(18446744035054321673),
164        BFieldElement::new(18446744035055370249),
165        BFieldElement::new(17591917604864),
166        BFieldElement::new(18446726476960108545),
167        BFieldElement::new(18442240469787213809),
168        BFieldElement::new(18442240469787213841),
169        BFieldElement::new(72056494509522944),
170        BFieldElement::new(18374685375881805825),
171        BFieldElement::new(18446744065119551490),
172        BFieldElement::new(18446744065119682562),
173        BFieldElement::new(1125882726711296),
174        BFieldElement::new(18445618152328134657),
175        BFieldElement::new(18158513693262871553),
176        BFieldElement::new(18158513693262873601),
177        BFieldElement::new(4611615648609468416),
178        BFieldElement::new(13834987683316760577),
179        BFieldElement::new(18446743794532483137),
180        BFieldElement::new(18446743794540871745),
181    ];
182
183    let mut m: usize = 1;
184    let mut t: usize = N;
185    while m < N {
186        t >>= 1;
187
188        for i in 0..m {
189            let s = i * t * 2;
190            let zeta = powers_of_psi_bitreversed[m + i];
191            for j in s..(s + t) {
192                let u = array[j];
193                let v = array[j + t] * zeta;
194                array[j] = u + v;
195                array[j + t] = u - v;
196            }
197        }
198
199        m *= 2;
200    }
201}
202
203pub const CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES: usize = 64;
204
205#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
206pub struct CyclotomicRingElement {
207    #[serde(with = "BigArray")]
208    coefficients: [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES],
209}
210
211impl From<[BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES]> for CyclotomicRingElement {
212    fn from(value: [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES]) -> Self {
213        Self {
214            coefficients: value,
215        }
216    }
217}
218
219impl From<CyclotomicRingElement> for [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES] {
220    fn from(value: CyclotomicRingElement) -> Self {
221        value.coefficients
222    }
223}
224
225impl CyclotomicRingElement {
226    pub fn sample_short(randomness: &[u8]) -> CyclotomicRingElement {
227        debug_assert!(randomness.len() >= 8 * 64);
228        CyclotomicRingElement {
229            coefficients: randomness
230                .chunks(8)
231                .map(|r| TryInto::<[u8; 8]>::try_into(r).unwrap())
232                .map(|r| sample_short_bfield_element(&r))
233                .collect_vec()
234                .try_into()
235                .unwrap(),
236        }
237    }
238
239    pub fn sample_uniform(randomness: &[u8]) -> CyclotomicRingElement {
240        debug_assert!(randomness.len() >= 9 * 64);
241        let mut coefficients = [BFieldElement::ZERO; 64];
242        for i in 0..64 {
243            let mut acc = 0u128;
244            for j in 0..9 {
245                acc = acc * 256 + randomness[i * 9 + j] as u128;
246            }
247            acc %= BFieldElement::P as u128;
248            coefficients[i] = BFieldElement::new(acc as u64);
249        }
250        CyclotomicRingElement { coefficients }
251    }
252
253    pub fn hadamard(a: CyclotomicRingElement, b: CyclotomicRingElement) -> CyclotomicRingElement {
254        let mut c = CyclotomicRingElement::zero();
255        for i in 0..64 {
256            c.coefficients[i] = a.coefficients[i] * b.coefficients[i];
257        }
258        c
259    }
260}
261
262impl Add for CyclotomicRingElement {
263    type Output = CyclotomicRingElement;
264
265    fn add(self, rhs: Self) -> Self::Output {
266        CyclotomicRingElement {
267            coefficients: (0..64)
268                .map(|i| self.coefficients[i] + rhs.coefficients[i])
269                .collect_vec()
270                .try_into()
271                .unwrap(),
272        }
273    }
274}
275
276impl AddAssign for CyclotomicRingElement {
277    fn add_assign(&mut self, rhs: Self) {
278        self.coefficients
279            .iter_mut()
280            .zip(rhs.coefficients.iter())
281            .for_each(|(l, r)| *l += *r);
282    }
283}
284
285impl Sub for CyclotomicRingElement {
286    type Output = CyclotomicRingElement;
287
288    fn sub(self, rhs: Self) -> Self::Output {
289        CyclotomicRingElement {
290            coefficients: (0..64)
291                .map(|i| self.coefficients[i] - rhs.coefficients[i])
292                .collect_vec()
293                .try_into()
294                .unwrap(),
295        }
296    }
297}
298
299impl Mul for CyclotomicRingElement {
300    type Output = CyclotomicRingElement;
301
302    /// Multiply two polynomials in the ring
303    /// `Fp[X] / (X^64 + 1)`
304    /// using `coset-NTT`.
305    fn mul(self, rhs: Self) -> Self::Output {
306        let mut lhs_coeffs = self.coefficients;
307        let mut rhs_coeffs = rhs.coefficients;
308        coset_ntt_noswap_64(&mut lhs_coeffs);
309        coset_ntt_noswap_64(&mut rhs_coeffs);
310        let mut out_coeffs = [BFieldElement::ZERO; 64];
311        for i in 0..64 {
312            out_coeffs[i] = lhs_coeffs[i] * rhs_coeffs[i];
313        }
314        coset_intt_noswap_64(&mut out_coeffs);
315        CyclotomicRingElement {
316            coefficients: out_coeffs,
317        }
318    }
319}
320
321impl Zero for CyclotomicRingElement {
322    fn zero() -> Self {
323        CyclotomicRingElement {
324            coefficients: [BFieldElement::ZERO; 64],
325        }
326    }
327
328    fn is_zero(&self) -> bool {
329        self.coefficients == [BFieldElement::ZERO; 64]
330    }
331}
332
333pub fn embed_msg(msg: [u8; 32]) -> CyclotomicRingElement {
334    let mut embedding: [BFieldElement; 64] = [BFieldElement::ZERO; 64];
335    for i in 0..msg.len() {
336        let mut integer = 0u64;
337        for j in 0..4 {
338            let bit = (msg[i] >> j) & 1;
339            integer += (bit as u64) << (15 + 16 * j);
340        }
341        embedding[2 * i] = BFieldElement::new(integer);
342
343        integer = 0;
344        for j in 0..4 {
345            let bit = (msg[i] >> (4 + j)) & 1;
346            integer += (bit as u64) << (15 + 16 * j);
347        }
348        embedding[2 * i + 1] = BFieldElement::new(integer);
349    }
350    CyclotomicRingElement {
351        coefficients: embedding,
352    }
353}
354
355pub fn extract_msg(embedding: CyclotomicRingElement) -> [u8; 32] {
356    let mut msg = [0u8; 32];
357    for (ctr, pair) in embedding.coefficients.chunks(2).enumerate() {
358        let mut byte = 0u8;
359        let mut value = pair[0].value();
360        for j in 0..4 {
361            let chunk = value & 0xffff;
362            value >>= 16;
363
364            let bit = if chunk < (1 << 14) || (1 << 16) - chunk < (1 << 14) {
365                0
366            } else {
367                1
368            };
369            byte |= bit << j;
370        }
371
372        value = pair[1].value();
373        for j in 0..4 {
374            let chunk = value & 0xffff;
375            value >>= 16;
376
377            let bit = if chunk < (1 << 14) || (1 << 16) - chunk < (1 << 14) {
378                0
379            } else {
380                1
381            };
382            byte |= bit << (4 + j);
383        }
384        msg[ctr] = byte;
385    }
386    msg
387}
388
389const fn num_set_bits(a: u8) -> u8 {
390    let mut sum = 0;
391    let mut i = 0;
392    while i < 8 {
393        let bit = if a & (1 << i) != 0 { 1 } else { 0 };
394        sum += bit;
395        i += 1;
396    }
397    sum
398}
399
400const fn num_set_bits_table() -> [u8; 256] {
401    let mut table: [u8; 256] = [0u8; 256];
402    let mut i = 1;
403    while i < 256 {
404        table[i] = num_set_bits(i as u8);
405        i += 1;
406    }
407    table
408}
409
410pub fn sample_short_bfield_element(randomness: &[u8; 8]) -> BFieldElement {
411    const NUM_SET_BITS: [u8; 256] = num_set_bits_table();
412    let left = ((NUM_SET_BITS[randomness[0] as usize] as u64) << (3 * 16))
413        + ((NUM_SET_BITS[randomness[1] as usize] as u64) << (2 * 16))
414        + ((NUM_SET_BITS[randomness[2] as usize] as u64) << 16)
415        + (NUM_SET_BITS[randomness[3] as usize] as u64);
416    let right = ((NUM_SET_BITS[randomness[4] as usize] as u64) << (3 * 16))
417        + ((NUM_SET_BITS[randomness[5] as usize] as u64) << (2 * 16))
418        + ((NUM_SET_BITS[randomness[6] as usize] as u64) << 16)
419        + (NUM_SET_BITS[randomness[7] as usize] as u64);
420    BFieldElement::new(left) - BFieldElement::new(right)
421}
422
423/// The Module is a matrix over the cyclotomic ring (i.e., the ring
424/// of residue classes of polynomials modulo X^64+1). The matrix
425/// contains N cyclotomic ring elements in total.
426#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
427pub struct ModuleElement<const N: usize> {
428    #[serde(with = "BigArray")]
429    elements: [CyclotomicRingElement; N],
430}
431
432impl<const N: usize> ModuleElement<N> {
433    pub fn sample_short(randomness: &[u8]) -> Self {
434        debug_assert!(randomness.len() >= 8 * 64 * N);
435        let mut elements = [CyclotomicRingElement::zero(); N];
436        for n in 0..N {
437            elements[n] =
438                CyclotomicRingElement::sample_short(&randomness[8 * 64 * n..8 * 64 * (n + 1)]);
439        }
440        Self { elements }
441    }
442
443    pub fn sample_uniform(randomness: &[u8]) -> Self {
444        debug_assert!(randomness.len() >= N * 9 * 64);
445        ModuleElement {
446            elements: (0..N)
447                .map(|i| {
448                    CyclotomicRingElement::sample_uniform(&randomness[i * 9 * 64..(i + 1) * 9 * 64])
449                })
450                .collect_vec()
451                .try_into()
452                .unwrap(),
453        }
454    }
455
456    pub fn ntt(&self) -> Self {
457        let mut copy = *self;
458        for n in 0..N {
459            coset_ntt_noswap_64(&mut copy.elements[n].coefficients);
460        }
461        copy
462    }
463
464    pub fn intt(&self) -> Self {
465        let mut copy = *self;
466        for n in 0..N {
467            coset_intt_noswap_64(&mut copy.elements[n].coefficients);
468        }
469        copy
470    }
471
472    /// Multiply two module elements from a pair of matrix-
473    /// multiplication-compatible modules. This method uses
474    /// hadamard multiplication for cyclotomic ring elements, which
475    /// is useful for avoiding the repeated conversion to and from
476    /// NTT domain.
477    ///  - `N` counts the total number of elements in the matrix;
478    ///  - `H` counts the number of rows of the left hand side (and of
479    ///    the output) matrix;
480    ///  - `W` counts the number of columns of the right hand side (and
481    ///    of the output) matrix;
482    ///  - `INNER` counts the number of columns of the left hand side,
483    ///    as well as the number of rows of the right hand side.
484    pub fn multiply_hadamard<
485        const LHS_H: usize,
486        const LHS_N: usize,
487        const RHS_W: usize,
488        const RHS_N: usize,
489        const INNER: usize,
490        const OUT_N: usize,
491    >(
492        lhs: ModuleElement<LHS_N>,
493        rhs: ModuleElement<RHS_N>,
494    ) -> ModuleElement<OUT_N> {
495        debug_assert_eq!(LHS_H * INNER, LHS_N);
496        debug_assert_eq!(INNER * RHS_W, RHS_N);
497        debug_assert_eq!(LHS_H * RHS_W, OUT_N);
498
499        let mut elements = [CyclotomicRingElement::zero(); OUT_N];
500        for h in 0..LHS_H {
501            for w in 0..RHS_W {
502                for i in 0..INNER {
503                    elements[h * RHS_W + w] += CyclotomicRingElement::hadamard(
504                        lhs.elements[h * INNER + i],
505                        rhs.elements[i * RHS_W + w],
506                    );
507                }
508            }
509        }
510
511        ModuleElement { elements }
512    }
513
514    /// Multiply two module elements from a pair of matrix-
515    /// multiplication-compatible modules. This method uses the
516    /// multiplication defined for cyclotomic ring elements
517    /// abstractly. For a faster method that computes the entire
518    /// matrix multiplication in the NTT domain, use `fast_multiply`.
519    ///  - `N` counts the total number of elements in the matrix;
520    ///  - `H` counts the number of rows of the left hand side (and of
521    ///    the output) matrix;
522    ///  - `W` counts the number of columns of the right hand side (and
523    ///    of the output) matrix;
524    ///  - `INNER` counts the number of columns of the left hand side,
525    ///    as well as the number of rows of the right hand side.
526    pub fn multiply<
527        const LHS_H: usize,
528        const LHS_N: usize,
529        const RHS_W: usize,
530        const RHS_N: usize,
531        const INNER: usize,
532        const OUT_N: usize,
533    >(
534        lhs: ModuleElement<LHS_N>,
535        rhs: ModuleElement<RHS_N>,
536    ) -> ModuleElement<OUT_N> {
537        debug_assert_eq!(LHS_H * INNER, LHS_N);
538        debug_assert_eq!(INNER * RHS_W, RHS_N);
539        debug_assert_eq!(LHS_H * RHS_W, OUT_N);
540
541        let mut out = ModuleElement {
542            elements: [CyclotomicRingElement::zero(); OUT_N],
543        };
544        for h in 0..LHS_H {
545            for w in 0..RHS_W {
546                for i in 0..INNER {
547                    out.elements[h * RHS_W + w] +=
548                        lhs.elements[h * INNER + i] * rhs.elements[i * RHS_W + w];
549                }
550            }
551        }
552
553        out
554    }
555
556    /// Multiply two module elements from a pair of matrix-
557    /// multiplication-compatible modules, by converting everything
558    /// into the NTT domain, performing the matrix multiplication,
559    /// and converting back.
560    ///  - `N` counts the total number of elements in the matrix;
561    ///  - `H` counts the number of rows of the left hand side (and of
562    ///    the output) matrix;
563    ///  - `W` counts the number of columns of the right hand side (and
564    ///    of the output) matrix;
565    ///  - `INNER` counts the number of columns of the left hand side,
566    ///    as well as the number of rows of the right hand side.
567    pub fn fast_multiply<
568        const LHS_H: usize,
569        const LHS_N: usize,
570        const RHS_W: usize,
571        const RHS_N: usize,
572        const INNER: usize,
573        const OUT_N: usize,
574    >(
575        lhs: ModuleElement<LHS_N>,
576        rhs: ModuleElement<RHS_N>,
577    ) -> ModuleElement<OUT_N> {
578        debug_assert_eq!(LHS_H * INNER, LHS_N);
579        debug_assert_eq!(INNER * RHS_W, RHS_N);
580        debug_assert_eq!(LHS_H * RHS_W, OUT_N);
581
582        let lhs_ntt = lhs.ntt();
583        let rhs_ntt = rhs.ntt();
584
585        let out_ntt =
586            Self::multiply_hadamard::<LHS_H, LHS_N, RHS_W, RHS_N, INNER, OUT_N>(lhs_ntt, rhs_ntt);
587
588        out_ntt.intt()
589    }
590}
591
592impl<const N: usize> Add for ModuleElement<N> {
593    type Output = ModuleElement<N>;
594
595    fn add(self, rhs: Self) -> Self::Output {
596        let elements: [CyclotomicRingElement; N] = (0..N)
597            .into_par_iter()
598            .map(|i| self.elements[i] + rhs.elements[i])
599            .collect::<Vec<_>>()
600            .try_into()
601            .unwrap();
602        ModuleElement::<N> { elements }
603    }
604}
605
606impl<const N: usize> Sub for ModuleElement<N> {
607    type Output = ModuleElement<N>;
608
609    fn sub(self, rhs: Self) -> Self::Output {
610        let elements: [CyclotomicRingElement; N] = (0..N)
611            .into_par_iter()
612            .map(|i| self.elements[i] - rhs.elements[i])
613            .collect::<Vec<_>>()
614            .try_into()
615            .unwrap();
616        ModuleElement::<N> { elements }
617    }
618}
619
620impl<const N: usize> Zero for ModuleElement<N> {
621    fn zero() -> Self {
622        Self {
623            elements: [CyclotomicRingElement::zero(); N],
624        }
625    }
626
627    fn is_zero(&self) -> bool {
628        *self == Self::zero()
629    }
630}
631
632pub mod kem {
633    use itertools::Itertools;
634    use serde_derive::Deserialize;
635    use serde_derive::Serialize;
636    use sha3::Digest as Sha3Digest;
637    use sha3::Sha3_256;
638    use sha3::Shake256;
639    use sha3::digest::ExtendableOutput;
640    use sha3::digest::Update;
641    use zeroize::Zeroize;
642
643    use super::CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES;
644    use super::CyclotomicRingElement;
645    use super::ModuleElement;
646    use super::embed_msg;
647    use super::extract_msg;
648    use crate::math::b_field_element::BFieldElement;
649
650    #[derive(PartialEq, Eq, Copy, Clone, Debug, Serialize, Deserialize, Zeroize)]
651    pub struct SecretKey {
652        key: [u8; 32],
653        seed: [u8; 32],
654    }
655
656    #[derive(PartialEq, Eq, Copy, Clone, Debug, Serialize, Deserialize)]
657    pub struct PublicKey {
658        seed: [u8; 32],
659        ga: ModuleElement<4>,
660    }
661
662    #[derive(PartialEq, Eq, Copy, Clone, Debug, Serialize, Deserialize)]
663    pub struct Ciphertext {
664        bg: ModuleElement<4>,
665        bga_m: ModuleElement<1>,
666    }
667
668    pub const CIPHERTEXT_SIZE_IN_BFES: usize = CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES * 5;
669
670    impl From<[BFieldElement; CIPHERTEXT_SIZE_IN_BFES]> for Ciphertext {
671        fn from(value: [BFieldElement; CIPHERTEXT_SIZE_IN_BFES]) -> Self {
672            let (bg_slice, bga_m_slice) = value.split_at(4 * CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES);
673
674            let bg_array: [BFieldElement; 4 * CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES] =
675                bg_slice.try_into().unwrap();
676            let bga_m_array: [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES] =
677                bga_m_slice.try_into().unwrap();
678
679            let bg_module = ModuleElement {
680                elements: bg_array
681                    .chunks(CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES)
682                    .map(|sl| {
683                        CyclotomicRingElement::from(
684                            std::convert::TryInto::<
685                                [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES],
686                            >::try_into(sl)
687                            .unwrap(),
688                        )
689                    })
690                    .collect_vec()
691                    .try_into()
692                    .unwrap(),
693            };
694            let bga_m_module = ModuleElement {
695                elements: [CyclotomicRingElement::from(bga_m_array); 1],
696            };
697
698            Self {
699                bg: bg_module,
700                bga_m: bga_m_module,
701            }
702        }
703    }
704
705    impl From<Ciphertext> for [BFieldElement; CIPHERTEXT_SIZE_IN_BFES] {
706        fn from(value: Ciphertext) -> Self {
707            let bg_slice = value
708                .bg
709                .elements
710                .iter()
711                .flat_map(|e| e.coefficients)
712                .collect_vec();
713            let bga_m_slice = value
714                .bga_m
715                .elements
716                .iter()
717                .flat_map(|e| e.coefficients)
718                .collect_vec();
719            [bg_slice, bga_m_slice].concat().try_into().unwrap()
720        }
721    }
722
723    /// randomness extension
724    pub(super) fn shake256<const NUM_OUT_BYTES: usize>(
725        randomness: impl AsRef<[u8]>,
726    ) -> [u8; NUM_OUT_BYTES] {
727        let mut hasher = Shake256::default();
728        hasher.update(randomness.as_ref());
729
730        let mut result = [0u8; NUM_OUT_BYTES];
731        hasher.finalize_xof_into(&mut result);
732        result
733    }
734
735    fn derive_public_matrix(seed: &[u8; 32]) -> ModuleElement<16> {
736        const NUM_BYTES: usize = 9 * 64 * 16;
737        let randomness = shake256::<NUM_BYTES>(seed);
738        ModuleElement::<16>::sample_uniform(&randomness)
739    }
740
741    fn derive_secret_vectors(seed: &[u8; 32]) -> (ModuleElement<4>, ModuleElement<4>) {
742        const NUM_BYTES: usize = 2 * 4 * 64 * 8;
743        let randomness = shake256::<NUM_BYTES>(seed);
744        let a = ModuleElement::<4>::sample_short(&randomness[0..(NUM_BYTES / 2)]);
745        let b = ModuleElement::<4>::sample_short(&randomness[(NUM_BYTES / 2)..]);
746        (a, b)
747    }
748
749    /// Generate a public-secret key pair for key encapsulation.
750    pub fn keygen(randomness: [u8; 32]) -> (SecretKey, PublicKey) {
751        const OUTPUT_LENGTH: usize = 32;
752        let seed: [u8; OUTPUT_LENGTH] = shake256([randomness.to_vec(), vec![0u8]].concat());
753        let key: [u8; OUTPUT_LENGTH] = shake256([randomness.to_vec(), vec![1u8]].concat());
754
755        let sk = SecretKey { key, seed };
756
757        let pk = derive_public_key(&key, &seed);
758        (sk, pk)
759    }
760
761    fn derive_public_key(key: &[u8; 32], seed: &[u8; 32]) -> PublicKey {
762        let (a, c) = derive_secret_vectors(key);
763        let g = derive_public_matrix(seed);
764        let ga = ModuleElement::<16>::multiply_hadamard::<4, 16, 1, 4, 4, 4>(g, a.ntt()) + c.ntt();
765
766        PublicKey { seed: *seed, ga }
767    }
768
769    /// Generate a ciphertext with the given seed (`payload`) from
770    /// which to derive all randomness.
771    fn generate_ciphertext_derandomized(pk: PublicKey, payload: [u8; 32]) -> Ciphertext {
772        let (b, d) = derive_secret_vectors(&payload);
773        let b_ntt = b.ntt();
774        let d_ntt = d.ntt();
775        let g = derive_public_matrix(&pk.seed);
776        let bg = ModuleElement::<9>::multiply_hadamard::<1, 4, 4, 16, 4, 4>(b_ntt, g) + d_ntt;
777
778        let m = embed_msg(payload);
779        let bga_m = ModuleElement::<3>::multiply_hadamard::<1, 4, 1, 4, 4, 1>(b_ntt, pk.ga)
780            + ModuleElement::<1> { elements: [m] }.ntt();
781
782        Ciphertext { bg, bga_m }
783    }
784
785    /// Encapsulate: generate a ciphertext and an associated shared
786    /// symmetric key.
787    pub fn enc(pk: PublicKey, randomness: [u8; 32]) -> ([u8; 32], Ciphertext) {
788        const OUTPUT_LENGTH: usize = 32;
789        let payload: [u8; OUTPUT_LENGTH] = shake256(randomness);
790        let ciphertext = generate_ciphertext_derandomized(pk, payload);
791        let shared_key: [u8; 32] = Sha3_256::digest(payload).into();
792
793        (shared_key, ciphertext)
794    }
795
796    /// Decapsulate: use the secret key to extract the corresponding
797    /// shared symmetric key from a ciphertext (if successful).
798    pub fn dec(sk: SecretKey, ctxt: Ciphertext) -> Option<[u8; 32]> {
799        let (a, _) = derive_secret_vectors(&sk.key);
800        let bga = ModuleElement::<3>::multiply_hadamard::<1, 4, 1, 4, 4, 1>(ctxt.bg, a.ntt());
801        let m = (ctxt.bga_m - bga).intt();
802        let payload = extract_msg(m.elements[0]);
803
804        let pk = derive_public_key(&sk.key, &sk.seed);
805        let regenerated_ciphertext = generate_ciphertext_derandomized(pk, payload);
806
807        if regenerated_ciphertext != ctxt {
808            return None;
809        }
810
811        let shared_key = Sha3_256::digest(payload).into();
812        Some(shared_key)
813    }
814
815    #[cfg(test)]
816    #[cfg_attr(coverage_nightly, coverage(off))]
817    mod tests {
818        use super::*;
819
820        #[test]
821        fn secret_key_zeroize_test() {
822            let mut secret_key = SecretKey {
823                key: rand::random(),
824                seed: rand::random(),
825            };
826
827            secret_key.zeroize();
828
829            assert_eq!([0; 32], secret_key.key);
830            assert_eq!([0; 32], secret_key.seed);
831        }
832    }
833}
834
835#[cfg(test)]
836#[cfg_attr(coverage_nightly, coverage(off))]
837mod tests {
838    use itertools::Itertools;
839    use num_traits::ConstOne;
840    use num_traits::Zero;
841    use rand::RngCore;
842    use rand::random;
843    use sha3::Digest as Sha3Digest;
844    use sha3::Sha3_256;
845
846    use super::kem::CIPHERTEXT_SIZE_IN_BFES;
847    use super::kem::SecretKey;
848    use super::kem::shake256;
849    use crate::math::b_field_element::BFieldElement;
850    use crate::math::lattice::kem::Ciphertext;
851    use crate::math::lattice::kem::PublicKey;
852    use crate::math::lattice::*;
853
854    #[test]
855    fn test_kats() {
856        // KATs lifted from
857        // https://github.com/XKCP/XKCP/blob/master/tests/UnitTests/main.c
858        // starting at line 446.
859        let input = b"\x21\xF1\x34\xAC\x57";
860        let expected_output_shake256 = b"\xBB\x8A\x84\x47\x51\x7B\xA9\xCA\x7F\xA3\x4E\xC9\x9A\x80\
861        \x00\x4F\x22\x8A\xB2\x82\x47\x28\x41\xEB\x3D\x3A\x76\x22\x5C\x9D\xBE\x77\xF7\xE4\x0A\x06\
862        \x67\x76\xD3\x2C\x74\x94\x12\x02\xF9\xF4\xAA\x43\xD1\x2C\x62\x64\xAF\xA5\x96\x39\xC4\x4E\
863        \x11\xF5\xE1\x4F\x1E\x56";
864        let expected_output_sha3_256 = b"\x55\xBD\x92\x24\xAF\x4E\xED\x0D\x12\x11\x49\xE3\x7F\xF4\
865        \xD7\xDD\x5B\xE2\x4B\xD9\xFB\xE5\x6E\x01\x71\xE8\x7D\xB7\xA6\xF4\xE0\x6D";
866
867        assert_eq!(*expected_output_shake256, shake256(input));
868
869        let sha3_out = Sha3_256::digest(input).to_vec();
870        assert_eq!(expected_output_sha3_256, &*sha3_out);
871    }
872
873    #[test]
874    fn test_fast_mul() {
875        let a: [BFieldElement; 64] = random();
876        let b: [BFieldElement; 64] = random();
877
878        let mut c_schoolbook = [BFieldElement::ZERO; 64];
879        for i in 0..64 {
880            for j in 0..64 {
881                if i + j >= 64 {
882                    c_schoolbook[i + j - 64] -= a[i] * b[j];
883                } else {
884                    c_schoolbook[i + j] += a[i] * b[j];
885                }
886            }
887        }
888
889        let c_fast = (CyclotomicRingElement { coefficients: a }
890            * CyclotomicRingElement { coefficients: b })
891        .coefficients;
892
893        assert_eq!(c_fast, c_schoolbook);
894    }
895
896    #[test]
897    fn test_embedding() {
898        let mut rng = rand::rng();
899        let msg: [u8; 32] = (0..32)
900            .map(|_| (rng.next_u32() % 256) as u8)
901            .collect_vec()
902            .try_into()
903            .unwrap();
904        let embedding = embed_msg(msg);
905        let extracted = extract_msg(embedding);
906
907        assert_eq!(msg, extracted);
908    }
909
910    #[test]
911    fn test_module_distributivity() {
912        let mut rng = rand::rng();
913        let randomness = (0..(2 * 3 + 2 * 3 + 3) * 64 * 9)
914            .map(|_| (rng.next_u32() % 256) as u8)
915            .collect_vec();
916        let mut start = 0;
917        let mut stop = 2 * 3 * 9 * 64;
918        let a = ModuleElement::<{ 2 * 3 }>::sample_uniform(&randomness[start..stop]);
919        start = stop;
920        stop += 2 * 3 * 9 * 64;
921        let b = ModuleElement::<{ 2 * 3 }>::sample_uniform(&randomness[start..stop]);
922        start = stop;
923        stop += 3 * 9 * 64;
924        let c = ModuleElement::<3>::sample_uniform(&randomness[start..stop]);
925
926        let sumprod = ModuleElement::<1>::multiply::<2, 6, 1, 3, 3, 2>(a + b, c);
927        let prodsum = ModuleElement::<1>::multiply::<2, 6, 1, 3, 3, 2>(a, c)
928            + ModuleElement::<1>::multiply::<2, 6, 1, 3, 3, 2>(b, c);
929
930        assert_eq!(sumprod, prodsum);
931    }
932
933    #[test]
934    fn test_module_multiply() {
935        let mut rng = rand::rng();
936        let randomness = (0..(2 * 3 + 2 * 3 + 3) * 64 * 9)
937            .map(|_| (rng.next_u32() % 256) as u8)
938            .collect_vec();
939        let mut start = 0;
940        let mut stop = 2 * 3 * 9 * 64;
941        let a = ModuleElement::<{ 2 * 3 }>::sample_uniform(&randomness[start..stop]);
942        start = stop;
943        stop += 3 * 2 * 9 * 64;
944        let b = ModuleElement::<{ 2 * 3 }>::sample_uniform(&randomness[start..stop]);
945
946        assert_eq!(
947            ModuleElement::<1>::fast_multiply::<2, 6, 2, 6, 3, 4>(a, b),
948            ModuleElement::<1>::multiply::<2, 6, 2, 6, 3, 4>(a, b)
949        );
950    }
951
952    #[test]
953    fn test_kem() {
954        let mut rng = rand::rng();
955        let mut key_randomness: [u8; 32] = [0u8; 32];
956        rng.fill_bytes(&mut key_randomness);
957        let mut ctxt_randomness: [u8; 32] = [0u8; 32];
958        rng.fill_bytes(&mut ctxt_randomness);
959        // correctness
960        let (sk, pk) = kem::keygen(key_randomness);
961        let (alice_key, ctxt) = kem::enc(pk, ctxt_randomness);
962        if let Some(bob_key) = kem::dec(sk, ctxt) {
963            assert_eq!(alice_key, bob_key);
964        } else {
965            panic!()
966        }
967
968        // sanity
969        rng.fill_bytes(&mut key_randomness);
970        let (other_sk, _) = kem::keygen(key_randomness);
971        assert!(kem::dec(other_sk, ctxt).is_none());
972    }
973
974    #[test]
975    fn test_ciphertext_conversion() {
976        let bfes: [BFieldElement; CIPHERTEXT_SIZE_IN_BFES] = random();
977        let ciphertext: Ciphertext = bfes.into();
978        let bfes_again: [BFieldElement; CIPHERTEXT_SIZE_IN_BFES] = ciphertext.into();
979        let ciphertext_again: Ciphertext = bfes_again.into();
980
981        assert_eq!(bfes, bfes_again);
982        assert_eq!(ciphertext, ciphertext_again);
983    }
984
985    #[test]
986    fn zero_test() {
987        let zero_me = ModuleElement::<4>::zero();
988        assert!(zero_me.is_zero(), "zero must be zero");
989        let not_zero = ModuleElement {
990            elements: [CyclotomicRingElement {
991                coefficients: [BFieldElement::ONE; 64],
992            }; 4],
993        };
994        assert!(!not_zero.is_zero(), "not-zero must be not be zero");
995    }
996
997    #[test]
998    fn serialization_deserialization_test() {
999        // This is tested here since the serialization for these objects is a bit more complicated
1000        // than the standard serde stuff. So to be sure that it works, we just run this test here.
1001        let mut rng = rand::rng();
1002        let mut key_randomness: [u8; 32] = [0u8; 32];
1003        rng.fill_bytes(&mut key_randomness);
1004        let mut ctxt_randomness: [u8; 32] = [0u8; 32];
1005        rng.fill_bytes(&mut ctxt_randomness);
1006        let (sk, pk) = kem::keygen(key_randomness);
1007        let (alice_key, ctxt) = kem::enc(pk, ctxt_randomness);
1008
1009        let sk_as_json: String = serde_json::to_string(&sk).unwrap();
1010        let sk_again = serde_json::from_str::<SecretKey>(&sk_as_json).unwrap();
1011        assert_eq!(sk, sk_again);
1012
1013        let pk_as_json: String = serde_json::to_string(&pk).unwrap();
1014        let pk_again = serde_json::from_str::<PublicKey>(&pk_as_json).unwrap();
1015        assert_eq!(pk, pk_again);
1016
1017        let ctxt_as_json: String = serde_json::to_string(&ctxt).unwrap();
1018        let ctxt_again = serde_json::from_str::<Ciphertext>(&ctxt_as_json).unwrap();
1019        assert_eq!(ctxt, ctxt_again);
1020
1021        let alice_key_as_json: String = serde_json::to_string(&alice_key).unwrap();
1022        let alice_key_again = serde_json::from_str::<[u8; 32]>(&alice_key_as_json).unwrap();
1023        assert_eq!(alice_key, alice_key_again);
1024    }
1025}