1use super::*;
17
18impl<E: Environment, const TYPE: u8, const VARIANT: usize> Hash for Keccak<E, TYPE, VARIANT> {
19    type Input = Boolean<E>;
20    type Output = Vec<Boolean<E>>;
21
22    #[inline]
24    fn hash(&self, input: &[Self::Input]) -> Self::Output {
25        let bitrate = PERMUTATION_WIDTH - 2 * VARIANT;
29        debug_assert!(bitrate < PERMUTATION_WIDTH, "The bitrate must be less than the permutation width");
30        debug_assert!(bitrate % 8 == 0, "The bitrate must be a multiple of 8");
31
32        if input.is_empty() {
34            E::halt("The input to the hash function must not be empty")
35        }
36
37        let mut s = vec![Boolean::constant(false); PERMUTATION_WIDTH];
39
40        let padded_blocks = match TYPE {
42            0 => Self::pad_keccak(input, bitrate),
43            1 => Self::pad_sha3(input, bitrate),
44            2.. => unreachable!("Invalid Keccak type"),
45        };
46
47        for block in padded_blocks {
55            for (j, bit) in block.into_iter().enumerate() {
57                s[j] = &s[j] ^ &bit;
58            }
59            s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
61        }
62
63        let mut z = s[..bitrate].to_vec();
74        while z.len() < VARIANT {
76            s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
78            z.extend(s.iter().take(bitrate).cloned());
80        }
81        z.truncate(VARIANT);
83        z
84    }
85}
86
87impl<E: Environment, const TYPE: u8, const VARIANT: usize> Keccak<E, TYPE, VARIANT> {
88    fn pad_keccak(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
93        debug_assert!(bitrate > 0, "The bitrate must be positive");
94
95        let mut padded_input = input.to_vec();
97        padded_input.resize(input.len().div_ceil(8) * 8, Boolean::constant(false));
98
99        padded_input.push(Boolean::constant(true));
101
102        while (padded_input.len() % bitrate) != (bitrate - 1) {
104            padded_input.push(Boolean::constant(false));
105        }
106
107        padded_input.push(Boolean::constant(true));
109
110        let mut result = Vec::new();
112        for block in padded_input.chunks(bitrate) {
113            result.push(block.to_vec());
114        }
115        result
116    }
117
118    fn pad_sha3(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
123        debug_assert!(bitrate > 1, "The bitrate must be greater than 1");
124
125        let mut padded_input = input.to_vec();
127        padded_input.resize(input.len().div_ceil(8) * 8, Boolean::constant(false));
128
129        padded_input.push(Boolean::constant(false));
131        padded_input.push(Boolean::constant(true));
132        padded_input.push(Boolean::constant(true));
133        padded_input.push(Boolean::constant(false));
134
135        while (padded_input.len() % bitrate) != (bitrate - 1) {
137            padded_input.push(Boolean::constant(false));
138        }
139
140        padded_input.push(Boolean::constant(true));
142
143        let mut result = Vec::new();
145        for block in padded_input.chunks(bitrate) {
146            result.push(block.to_vec());
147        }
148        result
149    }
150
151    fn permutation_f<const WIDTH: usize, const NUM_ROUNDS: usize>(
157        input: Vec<Boolean<E>>,
158        round_constants: &[U64<E>],
159        rotl: &[usize],
160    ) -> Vec<Boolean<E>> {
161        debug_assert_eq!(input.len(), WIDTH, "The input vector must have {WIDTH} bits");
162        debug_assert_eq!(
163            round_constants.len(),
164            NUM_ROUNDS,
165            "The round constants vector must have {NUM_ROUNDS} elements"
166        );
167
168        let mut a = input.chunks(64).map(U64::from_bits_le).collect::<Vec<_>>();
170        for round_constant in round_constants.iter().take(NUM_ROUNDS) {
172            a = Self::round(a, round_constant, rotl);
173        }
174        let mut bits = Vec::with_capacity(input.len());
176        a.iter().for_each(|e| e.write_bits_le(&mut bits));
177        bits
178    }
179
180    fn round(a: Vec<U64<E>>, round_constant: &U64<E>, rotl: &[usize]) -> Vec<U64<E>> {
186        debug_assert_eq!(a.len(), MODULO * MODULO, "The input vector 'a' must have {} elements", MODULO * MODULO);
187
188        let mut c = Vec::with_capacity(MODULO);
198        for x in 0..MODULO {
199            c.push(&a[x] ^ &a[x + MODULO] ^ &a[x + (2 * MODULO)] ^ &a[x + (3 * MODULO)] ^ &a[x + (4 * MODULO)]);
200        }
201
202        let mut d = Vec::with_capacity(MODULO);
212        for x in 0..MODULO {
213            d.push(&c[(x + 4) % MODULO] ^ Self::rotate_left(&c[(x + 1) % MODULO], 63));
214        }
215        let mut a_1 = Vec::with_capacity(MODULO * MODULO);
216        for y in 0..MODULO {
217            for x in 0..MODULO {
218                a_1.push(&a[x + (y * MODULO)] ^ &d[x]);
219            }
220        }
221
222        let mut a_2 = a_1.clone();
241        for y in 0..MODULO {
242            for x in 0..MODULO {
243                a_2[y + ((((2 * x) + (3 * y)) % MODULO) * MODULO)] =
245                    Self::rotate_left(&a_1[x + (y * MODULO)], rotl[x + (y * MODULO)]);
246            }
247        }
248
249        let mut a_3 = Vec::with_capacity(MODULO * MODULO);
258        for y in 0..MODULO {
259            for x in 0..MODULO {
260                let a = &a_2[x + (y * MODULO)];
261                let b = &a_2[((x + 1) % MODULO) + (y * MODULO)];
262                let c = &a_2[((x + 2) % MODULO) + (y * MODULO)];
263                a_3.push(a ^ ((!b) & c));
264            }
265        }
266
267        a_3[0] = &a_3[0] ^ round_constant;
272        a_3
273    }
274
275    fn rotate_left(value: &U64<E>, n: usize) -> U64<E> {
277        let mut bits_le = value.to_bits_le();
279        bits_le.rotate_left(n);
280        U64::from_bits_le(&bits_le)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use console::Rng;
289    use snarkvm_circuit_types::environment::Circuit;
290
291    const ITERATIONS: usize = 3;
292
293    macro_rules! check_equivalence {
294        ($console:expr, $circuit:expr) => {
295            use console::Hash as H;
296
297            let rng = &mut TestRng::default();
298
299            let mut input_sizes = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64, 128, 256, 512, 1024];
300            input_sizes.extend((0..5).map(|_| rng.gen_range(1..1024)));
301
302            for num_inputs in input_sizes {
303                println!("Checking equivalence for {num_inputs} inputs");
304
305                let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
307                let input = native_input.iter().map(|v| Boolean::<Circuit>::new(Mode::Private, *v)).collect::<Vec<_>>();
308
309                let expected = $console.hash(&native_input).expect("Failed to hash console input");
311
312                let candidate = $circuit.hash(&input);
314                assert_eq!(expected, candidate.eject_value());
315                Circuit::reset();
316            }
317        };
318    }
319
320    fn check_hash(
321        mode: Mode,
322        num_inputs: usize,
323        num_constants: u64,
324        num_public: u64,
325        num_private: u64,
326        num_constraints: u64,
327        rng: &mut TestRng,
328    ) {
329        use console::Hash as H;
330
331        let native = console::Keccak256::default();
332        let keccak = Keccak256::<Circuit>::new();
333
334        for i in 0..ITERATIONS {
335            let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
337            let input = native_input.iter().map(|v| Boolean::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
338
339            let expected = native.hash(&native_input).expect("Failed to hash native input");
341
342            Circuit::scope(format!("Keccak {mode} {i}"), || {
344                let candidate = keccak.hash(&input);
345                assert_eq!(expected, candidate.eject_value());
346                let case = format!("(mode = {mode}, num_inputs = {num_inputs})");
347                assert_scope!(case, num_constants, num_public, num_private, num_constraints);
348            });
349            Circuit::reset();
350        }
351    }
352
353    #[test]
354    fn test_keccak_256_hash_constant() {
355        let mut rng = TestRng::default();
356
357        check_hash(Mode::Constant, 1, 0, 0, 0, 0, &mut rng);
358        check_hash(Mode::Constant, 2, 0, 0, 0, 0, &mut rng);
359        check_hash(Mode::Constant, 3, 0, 0, 0, 0, &mut rng);
360        check_hash(Mode::Constant, 4, 0, 0, 0, 0, &mut rng);
361        check_hash(Mode::Constant, 5, 0, 0, 0, 0, &mut rng);
362        check_hash(Mode::Constant, 6, 0, 0, 0, 0, &mut rng);
363        check_hash(Mode::Constant, 7, 0, 0, 0, 0, &mut rng);
364        check_hash(Mode::Constant, 8, 0, 0, 0, 0, &mut rng);
365        check_hash(Mode::Constant, 16, 0, 0, 0, 0, &mut rng);
366        check_hash(Mode::Constant, 32, 0, 0, 0, 0, &mut rng);
367        check_hash(Mode::Constant, 64, 0, 0, 0, 0, &mut rng);
368        check_hash(Mode::Constant, 128, 0, 0, 0, 0, &mut rng);
369        check_hash(Mode::Constant, 256, 0, 0, 0, 0, &mut rng);
370        check_hash(Mode::Constant, 511, 0, 0, 0, 0, &mut rng);
371        check_hash(Mode::Constant, 512, 0, 0, 0, 0, &mut rng);
372        check_hash(Mode::Constant, 513, 0, 0, 0, 0, &mut rng);
373        check_hash(Mode::Constant, 1023, 0, 0, 0, 0, &mut rng);
374        check_hash(Mode::Constant, 1024, 0, 0, 0, 0, &mut rng);
375        check_hash(Mode::Constant, 1025, 0, 0, 0, 0, &mut rng);
376    }
377
378    #[test]
379    fn test_keccak_256_hash_public() {
380        let mut rng = TestRng::default();
381
382        check_hash(Mode::Public, 1, 0, 0, 138157, 138157, &mut rng);
383        check_hash(Mode::Public, 2, 0, 0, 139108, 139108, &mut rng);
384        check_hash(Mode::Public, 3, 0, 0, 139741, 139741, &mut rng);
385        check_hash(Mode::Public, 4, 0, 0, 140318, 140318, &mut rng);
386        check_hash(Mode::Public, 5, 0, 0, 140879, 140879, &mut rng);
387        check_hash(Mode::Public, 6, 0, 0, 141350, 141350, &mut rng);
388        check_hash(Mode::Public, 7, 0, 0, 141787, 141787, &mut rng);
389        check_hash(Mode::Public, 8, 0, 0, 142132, 142132, &mut rng);
390        check_hash(Mode::Public, 16, 0, 0, 144173, 144173, &mut rng);
391        check_hash(Mode::Public, 32, 0, 0, 145394, 145394, &mut rng);
392        check_hash(Mode::Public, 64, 0, 0, 146650, 146650, &mut rng);
393        check_hash(Mode::Public, 128, 0, 0, 149248, 149248, &mut rng);
394        check_hash(Mode::Public, 256, 0, 0, 150848, 150848, &mut rng);
395        check_hash(Mode::Public, 512, 0, 0, 151424, 151424, &mut rng);
396        check_hash(Mode::Public, 1024, 0, 0, 152448, 152448, &mut rng);
397    }
398
399    #[test]
400    fn test_keccak_256_hash_private() {
401        let mut rng = TestRng::default();
402
403        check_hash(Mode::Private, 1, 0, 0, 138157, 138157, &mut rng);
404        check_hash(Mode::Private, 2, 0, 0, 139108, 139108, &mut rng);
405        check_hash(Mode::Private, 3, 0, 0, 139741, 139741, &mut rng);
406        check_hash(Mode::Private, 4, 0, 0, 140318, 140318, &mut rng);
407        check_hash(Mode::Private, 5, 0, 0, 140879, 140879, &mut rng);
408        check_hash(Mode::Private, 6, 0, 0, 141350, 141350, &mut rng);
409        check_hash(Mode::Private, 7, 0, 0, 141787, 141787, &mut rng);
410        check_hash(Mode::Private, 8, 0, 0, 142132, 142132, &mut rng);
411        check_hash(Mode::Private, 16, 0, 0, 144173, 144173, &mut rng);
412        check_hash(Mode::Private, 32, 0, 0, 145394, 145394, &mut rng);
413        check_hash(Mode::Private, 64, 0, 0, 146650, 146650, &mut rng);
414        check_hash(Mode::Private, 128, 0, 0, 149248, 149248, &mut rng);
415        check_hash(Mode::Private, 256, 0, 0, 150848, 150848, &mut rng);
416        check_hash(Mode::Private, 512, 0, 0, 151424, 151424, &mut rng);
417        check_hash(Mode::Private, 1024, 0, 0, 152448, 152448, &mut rng);
418    }
419
420    #[test]
421    fn test_keccak_224_equivalence() {
422        check_equivalence!(console::Keccak224::default(), Keccak224::<Circuit>::new());
423    }
424
425    #[test]
426    fn test_keccak_256_equivalence() {
427        check_equivalence!(console::Keccak256::default(), Keccak256::<Circuit>::new());
428    }
429
430    #[test]
431    fn test_keccak_384_equivalence() {
432        check_equivalence!(console::Keccak384::default(), Keccak384::<Circuit>::new());
433    }
434
435    #[test]
436    fn test_keccak_512_equivalence() {
437        check_equivalence!(console::Keccak512::default(), Keccak512::<Circuit>::new());
438    }
439
440    #[test]
441    fn test_sha3_224_equivalence() {
442        check_equivalence!(console::Sha3_224::default(), Sha3_224::<Circuit>::new());
443    }
444
445    #[test]
446    fn test_sha3_256_equivalence() {
447        check_equivalence!(console::Sha3_256::default(), Sha3_256::<Circuit>::new());
448    }
449
450    #[test]
451    fn test_sha3_384_equivalence() {
452        check_equivalence!(console::Sha3_384::default(), Sha3_384::<Circuit>::new());
453    }
454
455    #[test]
456    fn test_sha3_512_equivalence() {
457        check_equivalence!(console::Sha3_512::default(), Sha3_512::<Circuit>::new());
458    }
459}