twenty_first/util_types/
sponge.rs

1use std::fmt::Debug;
2
3use num_traits::ConstOne;
4use num_traits::ConstZero;
5
6use crate::math::b_field_element::BFieldElement;
7
8pub const RATE: usize = 10;
9
10/// The hasher [Domain] differentiates between the modes of hashing.
11///
12/// The main purpose of declaring the domain is to prevent collisions between different types of
13/// hashing by introducing defining differences in the way the hash function's internal state
14/// (e.g. a sponge state's capacity) is initialized.
15#[derive(Debug, PartialEq, Eq)]
16pub enum Domain {
17    /// The `VariableLength` domain is used for hashing objects that potentially serialize to more
18    /// than [`RATE`] number of field elements.
19    VariableLength,
20
21    /// The `FixedLength` domain is used for hashing objects that always fit
22    /// within [RATE] number of fields elements, e.g. a pair of
23    /// [Digest](crate::prelude::Digest)s.
24    FixedLength,
25}
26
27/// A [cryptographic sponge][sponge]. Should only be based on a cryptographic permutation, e.g.,
28/// [`Tip5`][tip5].
29///
30/// [sponge]: https://keccak.team/files/CSF-0.1.pdf
31/// [tip5]: crate::prelude::Tip5
32pub trait Sponge: Clone + Debug + Default + Send + Sync {
33    const RATE: usize;
34
35    fn init() -> Self;
36
37    fn absorb(&mut self, input: [BFieldElement; RATE]);
38
39    fn squeeze(&mut self) -> [BFieldElement; RATE];
40
41    fn pad_and_absorb_all(&mut self, input: &[BFieldElement]) {
42        let mut chunks = input.chunks_exact(RATE);
43        for chunk in chunks.by_ref() {
44            // `chunks_exact` yields only chunks of length RATE; unwrap is fine
45            self.absorb(chunk.try_into().unwrap());
46        }
47
48        // Pad input with [1, 0, 0, …] – padding is at least one element.
49        // Since remainder's len is at most `RATE - 1`, the indexing is safe.
50        let remainder = chunks.remainder();
51        let mut last_chunk = const { [BFieldElement::ZERO; RATE] };
52        last_chunk[..remainder.len()].copy_from_slice(remainder);
53        last_chunk[remainder.len()] = BFieldElement::ONE;
54        self.absorb(last_chunk);
55    }
56}
57
58#[cfg(test)]
59#[cfg_attr(coverage_nightly, coverage(off))]
60mod tests {
61    use proptest::prelude::*;
62    use rand::Rng;
63    use rand::distr::Distribution;
64    use rand::distr::StandardUniform;
65    use test_strategy::proptest;
66
67    use super::*;
68    use crate::math::x_field_element::EXTENSION_DEGREE;
69    use crate::prelude::BFieldCodec;
70    use crate::prelude::XFieldElement;
71    use crate::tip5::Digest;
72    use crate::tip5::Tip5;
73
74    fn encode_prop<T>(smallest: T, largest: T)
75    where
76        T: Eq + BFieldCodec,
77        StandardUniform: Distribution<T>,
78    {
79        let smallest_seq = smallest.encode();
80        let largest_seq = largest.encode();
81        assert_ne!(smallest_seq, largest_seq);
82        assert_eq!(smallest_seq.len(), largest_seq.len());
83
84        let mut rng = rand::rng();
85        let random_a: T = rng.random();
86        let random_b: T = rng.random();
87
88        if random_a != random_b {
89            assert_ne!(random_a.encode(), random_b.encode());
90        } else {
91            assert_eq!(random_a.encode(), random_b.encode());
92        }
93    }
94
95    #[test]
96    fn to_sequence_test() {
97        // bool
98        encode_prop(false, true);
99
100        // u32
101        encode_prop(0u32, u32::MAX);
102
103        // u64
104        encode_prop(0u64, u64::MAX);
105
106        // BFieldElement
107        let bfe_max = BFieldElement::new(BFieldElement::MAX);
108        encode_prop(BFieldElement::ZERO, bfe_max);
109
110        // XFieldElement
111        let xfe_max = XFieldElement::new([bfe_max; EXTENSION_DEGREE]);
112        encode_prop(XFieldElement::ZERO, xfe_max);
113
114        // Digest
115        let digest_max = Digest::new([bfe_max; Digest::LEN]);
116        encode_prop(Digest::ALL_ZERO, digest_max);
117
118        // u128
119        encode_prop(0u128, u128::MAX);
120    }
121
122    #[proptest]
123    fn sample_indices(mut tip5: Tip5) {
124        let cases = [
125            (2, 0),
126            (4, 1),
127            (8, 9),
128            (16, 10),
129            (32, 11),
130            (64, 19),
131            (128, 20),
132            (256, 21),
133            (512, 65),
134        ];
135
136        for (upper_bound, num_indices) in cases {
137            let indices = tip5.sample_indices(upper_bound, num_indices);
138            prop_assert_eq!(num_indices, indices.len());
139            prop_assert!(indices.into_iter().all(|index| index < upper_bound));
140        }
141    }
142}