1use std::ops::Add;
2use std::ops::AddAssign;
3use std::ops::Mul;
4use std::ops::Sub;
5
6use itertools::Itertools;
7use num_traits::ConstZero;
8use num_traits::Zero;
9use rayon::prelude::IntoParallelIterator;
10use rayon::prelude::ParallelIterator;
11use serde_big_array::BigArray;
12use serde_derive::Deserialize;
13use serde_derive::Serialize;
14
15use super::b_field_element::BFieldElement;
16
17pub fn coset_intt_noswap_64(array: &mut [BFieldElement; 64]) {
18 const N: usize = 64;
19 const N_INV: BFieldElement = BFieldElement::new(18158513693329981441);
20 let powers_of_psi_inv_bitreversed = [
21 BFieldElement::new(1),
22 BFieldElement::new(18446462594437873665),
23 BFieldElement::new(18446742969902956801),
24 BFieldElement::new(18446744069397807105),
25 BFieldElement::new(18442240469788262401),
26 BFieldElement::new(18446744000695107585),
27 BFieldElement::new(17293822564807737345),
28 BFieldElement::new(18446744069414580225),
29 BFieldElement::new(18158513693329981441),
30 BFieldElement::new(18446739671368073217),
31 BFieldElement::new(18446744052234715141),
32 BFieldElement::new(18446744069414322177),
33 BFieldElement::new(18446673700670423041),
34 BFieldElement::new(18446744068340842497),
35 BFieldElement::new(18428729670905102337),
36 BFieldElement::new(18446744069414584257),
37 BFieldElement::new(16140901060737761281),
38 BFieldElement::new(18446708885042495489),
39 BFieldElement::new(18446743931975630881),
40 BFieldElement::new(18446744069412487169),
41 BFieldElement::new(18446181119461294081),
42 BFieldElement::new(18446744060824649729),
43 BFieldElement::new(18302628881338728449),
44 BFieldElement::new(18446744069414583809),
45 BFieldElement::new(18410715272404008961),
46 BFieldElement::new(18446743519658770433),
47 BFieldElement::new(9223372032559808513),
48 BFieldElement::new(18446744069414551553),
49 BFieldElement::new(18446735273321564161),
50 BFieldElement::new(18446744069280366593),
51 BFieldElement::new(18444492269600899073),
52 BFieldElement::new(18446744069414584313),
53 BFieldElement::new(274873712576),
54 BFieldElement::new(274882101184),
55 BFieldElement::new(4611756386097823744),
56 BFieldElement::new(13835128420805115905),
57 BFieldElement::new(288230376151710720),
58 BFieldElement::new(288230376151712768),
59 BFieldElement::new(1125917086449664),
60 BFieldElement::new(18445618186687873025),
61 BFieldElement::new(4294901759),
62 BFieldElement::new(4295032831),
63 BFieldElement::new(72058693532778496),
64 BFieldElement::new(18374687574905061377),
65 BFieldElement::new(4503599627370480),
66 BFieldElement::new(4503599627370512),
67 BFieldElement::new(17592454475776),
68 BFieldElement::new(18446726477496979457),
69 BFieldElement::new(34359214072),
70 BFieldElement::new(34360262648),
71 BFieldElement::new(576469548262227968),
72 BFieldElement::new(17870292113338400769),
73 BFieldElement::new(36028797018963840),
74 BFieldElement::new(36028797018964096),
75 BFieldElement::new(140739635806208),
76 BFieldElement::new(18446603334073745409),
77 BFieldElement::new(2305843009213685760),
78 BFieldElement::new(2305843009213702144),
79 BFieldElement::new(9007336691597312),
80 BFieldElement::new(18437737007600893953),
81 BFieldElement::new(562949953421310),
82 BFieldElement::new(562949953421314),
83 BFieldElement::new(2199056809472),
84 BFieldElement::new(18446741870424883713),
85 ];
86 const LOGN: usize = 6;
87
88 let mut t = 1;
89 let mut h = N / 2;
90 for _ in 0..LOGN {
91 let mut k = 0;
92 for i in 0..h {
93 let zeta = powers_of_psi_inv_bitreversed[h + i];
94 for j in k..(k + t) {
95 let u = array[j];
96 let v = array[j + t];
97 array[j] = u + v;
98 array[j + t] = (u - v) * zeta;
99 }
100
101 k += 2 * t;
102 }
103
104 t *= 2;
105 h >>= 1;
106 }
107
108 for a in array.iter_mut() {
109 *a *= N_INV;
110 }
111}
112
113pub fn coset_ntt_noswap_64(array: &mut [BFieldElement; 64]) {
114 const N: usize = 64;
115
116 let powers_of_psi_bitreversed = [
117 BFieldElement::new(1),
118 BFieldElement::new(281474976710656),
119 BFieldElement::new(16777216),
120 BFieldElement::new(1099511627520),
121 BFieldElement::new(4096),
122 BFieldElement::new(1152921504606846976),
123 BFieldElement::new(68719476736),
124 BFieldElement::new(4503599626321920),
125 BFieldElement::new(64),
126 BFieldElement::new(18014398509481984),
127 BFieldElement::new(1073741824),
128 BFieldElement::new(70368744161280),
129 BFieldElement::new(262144),
130 BFieldElement::new(17179869180),
131 BFieldElement::new(4398046511104),
132 BFieldElement::new(288230376084602880),
133 BFieldElement::new(8),
134 BFieldElement::new(2251799813685248),
135 BFieldElement::new(134217728),
136 BFieldElement::new(8796093020160),
137 BFieldElement::new(32768),
138 BFieldElement::new(9223372036854775808),
139 BFieldElement::new(549755813888),
140 BFieldElement::new(36028797010575360),
141 BFieldElement::new(512),
142 BFieldElement::new(144115188075855872),
143 BFieldElement::new(8589934592),
144 BFieldElement::new(562949953290240),
145 BFieldElement::new(2097152),
146 BFieldElement::new(137438953440),
147 BFieldElement::new(35184372088832),
148 BFieldElement::new(2305843008676823040),
149 BFieldElement::new(2198989700608),
150 BFieldElement::new(18446741870357774849),
151 BFieldElement::new(18446181119461163007),
152 BFieldElement::new(18446181119461163011),
153 BFieldElement::new(9007061813690368),
154 BFieldElement::new(18437736732722987009),
155 BFieldElement::new(16140901060200882177),
156 BFieldElement::new(16140901060200898561),
157 BFieldElement::new(140735340838912),
158 BFieldElement::new(18446603329778778113),
159 BFieldElement::new(18410715272395620225),
160 BFieldElement::new(18410715272395620481),
161 BFieldElement::new(576451956076183552),
162 BFieldElement::new(17870274521152356353),
163 BFieldElement::new(18446744035054321673),
164 BFieldElement::new(18446744035055370249),
165 BFieldElement::new(17591917604864),
166 BFieldElement::new(18446726476960108545),
167 BFieldElement::new(18442240469787213809),
168 BFieldElement::new(18442240469787213841),
169 BFieldElement::new(72056494509522944),
170 BFieldElement::new(18374685375881805825),
171 BFieldElement::new(18446744065119551490),
172 BFieldElement::new(18446744065119682562),
173 BFieldElement::new(1125882726711296),
174 BFieldElement::new(18445618152328134657),
175 BFieldElement::new(18158513693262871553),
176 BFieldElement::new(18158513693262873601),
177 BFieldElement::new(4611615648609468416),
178 BFieldElement::new(13834987683316760577),
179 BFieldElement::new(18446743794532483137),
180 BFieldElement::new(18446743794540871745),
181 ];
182
183 let mut m: usize = 1;
184 let mut t: usize = N;
185 while m < N {
186 t >>= 1;
187
188 for i in 0..m {
189 let s = i * t * 2;
190 let zeta = powers_of_psi_bitreversed[m + i];
191 for j in s..(s + t) {
192 let u = array[j];
193 let v = array[j + t] * zeta;
194 array[j] = u + v;
195 array[j + t] = u - v;
196 }
197 }
198
199 m *= 2;
200 }
201}
202
203pub const CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES: usize = 64;
204
205#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
206pub struct CyclotomicRingElement {
207 #[serde(with = "BigArray")]
208 coefficients: [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES],
209}
210
211impl From<[BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES]> for CyclotomicRingElement {
212 fn from(value: [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES]) -> Self {
213 Self {
214 coefficients: value,
215 }
216 }
217}
218
219impl From<CyclotomicRingElement> for [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES] {
220 fn from(value: CyclotomicRingElement) -> Self {
221 value.coefficients
222 }
223}
224
225impl CyclotomicRingElement {
226 pub fn sample_short(randomness: &[u8]) -> CyclotomicRingElement {
227 debug_assert!(randomness.len() >= 8 * 64);
228 CyclotomicRingElement {
229 coefficients: randomness
230 .chunks(8)
231 .map(|r| TryInto::<[u8; 8]>::try_into(r).unwrap())
232 .map(|r| sample_short_bfield_element(&r))
233 .collect_vec()
234 .try_into()
235 .unwrap(),
236 }
237 }
238
239 pub fn sample_uniform(randomness: &[u8]) -> CyclotomicRingElement {
240 debug_assert!(randomness.len() >= 9 * 64);
241 let mut coefficients = [BFieldElement::ZERO; 64];
242 for i in 0..64 {
243 let mut acc = 0u128;
244 for j in 0..9 {
245 acc = acc * 256 + randomness[i * 9 + j] as u128;
246 }
247 acc %= BFieldElement::P as u128;
248 coefficients[i] = BFieldElement::new(acc as u64);
249 }
250 CyclotomicRingElement { coefficients }
251 }
252
253 pub fn hadamard(a: CyclotomicRingElement, b: CyclotomicRingElement) -> CyclotomicRingElement {
254 let mut c = CyclotomicRingElement::zero();
255 for i in 0..64 {
256 c.coefficients[i] = a.coefficients[i] * b.coefficients[i];
257 }
258 c
259 }
260}
261
262impl Add for CyclotomicRingElement {
263 type Output = CyclotomicRingElement;
264
265 fn add(self, rhs: Self) -> Self::Output {
266 CyclotomicRingElement {
267 coefficients: (0..64)
268 .map(|i| self.coefficients[i] + rhs.coefficients[i])
269 .collect_vec()
270 .try_into()
271 .unwrap(),
272 }
273 }
274}
275
276impl AddAssign for CyclotomicRingElement {
277 fn add_assign(&mut self, rhs: Self) {
278 self.coefficients
279 .iter_mut()
280 .zip(rhs.coefficients.iter())
281 .for_each(|(l, r)| *l += *r);
282 }
283}
284
285impl Sub for CyclotomicRingElement {
286 type Output = CyclotomicRingElement;
287
288 fn sub(self, rhs: Self) -> Self::Output {
289 CyclotomicRingElement {
290 coefficients: (0..64)
291 .map(|i| self.coefficients[i] - rhs.coefficients[i])
292 .collect_vec()
293 .try_into()
294 .unwrap(),
295 }
296 }
297}
298
299impl Mul for CyclotomicRingElement {
300 type Output = CyclotomicRingElement;
301
302 fn mul(self, rhs: Self) -> Self::Output {
306 let mut lhs_coeffs = self.coefficients;
307 let mut rhs_coeffs = rhs.coefficients;
308 coset_ntt_noswap_64(&mut lhs_coeffs);
309 coset_ntt_noswap_64(&mut rhs_coeffs);
310 let mut out_coeffs = [BFieldElement::ZERO; 64];
311 for i in 0..64 {
312 out_coeffs[i] = lhs_coeffs[i] * rhs_coeffs[i];
313 }
314 coset_intt_noswap_64(&mut out_coeffs);
315 CyclotomicRingElement {
316 coefficients: out_coeffs,
317 }
318 }
319}
320
321impl Zero for CyclotomicRingElement {
322 fn zero() -> Self {
323 CyclotomicRingElement {
324 coefficients: [BFieldElement::ZERO; 64],
325 }
326 }
327
328 fn is_zero(&self) -> bool {
329 self.coefficients == [BFieldElement::ZERO; 64]
330 }
331}
332
333pub fn embed_msg(msg: [u8; 32]) -> CyclotomicRingElement {
334 let mut embedding: [BFieldElement; 64] = [BFieldElement::ZERO; 64];
335 for i in 0..msg.len() {
336 let mut integer = 0u64;
337 for j in 0..4 {
338 let bit = (msg[i] >> j) & 1;
339 integer += (bit as u64) << (15 + 16 * j);
340 }
341 embedding[2 * i] = BFieldElement::new(integer);
342
343 integer = 0;
344 for j in 0..4 {
345 let bit = (msg[i] >> (4 + j)) & 1;
346 integer += (bit as u64) << (15 + 16 * j);
347 }
348 embedding[2 * i + 1] = BFieldElement::new(integer);
349 }
350 CyclotomicRingElement {
351 coefficients: embedding,
352 }
353}
354
355pub fn extract_msg(embedding: CyclotomicRingElement) -> [u8; 32] {
356 let mut msg = [0u8; 32];
357 for (ctr, pair) in embedding.coefficients.chunks(2).enumerate() {
358 let mut byte = 0u8;
359 let mut value = pair[0].value();
360 for j in 0..4 {
361 let chunk = value & 0xffff;
362 value >>= 16;
363
364 let bit = if chunk < (1 << 14) || (1 << 16) - chunk < (1 << 14) {
365 0
366 } else {
367 1
368 };
369 byte |= bit << j;
370 }
371
372 value = pair[1].value();
373 for j in 0..4 {
374 let chunk = value & 0xffff;
375 value >>= 16;
376
377 let bit = if chunk < (1 << 14) || (1 << 16) - chunk < (1 << 14) {
378 0
379 } else {
380 1
381 };
382 byte |= bit << (4 + j);
383 }
384 msg[ctr] = byte;
385 }
386 msg
387}
388
389const fn num_set_bits(a: u8) -> u8 {
390 let mut sum = 0;
391 let mut i = 0;
392 while i < 8 {
393 let bit = if a & (1 << i) != 0 { 1 } else { 0 };
394 sum += bit;
395 i += 1;
396 }
397 sum
398}
399
400const fn num_set_bits_table() -> [u8; 256] {
401 let mut table: [u8; 256] = [0u8; 256];
402 let mut i = 1;
403 while i < 256 {
404 table[i] = num_set_bits(i as u8);
405 i += 1;
406 }
407 table
408}
409
410pub fn sample_short_bfield_element(randomness: &[u8; 8]) -> BFieldElement {
411 const NUM_SET_BITS: [u8; 256] = num_set_bits_table();
412 let left = ((NUM_SET_BITS[randomness[0] as usize] as u64) << (3 * 16))
413 + ((NUM_SET_BITS[randomness[1] as usize] as u64) << (2 * 16))
414 + ((NUM_SET_BITS[randomness[2] as usize] as u64) << 16)
415 + (NUM_SET_BITS[randomness[3] as usize] as u64);
416 let right = ((NUM_SET_BITS[randomness[4] as usize] as u64) << (3 * 16))
417 + ((NUM_SET_BITS[randomness[5] as usize] as u64) << (2 * 16))
418 + ((NUM_SET_BITS[randomness[6] as usize] as u64) << 16)
419 + (NUM_SET_BITS[randomness[7] as usize] as u64);
420 BFieldElement::new(left) - BFieldElement::new(right)
421}
422
423#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
427pub struct ModuleElement<const N: usize> {
428 #[serde(with = "BigArray")]
429 elements: [CyclotomicRingElement; N],
430}
431
432impl<const N: usize> ModuleElement<N> {
433 pub fn sample_short(randomness: &[u8]) -> Self {
434 debug_assert!(randomness.len() >= 8 * 64 * N);
435 let mut elements = [CyclotomicRingElement::zero(); N];
436 for n in 0..N {
437 elements[n] =
438 CyclotomicRingElement::sample_short(&randomness[8 * 64 * n..8 * 64 * (n + 1)]);
439 }
440 Self { elements }
441 }
442
443 pub fn sample_uniform(randomness: &[u8]) -> Self {
444 debug_assert!(randomness.len() >= N * 9 * 64);
445 ModuleElement {
446 elements: (0..N)
447 .map(|i| {
448 CyclotomicRingElement::sample_uniform(&randomness[i * 9 * 64..(i + 1) * 9 * 64])
449 })
450 .collect_vec()
451 .try_into()
452 .unwrap(),
453 }
454 }
455
456 pub fn ntt(&self) -> Self {
457 let mut copy = *self;
458 for n in 0..N {
459 coset_ntt_noswap_64(&mut copy.elements[n].coefficients);
460 }
461 copy
462 }
463
464 pub fn intt(&self) -> Self {
465 let mut copy = *self;
466 for n in 0..N {
467 coset_intt_noswap_64(&mut copy.elements[n].coefficients);
468 }
469 copy
470 }
471
472 pub fn multiply_hadamard<
485 const LHS_H: usize,
486 const LHS_N: usize,
487 const RHS_W: usize,
488 const RHS_N: usize,
489 const INNER: usize,
490 const OUT_N: usize,
491 >(
492 lhs: ModuleElement<LHS_N>,
493 rhs: ModuleElement<RHS_N>,
494 ) -> ModuleElement<OUT_N> {
495 debug_assert_eq!(LHS_H * INNER, LHS_N);
496 debug_assert_eq!(INNER * RHS_W, RHS_N);
497 debug_assert_eq!(LHS_H * RHS_W, OUT_N);
498
499 let mut elements = [CyclotomicRingElement::zero(); OUT_N];
500 for h in 0..LHS_H {
501 for w in 0..RHS_W {
502 for i in 0..INNER {
503 elements[h * RHS_W + w] += CyclotomicRingElement::hadamard(
504 lhs.elements[h * INNER + i],
505 rhs.elements[i * RHS_W + w],
506 );
507 }
508 }
509 }
510
511 ModuleElement { elements }
512 }
513
514 pub fn multiply<
527 const LHS_H: usize,
528 const LHS_N: usize,
529 const RHS_W: usize,
530 const RHS_N: usize,
531 const INNER: usize,
532 const OUT_N: usize,
533 >(
534 lhs: ModuleElement<LHS_N>,
535 rhs: ModuleElement<RHS_N>,
536 ) -> ModuleElement<OUT_N> {
537 debug_assert_eq!(LHS_H * INNER, LHS_N);
538 debug_assert_eq!(INNER * RHS_W, RHS_N);
539 debug_assert_eq!(LHS_H * RHS_W, OUT_N);
540
541 let mut out = ModuleElement {
542 elements: [CyclotomicRingElement::zero(); OUT_N],
543 };
544 for h in 0..LHS_H {
545 for w in 0..RHS_W {
546 for i in 0..INNER {
547 out.elements[h * RHS_W + w] +=
548 lhs.elements[h * INNER + i] * rhs.elements[i * RHS_W + w];
549 }
550 }
551 }
552
553 out
554 }
555
556 pub fn fast_multiply<
568 const LHS_H: usize,
569 const LHS_N: usize,
570 const RHS_W: usize,
571 const RHS_N: usize,
572 const INNER: usize,
573 const OUT_N: usize,
574 >(
575 lhs: ModuleElement<LHS_N>,
576 rhs: ModuleElement<RHS_N>,
577 ) -> ModuleElement<OUT_N> {
578 debug_assert_eq!(LHS_H * INNER, LHS_N);
579 debug_assert_eq!(INNER * RHS_W, RHS_N);
580 debug_assert_eq!(LHS_H * RHS_W, OUT_N);
581
582 let lhs_ntt = lhs.ntt();
583 let rhs_ntt = rhs.ntt();
584
585 let out_ntt =
586 Self::multiply_hadamard::<LHS_H, LHS_N, RHS_W, RHS_N, INNER, OUT_N>(lhs_ntt, rhs_ntt);
587
588 out_ntt.intt()
589 }
590}
591
592impl<const N: usize> Add for ModuleElement<N> {
593 type Output = ModuleElement<N>;
594
595 fn add(self, rhs: Self) -> Self::Output {
596 let elements: [CyclotomicRingElement; N] = (0..N)
597 .into_par_iter()
598 .map(|i| self.elements[i] + rhs.elements[i])
599 .collect::<Vec<_>>()
600 .try_into()
601 .unwrap();
602 ModuleElement::<N> { elements }
603 }
604}
605
606impl<const N: usize> Sub for ModuleElement<N> {
607 type Output = ModuleElement<N>;
608
609 fn sub(self, rhs: Self) -> Self::Output {
610 let elements: [CyclotomicRingElement; N] = (0..N)
611 .into_par_iter()
612 .map(|i| self.elements[i] - rhs.elements[i])
613 .collect::<Vec<_>>()
614 .try_into()
615 .unwrap();
616 ModuleElement::<N> { elements }
617 }
618}
619
620impl<const N: usize> Zero for ModuleElement<N> {
621 fn zero() -> Self {
622 Self {
623 elements: [CyclotomicRingElement::zero(); N],
624 }
625 }
626
627 fn is_zero(&self) -> bool {
628 *self == Self::zero()
629 }
630}
631
632pub mod kem {
633 use itertools::Itertools;
634 use serde_derive::Deserialize;
635 use serde_derive::Serialize;
636 use sha3::Digest as Sha3Digest;
637 use sha3::Sha3_256;
638 use sha3::Shake256;
639 use sha3::digest::ExtendableOutput;
640 use sha3::digest::Update;
641 use zeroize::Zeroize;
642
643 use super::CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES;
644 use super::CyclotomicRingElement;
645 use super::ModuleElement;
646 use super::embed_msg;
647 use super::extract_msg;
648 use crate::math::b_field_element::BFieldElement;
649
650 #[derive(PartialEq, Eq, Copy, Clone, Debug, Serialize, Deserialize, Zeroize)]
651 pub struct SecretKey {
652 key: [u8; 32],
653 seed: [u8; 32],
654 }
655
656 #[derive(PartialEq, Eq, Copy, Clone, Debug, Serialize, Deserialize)]
657 pub struct PublicKey {
658 seed: [u8; 32],
659 ga: ModuleElement<4>,
660 }
661
662 #[derive(PartialEq, Eq, Copy, Clone, Debug, Serialize, Deserialize)]
663 pub struct Ciphertext {
664 bg: ModuleElement<4>,
665 bga_m: ModuleElement<1>,
666 }
667
668 pub const CIPHERTEXT_SIZE_IN_BFES: usize = CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES * 5;
669
670 impl From<[BFieldElement; CIPHERTEXT_SIZE_IN_BFES]> for Ciphertext {
671 fn from(value: [BFieldElement; CIPHERTEXT_SIZE_IN_BFES]) -> Self {
672 let (bg_slice, bga_m_slice) = value.split_at(4 * CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES);
673
674 let bg_array: [BFieldElement; 4 * CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES] =
675 bg_slice.try_into().unwrap();
676 let bga_m_array: [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES] =
677 bga_m_slice.try_into().unwrap();
678
679 let bg_module = ModuleElement {
680 elements: bg_array
681 .chunks(CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES)
682 .map(|sl| {
683 CyclotomicRingElement::from(
684 std::convert::TryInto::<
685 [BFieldElement; CYCLOTOMIC_RING_ELEMENT_SIZE_IN_BFES],
686 >::try_into(sl)
687 .unwrap(),
688 )
689 })
690 .collect_vec()
691 .try_into()
692 .unwrap(),
693 };
694 let bga_m_module = ModuleElement {
695 elements: [CyclotomicRingElement::from(bga_m_array); 1],
696 };
697
698 Self {
699 bg: bg_module,
700 bga_m: bga_m_module,
701 }
702 }
703 }
704
705 impl From<Ciphertext> for [BFieldElement; CIPHERTEXT_SIZE_IN_BFES] {
706 fn from(value: Ciphertext) -> Self {
707 let bg_slice = value
708 .bg
709 .elements
710 .iter()
711 .flat_map(|e| e.coefficients)
712 .collect_vec();
713 let bga_m_slice = value
714 .bga_m
715 .elements
716 .iter()
717 .flat_map(|e| e.coefficients)
718 .collect_vec();
719 [bg_slice, bga_m_slice].concat().try_into().unwrap()
720 }
721 }
722
723 pub(super) fn shake256<const NUM_OUT_BYTES: usize>(
725 randomness: impl AsRef<[u8]>,
726 ) -> [u8; NUM_OUT_BYTES] {
727 let mut hasher = Shake256::default();
728 hasher.update(randomness.as_ref());
729
730 let mut result = [0u8; NUM_OUT_BYTES];
731 hasher.finalize_xof_into(&mut result);
732 result
733 }
734
735 fn derive_public_matrix(seed: &[u8; 32]) -> ModuleElement<16> {
736 const NUM_BYTES: usize = 9 * 64 * 16;
737 let randomness = shake256::<NUM_BYTES>(seed);
738 ModuleElement::<16>::sample_uniform(&randomness)
739 }
740
741 fn derive_secret_vectors(seed: &[u8; 32]) -> (ModuleElement<4>, ModuleElement<4>) {
742 const NUM_BYTES: usize = 2 * 4 * 64 * 8;
743 let randomness = shake256::<NUM_BYTES>(seed);
744 let a = ModuleElement::<4>::sample_short(&randomness[0..(NUM_BYTES / 2)]);
745 let b = ModuleElement::<4>::sample_short(&randomness[(NUM_BYTES / 2)..]);
746 (a, b)
747 }
748
749 pub fn keygen(randomness: [u8; 32]) -> (SecretKey, PublicKey) {
751 const OUTPUT_LENGTH: usize = 32;
752 let seed: [u8; OUTPUT_LENGTH] = shake256([randomness.to_vec(), vec![0u8]].concat());
753 let key: [u8; OUTPUT_LENGTH] = shake256([randomness.to_vec(), vec![1u8]].concat());
754
755 let sk = SecretKey { key, seed };
756
757 let pk = derive_public_key(&key, &seed);
758 (sk, pk)
759 }
760
761 fn derive_public_key(key: &[u8; 32], seed: &[u8; 32]) -> PublicKey {
762 let (a, c) = derive_secret_vectors(key);
763 let g = derive_public_matrix(seed);
764 let ga = ModuleElement::<16>::multiply_hadamard::<4, 16, 1, 4, 4, 4>(g, a.ntt()) + c.ntt();
765
766 PublicKey { seed: *seed, ga }
767 }
768
769 fn generate_ciphertext_derandomized(pk: PublicKey, payload: [u8; 32]) -> Ciphertext {
772 let (b, d) = derive_secret_vectors(&payload);
773 let b_ntt = b.ntt();
774 let d_ntt = d.ntt();
775 let g = derive_public_matrix(&pk.seed);
776 let bg = ModuleElement::<9>::multiply_hadamard::<1, 4, 4, 16, 4, 4>(b_ntt, g) + d_ntt;
777
778 let m = embed_msg(payload);
779 let bga_m = ModuleElement::<3>::multiply_hadamard::<1, 4, 1, 4, 4, 1>(b_ntt, pk.ga)
780 + ModuleElement::<1> { elements: [m] }.ntt();
781
782 Ciphertext { bg, bga_m }
783 }
784
785 pub fn enc(pk: PublicKey, randomness: [u8; 32]) -> ([u8; 32], Ciphertext) {
788 const OUTPUT_LENGTH: usize = 32;
789 let payload: [u8; OUTPUT_LENGTH] = shake256(randomness);
790 let ciphertext = generate_ciphertext_derandomized(pk, payload);
791 let shared_key: [u8; 32] = Sha3_256::digest(payload).into();
792
793 (shared_key, ciphertext)
794 }
795
796 pub fn dec(sk: SecretKey, ctxt: Ciphertext) -> Option<[u8; 32]> {
799 let (a, _) = derive_secret_vectors(&sk.key);
800 let bga = ModuleElement::<3>::multiply_hadamard::<1, 4, 1, 4, 4, 1>(ctxt.bg, a.ntt());
801 let m = (ctxt.bga_m - bga).intt();
802 let payload = extract_msg(m.elements[0]);
803
804 let pk = derive_public_key(&sk.key, &sk.seed);
805 let regenerated_ciphertext = generate_ciphertext_derandomized(pk, payload);
806
807 if regenerated_ciphertext != ctxt {
808 return None;
809 }
810
811 let shared_key = Sha3_256::digest(payload).into();
812 Some(shared_key)
813 }
814
815 #[cfg(test)]
816 #[cfg_attr(coverage_nightly, coverage(off))]
817 mod tests {
818 use super::*;
819
820 #[test]
821 fn secret_key_zeroize_test() {
822 let mut secret_key = SecretKey {
823 key: rand::random(),
824 seed: rand::random(),
825 };
826
827 secret_key.zeroize();
828
829 assert_eq!([0; 32], secret_key.key);
830 assert_eq!([0; 32], secret_key.seed);
831 }
832 }
833}
834
835#[cfg(test)]
836#[cfg_attr(coverage_nightly, coverage(off))]
837mod tests {
838 use itertools::Itertools;
839 use num_traits::ConstOne;
840 use num_traits::Zero;
841 use rand::RngCore;
842 use rand::random;
843 use sha3::Digest as Sha3Digest;
844 use sha3::Sha3_256;
845
846 use super::kem::CIPHERTEXT_SIZE_IN_BFES;
847 use super::kem::SecretKey;
848 use super::kem::shake256;
849 use crate::math::b_field_element::BFieldElement;
850 use crate::math::lattice::kem::Ciphertext;
851 use crate::math::lattice::kem::PublicKey;
852 use crate::math::lattice::*;
853
854 #[test]
855 fn test_kats() {
856 let input = b"\x21\xF1\x34\xAC\x57";
860 let expected_output_shake256 = b"\xBB\x8A\x84\x47\x51\x7B\xA9\xCA\x7F\xA3\x4E\xC9\x9A\x80\
861 \x00\x4F\x22\x8A\xB2\x82\x47\x28\x41\xEB\x3D\x3A\x76\x22\x5C\x9D\xBE\x77\xF7\xE4\x0A\x06\
862 \x67\x76\xD3\x2C\x74\x94\x12\x02\xF9\xF4\xAA\x43\xD1\x2C\x62\x64\xAF\xA5\x96\x39\xC4\x4E\
863 \x11\xF5\xE1\x4F\x1E\x56";
864 let expected_output_sha3_256 = b"\x55\xBD\x92\x24\xAF\x4E\xED\x0D\x12\x11\x49\xE3\x7F\xF4\
865 \xD7\xDD\x5B\xE2\x4B\xD9\xFB\xE5\x6E\x01\x71\xE8\x7D\xB7\xA6\xF4\xE0\x6D";
866
867 assert_eq!(*expected_output_shake256, shake256(input));
868
869 let sha3_out = Sha3_256::digest(input).to_vec();
870 assert_eq!(expected_output_sha3_256, &*sha3_out);
871 }
872
873 #[test]
874 fn test_fast_mul() {
875 let a: [BFieldElement; 64] = random();
876 let b: [BFieldElement; 64] = random();
877
878 let mut c_schoolbook = [BFieldElement::ZERO; 64];
879 for i in 0..64 {
880 for j in 0..64 {
881 if i + j >= 64 {
882 c_schoolbook[i + j - 64] -= a[i] * b[j];
883 } else {
884 c_schoolbook[i + j] += a[i] * b[j];
885 }
886 }
887 }
888
889 let c_fast = (CyclotomicRingElement { coefficients: a }
890 * CyclotomicRingElement { coefficients: b })
891 .coefficients;
892
893 assert_eq!(c_fast, c_schoolbook);
894 }
895
896 #[test]
897 fn test_embedding() {
898 let mut rng = rand::rng();
899 let msg: [u8; 32] = (0..32)
900 .map(|_| (rng.next_u32() % 256) as u8)
901 .collect_vec()
902 .try_into()
903 .unwrap();
904 let embedding = embed_msg(msg);
905 let extracted = extract_msg(embedding);
906
907 assert_eq!(msg, extracted);
908 }
909
910 #[test]
911 fn test_module_distributivity() {
912 let mut rng = rand::rng();
913 let randomness = (0..(2 * 3 + 2 * 3 + 3) * 64 * 9)
914 .map(|_| (rng.next_u32() % 256) as u8)
915 .collect_vec();
916 let mut start = 0;
917 let mut stop = 2 * 3 * 9 * 64;
918 let a = ModuleElement::<{ 2 * 3 }>::sample_uniform(&randomness[start..stop]);
919 start = stop;
920 stop += 2 * 3 * 9 * 64;
921 let b = ModuleElement::<{ 2 * 3 }>::sample_uniform(&randomness[start..stop]);
922 start = stop;
923 stop += 3 * 9 * 64;
924 let c = ModuleElement::<3>::sample_uniform(&randomness[start..stop]);
925
926 let sumprod = ModuleElement::<1>::multiply::<2, 6, 1, 3, 3, 2>(a + b, c);
927 let prodsum = ModuleElement::<1>::multiply::<2, 6, 1, 3, 3, 2>(a, c)
928 + ModuleElement::<1>::multiply::<2, 6, 1, 3, 3, 2>(b, c);
929
930 assert_eq!(sumprod, prodsum);
931 }
932
933 #[test]
934 fn test_module_multiply() {
935 let mut rng = rand::rng();
936 let randomness = (0..(2 * 3 + 2 * 3 + 3) * 64 * 9)
937 .map(|_| (rng.next_u32() % 256) as u8)
938 .collect_vec();
939 let mut start = 0;
940 let mut stop = 2 * 3 * 9 * 64;
941 let a = ModuleElement::<{ 2 * 3 }>::sample_uniform(&randomness[start..stop]);
942 start = stop;
943 stop += 3 * 2 * 9 * 64;
944 let b = ModuleElement::<{ 2 * 3 }>::sample_uniform(&randomness[start..stop]);
945
946 assert_eq!(
947 ModuleElement::<1>::fast_multiply::<2, 6, 2, 6, 3, 4>(a, b),
948 ModuleElement::<1>::multiply::<2, 6, 2, 6, 3, 4>(a, b)
949 );
950 }
951
952 #[test]
953 fn test_kem() {
954 let mut rng = rand::rng();
955 let mut key_randomness: [u8; 32] = [0u8; 32];
956 rng.fill_bytes(&mut key_randomness);
957 let mut ctxt_randomness: [u8; 32] = [0u8; 32];
958 rng.fill_bytes(&mut ctxt_randomness);
959 let (sk, pk) = kem::keygen(key_randomness);
961 let (alice_key, ctxt) = kem::enc(pk, ctxt_randomness);
962 if let Some(bob_key) = kem::dec(sk, ctxt) {
963 assert_eq!(alice_key, bob_key);
964 } else {
965 panic!()
966 }
967
968 rng.fill_bytes(&mut key_randomness);
970 let (other_sk, _) = kem::keygen(key_randomness);
971 assert!(kem::dec(other_sk, ctxt).is_none());
972 }
973
974 #[test]
975 fn test_ciphertext_conversion() {
976 let bfes: [BFieldElement; CIPHERTEXT_SIZE_IN_BFES] = random();
977 let ciphertext: Ciphertext = bfes.into();
978 let bfes_again: [BFieldElement; CIPHERTEXT_SIZE_IN_BFES] = ciphertext.into();
979 let ciphertext_again: Ciphertext = bfes_again.into();
980
981 assert_eq!(bfes, bfes_again);
982 assert_eq!(ciphertext, ciphertext_again);
983 }
984
985 #[test]
986 fn zero_test() {
987 let zero_me = ModuleElement::<4>::zero();
988 assert!(zero_me.is_zero(), "zero must be zero");
989 let not_zero = ModuleElement {
990 elements: [CyclotomicRingElement {
991 coefficients: [BFieldElement::ONE; 64],
992 }; 4],
993 };
994 assert!(!not_zero.is_zero(), "not-zero must be not be zero");
995 }
996
997 #[test]
998 fn serialization_deserialization_test() {
999 let mut rng = rand::rng();
1002 let mut key_randomness: [u8; 32] = [0u8; 32];
1003 rng.fill_bytes(&mut key_randomness);
1004 let mut ctxt_randomness: [u8; 32] = [0u8; 32];
1005 rng.fill_bytes(&mut ctxt_randomness);
1006 let (sk, pk) = kem::keygen(key_randomness);
1007 let (alice_key, ctxt) = kem::enc(pk, ctxt_randomness);
1008
1009 let sk_as_json: String = serde_json::to_string(&sk).unwrap();
1010 let sk_again = serde_json::from_str::<SecretKey>(&sk_as_json).unwrap();
1011 assert_eq!(sk, sk_again);
1012
1013 let pk_as_json: String = serde_json::to_string(&pk).unwrap();
1014 let pk_again = serde_json::from_str::<PublicKey>(&pk_as_json).unwrap();
1015 assert_eq!(pk, pk_again);
1016
1017 let ctxt_as_json: String = serde_json::to_string(&ctxt).unwrap();
1018 let ctxt_again = serde_json::from_str::<Ciphertext>(&ctxt_as_json).unwrap();
1019 assert_eq!(ctxt, ctxt_again);
1020
1021 let alice_key_as_json: String = serde_json::to_string(&alice_key).unwrap();
1022 let alice_key_again = serde_json::from_str::<[u8; 32]>(&alice_key_as_json).unwrap();
1023 assert_eq!(alice_key, alice_key_again);
1024 }
1025}