1use super::*;
17
18impl<E: Environment, const RATE: usize> HashMany for Poseidon<E, RATE> {
19    type Input = Field<E>;
20    type Output = Field<E>;
21
22    #[inline]
23    fn hash_many(&self, input: &[Self::Input], num_outputs: u16) -> Vec<Self::Output> {
24        let mut preimage = Vec::with_capacity(RATE + input.len());
26        preimage.push(self.domain.clone());
27        preimage.push(Field::constant(console::Field::from_u128(input.len() as u128)));
28        preimage.resize(RATE, Field::zero()); preimage.extend_from_slice(input);
30
31        let mut state = vec![Field::zero(); RATE + CAPACITY];
33        let mut mode = DuplexSpongeMode::Absorbing { next_absorb_index: 0 };
34
35        self.absorb(&mut state, &mut mode, &preimage);
37        self.squeeze(&mut state, &mut mode, num_outputs)
38    }
39}
40
41#[allow(clippy::needless_borrow)]
42impl<E: Environment, const RATE: usize> Poseidon<E, RATE> {
43    #[inline]
45    fn absorb(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, input: &[Field<E>]) {
46        if !input.is_empty() {
47            let (mut absorb_index, should_permute) = match *mode {
49                DuplexSpongeMode::Absorbing { next_absorb_index } => match next_absorb_index == RATE {
50                    true => (0, true),
51                    false => (next_absorb_index, false),
52                },
53                DuplexSpongeMode::Squeezing { .. } => (0, true),
54            };
55
56            if should_permute {
58                self.permute(state);
59            }
60
61            let mut remaining = input;
62            loop {
63                let start = CAPACITY + absorb_index;
65
66                if absorb_index + remaining.len() <= RATE {
68                    remaining.iter().enumerate().for_each(|(i, element)| state[start + i] += element);
70                    *mode = DuplexSpongeMode::Absorbing { next_absorb_index: absorb_index + remaining.len() };
72                    return;
73                }
74
75                let num_absorbed = RATE - absorb_index;
77                remaining.iter().enumerate().take(num_absorbed).for_each(|(i, element)| state[start + i] += element);
78
79                self.permute(state);
81
82                remaining = &remaining[num_absorbed..];
84                absorb_index = 0;
85            }
86        }
87    }
88
89    #[inline]
91    fn squeeze(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, num_outputs: u16) -> Vec<Field<E>> {
92        let mut output = vec![Field::zero(); num_outputs as usize];
93        if num_outputs != 0 {
94            self.squeeze_internal(state, mode, &mut output);
95        }
96        output
97    }
98
99    #[inline]
101    fn squeeze_internal(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, output: &mut [Field<E>]) {
102        let (mut squeeze_index, should_permute) = match *mode {
104            DuplexSpongeMode::Absorbing { .. } => (0, true),
105            DuplexSpongeMode::Squeezing { next_squeeze_index } => match next_squeeze_index == RATE {
106                true => (0, true),
107                false => (next_squeeze_index, false),
108            },
109        };
110
111        if should_permute {
113            self.permute(state);
114        }
115
116        let mut remaining = output;
117        loop {
118            let start = CAPACITY + squeeze_index;
120
121            if squeeze_index + remaining.len() <= RATE {
123                remaining.clone_from_slice(&state[start..(start + remaining.len())]);
125                *mode = DuplexSpongeMode::Squeezing { next_squeeze_index: squeeze_index + remaining.len() };
127                return;
128            }
129
130            let num_squeezed = RATE - squeeze_index;
132            remaining[..num_squeezed].clone_from_slice(&state[start..(start + num_squeezed)]);
133
134            self.permute(state);
136
137            remaining = &mut remaining[num_squeezed..];
139            squeeze_index = 0;
140        }
141    }
142
143    #[inline]
145    fn apply_ark(&self, state: &mut [Field<E>], round: usize) {
146        for (i, element) in state.iter_mut().enumerate() {
147            *element += &self.ark[round][i];
148        }
149    }
150
151    #[inline]
153    fn apply_s_box(&self, state: &mut [Field<E>], is_full_round: bool) {
154        if is_full_round {
155            for element in state.iter_mut() {
157                *element = (&*element).pow(&self.alpha);
158            }
159        } else {
160            state[0] = (&state[0]).pow(&self.alpha);
162        }
163    }
164
165    #[inline]
167    fn apply_mds(&self, state: &mut [Field<E>], new_state: &mut Vec<Field<E>>) {
168        new_state.clear();
169        for i in 0..state.len() {
170            let mut accumulator = Field::zero();
171            for (j, element) in state.iter().enumerate() {
172                accumulator += element * &self.mds[i][j];
173            }
174            new_state.push(accumulator);
175        }
176        state.swap_with_slice(new_state);
177    }
178
179    #[inline]
181    fn permute(&self, state: &mut [Field<E>]) {
182        let full_rounds_over_2 = self.full_rounds / 2;
184        let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds);
185
186        let mut new_state = Vec::with_capacity(state.len());
188        for i in 0..(self.partial_rounds + self.full_rounds) {
189            let is_full_round = !partial_round_range.contains(&i);
190            self.apply_ark(state, i);
191            self.apply_s_box(state, is_full_round);
192            self.apply_mds(state, &mut new_state);
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use snarkvm_circuit_types::environment::Circuit;
201
202    use anyhow::Result;
203
204    const DOMAIN: &str = "PoseidonCircuit0";
205    const ITERATIONS: usize = 10;
206    const RATE: u16 = 4;
207
208    fn check_hash_many(
209        mode: Mode,
210        num_inputs: usize,
211        num_outputs: u16,
212        num_constants: u64,
213        num_public: u64,
214        num_private: u64,
215        num_constraints: u64,
216        rng: &mut TestRng,
217    ) -> Result<()> {
218        use console::HashMany as H;
219
220        let native = console::Poseidon::<<Circuit as Environment>::Network, { RATE as usize }>::setup(DOMAIN)?;
221        let poseidon = Poseidon::<Circuit, { RATE as usize }>::constant(native.clone());
222
223        for i in 0..ITERATIONS {
224            let native_input = (0..num_inputs)
226                .map(|_| console::Field::<<Circuit as Environment>::Network>::rand(rng))
227                .collect::<Vec<_>>();
228            let input = native_input.iter().map(|v| Field::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
229
230            let expected = native.hash_many(&native_input, num_outputs);
232
233            Circuit::scope(format!("Poseidon {mode} {i} {num_outputs}"), || {
235                let candidate = poseidon.hash_many(&input, num_outputs);
236                for (expected_element, candidate_element) in expected.iter().zip_eq(&candidate) {
237                    assert_eq!(*expected_element, candidate_element.eject_value());
238                }
239                let case = format!("(mode = {mode}, num_inputs = {num_inputs}, num_outputs = {num_outputs})");
240                assert_scope!(case, num_constants, num_public, num_private, num_constraints);
241            });
242            Circuit::reset();
243        }
244        Ok(())
245    }
246
247    #[test]
248    fn test_hash_many_constant() -> Result<()> {
249        let mut rng = TestRng::default();
250
251        for num_inputs in 0..=RATE {
252            for num_outputs in 0..=RATE {
253                check_hash_many(Mode::Constant, num_inputs as usize, num_outputs, 1, 0, 0, 0, &mut rng)?;
254            }
255        }
256        Ok(())
257    }
258
259    #[test]
260    fn test_hash_many_public() -> Result<()> {
261        let mut rng = TestRng::default();
262
263        for num_outputs in 0..=RATE {
264            check_hash_many(Mode::Public, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
265        }
266        for num_outputs in 1..=RATE {
267            check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
268            check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
269            check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
270            check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
271            check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
272            check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
273        }
274        for num_outputs in (RATE + 1)..=(RATE * 2) {
275            check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
276            check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
277            check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
278            check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
279            check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
280            check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
281        }
282        Ok(())
283    }
284
285    #[test]
286    fn test_hash_many_private() -> Result<()> {
287        let mut rng = TestRng::default();
288
289        for num_outputs in 0..=RATE {
290            check_hash_many(Mode::Private, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
291        }
292        for num_outputs in 1..=RATE {
293            check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
294            check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
295            check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
296            check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
297            check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
298            check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
299        }
300        for num_outputs in (RATE + 1)..=(RATE * 2) {
301            check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
302            check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
303            check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
304            check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
305            check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
306            check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
307        }
308        Ok(())
309    }
310}