snarkvm_circuit_algorithms/poseidon/
hash_many.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17
18impl<E: Environment, const RATE: usize> HashMany for Poseidon<E, RATE> {
19    type Input = Field<E>;
20    type Output = Field<E>;
21
22    #[inline]
23    fn hash_many(&self, input: &[Self::Input], num_outputs: u16) -> Vec<Self::Output> {
24        // Construct the preimage: [ DOMAIN || LENGTH(INPUT) || [0; RATE-2] || INPUT ].
25        let mut preimage = Vec::with_capacity(RATE + input.len());
26        preimage.push(self.domain.clone());
27        preimage.push(Field::constant(console::Field::from_u128(input.len() as u128)));
28        preimage.resize(RATE, Field::zero()); // Pad up to RATE.
29        preimage.extend_from_slice(input);
30
31        // Initialize a new sponge.
32        let mut state = vec![Field::zero(); RATE + CAPACITY];
33        let mut mode = DuplexSpongeMode::Absorbing { next_absorb_index: 0 };
34
35        // Absorb the input and squeeze the output.
36        self.absorb(&mut state, &mut mode, &preimage);
37        self.squeeze(&mut state, &mut mode, num_outputs)
38    }
39}
40
41#[allow(clippy::needless_borrow)]
42impl<E: Environment, const RATE: usize> Poseidon<E, RATE> {
43    /// Absorbs the input elements into state.
44    #[inline]
45    fn absorb(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, input: &[Field<E>]) {
46        if !input.is_empty() {
47            // Determine the absorb index.
48            let (mut absorb_index, should_permute) = match *mode {
49                DuplexSpongeMode::Absorbing { next_absorb_index } => match next_absorb_index == RATE {
50                    true => (0, true),
51                    false => (next_absorb_index, false),
52                },
53                DuplexSpongeMode::Squeezing { .. } => (0, true),
54            };
55
56            // Proceed to permute the state, if necessary.
57            if should_permute {
58                self.permute(state);
59            }
60
61            let mut remaining = input;
62            loop {
63                // Compute the starting index.
64                let start = CAPACITY + absorb_index;
65
66                // Check if we can exit the loop.
67                if absorb_index + remaining.len() <= RATE {
68                    // Absorb the state elements into the input.
69                    remaining.iter().enumerate().for_each(|(i, element)| state[start + i] += element);
70                    // Update the sponge mode.
71                    *mode = DuplexSpongeMode::Absorbing { next_absorb_index: absorb_index + remaining.len() };
72                    return;
73                }
74
75                // Otherwise, proceed to absorb `(rate - absorb_index)` elements.
76                let num_absorbed = RATE - absorb_index;
77                remaining.iter().enumerate().take(num_absorbed).for_each(|(i, element)| state[start + i] += element);
78
79                // Permute the state.
80                self.permute(state);
81
82                // Repeat with the updated input slice and absorb index.
83                remaining = &remaining[num_absorbed..];
84                absorb_index = 0;
85            }
86        }
87    }
88
89    /// Squeeze the specified number of state elements into the output.
90    #[inline]
91    fn squeeze(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, num_outputs: u16) -> Vec<Field<E>> {
92        let mut output = vec![Field::zero(); num_outputs as usize];
93        if num_outputs != 0 {
94            self.squeeze_internal(state, mode, &mut output);
95        }
96        output
97    }
98
99    /// Squeeze the state elements into the output.
100    #[inline]
101    fn squeeze_internal(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, output: &mut [Field<E>]) {
102        // Determine the squeeze index.
103        let (mut squeeze_index, should_permute) = match *mode {
104            DuplexSpongeMode::Absorbing { .. } => (0, true),
105            DuplexSpongeMode::Squeezing { next_squeeze_index } => match next_squeeze_index == RATE {
106                true => (0, true),
107                false => (next_squeeze_index, false),
108            },
109        };
110
111        // Proceed to permute the state, if necessary.
112        if should_permute {
113            self.permute(state);
114        }
115
116        let mut remaining = output;
117        loop {
118            // Compute the starting index.
119            let start = CAPACITY + squeeze_index;
120
121            // Check if we can exit the loop.
122            if squeeze_index + remaining.len() <= RATE {
123                // Store the state elements into the output.
124                remaining.clone_from_slice(&state[start..(start + remaining.len())]);
125                // Update the sponge mode.
126                *mode = DuplexSpongeMode::Squeezing { next_squeeze_index: squeeze_index + remaining.len() };
127                return;
128            }
129
130            // Otherwise, proceed to squeeze `(rate - squeeze_index)` elements.
131            let num_squeezed = RATE - squeeze_index;
132            remaining[..num_squeezed].clone_from_slice(&state[start..(start + num_squeezed)]);
133
134            // Permute.
135            self.permute(state);
136
137            // Repeat with the updated output slice and squeeze index.
138            remaining = &mut remaining[num_squeezed..];
139            squeeze_index = 0;
140        }
141    }
142
143    /// Apply the additive round keys in-place.
144    #[inline]
145    fn apply_ark(&self, state: &mut [Field<E>], round: usize) {
146        for (i, element) in state.iter_mut().enumerate() {
147            *element += &self.ark[round][i];
148        }
149    }
150
151    /// Apply the S-Box based on whether it is a full round or partial round.
152    #[inline]
153    fn apply_s_box(&self, state: &mut [Field<E>], is_full_round: bool) {
154        if is_full_round {
155            // Full rounds apply the S Box (x^alpha) to every element of state
156            for element in state.iter_mut() {
157                *element = (&*element).pow(&self.alpha);
158            }
159        } else {
160            // Partial rounds apply the S Box (x^alpha) to just the first element of state
161            state[0] = (&state[0]).pow(&self.alpha);
162        }
163    }
164
165    /// Apply the Maximally Distance Separating (MDS) matrix in-place.
166    #[inline]
167    fn apply_mds(&self, state: &mut [Field<E>], new_state: &mut Vec<Field<E>>) {
168        new_state.clear();
169        for i in 0..state.len() {
170            let mut accumulator = Field::zero();
171            for (j, element) in state.iter().enumerate() {
172                accumulator += element * &self.mds[i][j];
173            }
174            new_state.push(accumulator);
175        }
176        state.swap_with_slice(new_state);
177    }
178
179    /// Apply the permutation for all rounds in-place.
180    #[inline]
181    fn permute(&self, state: &mut [Field<E>]) {
182        // Determine the partial rounds range bound.
183        let full_rounds_over_2 = self.full_rounds / 2;
184        let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds);
185
186        // Iterate through all rounds to permute.
187        let mut new_state = Vec::with_capacity(state.len());
188        for i in 0..(self.partial_rounds + self.full_rounds) {
189            let is_full_round = !partial_round_range.contains(&i);
190            self.apply_ark(state, i);
191            self.apply_s_box(state, is_full_round);
192            self.apply_mds(state, &mut new_state);
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use snarkvm_circuit_types::environment::Circuit;
201
202    use anyhow::Result;
203
204    const DOMAIN: &str = "PoseidonCircuit0";
205    const ITERATIONS: usize = 10;
206    const RATE: u16 = 4;
207
208    fn check_hash_many(
209        mode: Mode,
210        num_inputs: usize,
211        num_outputs: u16,
212        num_constants: u64,
213        num_public: u64,
214        num_private: u64,
215        num_constraints: u64,
216        rng: &mut TestRng,
217    ) -> Result<()> {
218        use console::HashMany as H;
219
220        let native = console::Poseidon::<<Circuit as Environment>::Network, { RATE as usize }>::setup(DOMAIN)?;
221        let poseidon = Poseidon::<Circuit, { RATE as usize }>::constant(native.clone());
222
223        for i in 0..ITERATIONS {
224            // Prepare the preimage.
225            let native_input = (0..num_inputs)
226                .map(|_| console::Field::<<Circuit as Environment>::Network>::rand(rng))
227                .collect::<Vec<_>>();
228            let input = native_input.iter().map(|v| Field::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
229
230            // Compute the native hash.
231            let expected = native.hash_many(&native_input, num_outputs);
232
233            // Compute the circuit hash.
234            Circuit::scope(format!("Poseidon {mode} {i} {num_outputs}"), || {
235                let candidate = poseidon.hash_many(&input, num_outputs);
236                for (expected_element, candidate_element) in expected.iter().zip_eq(&candidate) {
237                    assert_eq!(*expected_element, candidate_element.eject_value());
238                }
239                let case = format!("(mode = {mode}, num_inputs = {num_inputs}, num_outputs = {num_outputs})");
240                assert_scope!(case, num_constants, num_public, num_private, num_constraints);
241            });
242            Circuit::reset();
243        }
244        Ok(())
245    }
246
247    #[test]
248    fn test_hash_many_constant() -> Result<()> {
249        let mut rng = TestRng::default();
250
251        for num_inputs in 0..=RATE {
252            for num_outputs in 0..=RATE {
253                check_hash_many(Mode::Constant, num_inputs as usize, num_outputs, 1, 0, 0, 0, &mut rng)?;
254            }
255        }
256        Ok(())
257    }
258
259    #[test]
260    fn test_hash_many_public() -> Result<()> {
261        let mut rng = TestRng::default();
262
263        for num_outputs in 0..=RATE {
264            check_hash_many(Mode::Public, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
265        }
266        for num_outputs in 1..=RATE {
267            check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
268            check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
269            check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
270            check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
271            check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
272            check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
273        }
274        for num_outputs in (RATE + 1)..=(RATE * 2) {
275            check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
276            check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
277            check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
278            check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
279            check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
280            check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
281        }
282        Ok(())
283    }
284
285    #[test]
286    fn test_hash_many_private() -> Result<()> {
287        let mut rng = TestRng::default();
288
289        for num_outputs in 0..=RATE {
290            check_hash_many(Mode::Private, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
291        }
292        for num_outputs in 1..=RATE {
293            check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
294            check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
295            check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
296            check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
297            check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
298            check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
299        }
300        for num_outputs in (RATE + 1)..=(RATE * 2) {
301            check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
302            check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
303            check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
304            check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
305            check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
306            check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
307        }
308        Ok(())
309    }
310}