twenty_first/util_types/
sponge.rs

1use std::fmt::Debug;
2
3use num_traits::ConstOne;
4use num_traits::ConstZero;
5
6use crate::math::b_field_element::BFieldElement;
7
8pub const RATE: usize = 10;
9
10/// The hasher [Domain] differentiates between the modes of hashing.
11///
12/// The main purpose of declaring the domain is to prevent collisions between different types of
13/// hashing by introducing defining differences in the way the hash function's internal state
14/// (e.g. a sponge state's capacity) is initialized.
15#[derive(Debug, PartialEq, Eq)]
16pub enum Domain {
17    /// The `VariableLength` domain is used for hashing objects that potentially serialize to more
18    /// than [`RATE`] number of field elements.
19    VariableLength,
20
21    /// The `FixedLength` domain is used for hashing objects that always fit
22    /// within [RATE] number of fields elements, e.g. a pair of
23    /// [Digest](crate::prelude::Digest)s.
24    FixedLength,
25}
26
27/// A [cryptographic sponge][sponge]. Should only be based on a cryptographic permutation, e.g.,
28/// [`Tip5`][tip5].
29///
30/// [sponge]: https://keccak.team/files/CSF-0.1.pdf
31/// [tip5]: crate::prelude::Tip5
32pub trait Sponge: Clone + Debug + Default + Send + Sync {
33    const RATE: usize;
34
35    fn init() -> Self;
36
37    fn absorb(&mut self, input: [BFieldElement; RATE]);
38
39    fn squeeze(&mut self) -> [BFieldElement; RATE];
40
41    fn pad_and_absorb_all(&mut self, input: &[BFieldElement]) {
42        let mut chunks = input.chunks_exact(RATE);
43        for chunk in chunks.by_ref() {
44            // `chunks_exact` yields only chunks of length RATE; unwrap is fine
45            self.absorb(chunk.try_into().unwrap());
46        }
47
48        // Pad input with [1, 0, 0, …] – padding is at least one element.
49        // Since remainder's len is at most `RATE - 1`, the indexing is safe.
50        let remainder = chunks.remainder();
51        let mut last_chunk = const { [BFieldElement::ZERO; RATE] };
52        last_chunk[..remainder.len()].copy_from_slice(remainder);
53        last_chunk[remainder.len()] = BFieldElement::ONE;
54        self.absorb(last_chunk);
55    }
56}
57
58#[cfg(test)]
59#[cfg_attr(coverage_nightly, coverage(off))]
60mod tests {
61    use proptest::prelude::*;
62    use rand::Rng;
63    use rand::distr::Distribution;
64    use rand::distr::StandardUniform;
65
66    use super::*;
67    use crate::math::x_field_element::EXTENSION_DEGREE;
68    use crate::prelude::BFieldCodec;
69    use crate::prelude::XFieldElement;
70    use crate::tests::proptest;
71    use crate::tests::test;
72    use crate::tip5::Digest;
73    use crate::tip5::Tip5;
74
75    fn encode_prop<T>(smallest: T, largest: T)
76    where
77        T: Eq + BFieldCodec,
78        StandardUniform: Distribution<T>,
79    {
80        let smallest_seq = smallest.encode();
81        let largest_seq = largest.encode();
82        assert_ne!(smallest_seq, largest_seq);
83        assert_eq!(smallest_seq.len(), largest_seq.len());
84
85        let mut rng = rand::rng();
86        let random_a: T = rng.random();
87        let random_b: T = rng.random();
88
89        if random_a != random_b {
90            assert_ne!(random_a.encode(), random_b.encode());
91        } else {
92            assert_eq!(random_a.encode(), random_b.encode());
93        }
94    }
95
96    #[macro_rules_attr::apply(test)]
97    fn to_sequence_test() {
98        // bool
99        encode_prop(false, true);
100
101        // u32
102        encode_prop(0u32, u32::MAX);
103
104        // u64
105        encode_prop(0u64, u64::MAX);
106
107        // BFieldElement
108        let bfe_max = BFieldElement::new(BFieldElement::MAX);
109        encode_prop(BFieldElement::ZERO, bfe_max);
110
111        // XFieldElement
112        let xfe_max = XFieldElement::new([bfe_max; EXTENSION_DEGREE]);
113        encode_prop(XFieldElement::ZERO, xfe_max);
114
115        // Digest
116        let digest_max = Digest::new([bfe_max; Digest::LEN]);
117        encode_prop(Digest::ALL_ZERO, digest_max);
118
119        // u128
120        encode_prop(0u128, u128::MAX);
121    }
122
123    #[macro_rules_attr::apply(proptest)]
124    fn sample_indices(mut tip5: Tip5) {
125        let cases = [
126            (2, 0),
127            (4, 1),
128            (8, 9),
129            (16, 10),
130            (32, 11),
131            (64, 19),
132            (128, 20),
133            (256, 21),
134            (512, 65),
135        ];
136
137        for (upper_bound, num_indices) in cases {
138            let indices = tip5.sample_indices(upper_bound, num_indices);
139            prop_assert_eq!(num_indices, indices.len());
140            prop_assert!(indices.into_iter().all(|index| index < upper_bound));
141        }
142    }
143}