snarkvm_algorithms/crypto_hash/
poseidon.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{AlgebraicSponge, DuplexSpongeMode, nonnative_params::*};
17use snarkvm_fields::{FieldParameters, PoseidonParameters, PrimeField, ToConstraintField};
18use snarkvm_utilities::{BigInteger, FromBits, ToBits};
19
20use smallvec::SmallVec;
21use std::{
22    iter::Peekable,
23    ops::{Index, IndexMut},
24    sync::Arc,
25};
26
27#[derive(Copy, Clone, Debug)]
28pub struct State<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
29    capacity_state: [F; CAPACITY],
30    rate_state: [F; RATE],
31}
32
33impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Default for State<F, RATE, CAPACITY> {
34    fn default() -> Self {
35        Self { capacity_state: [F::zero(); CAPACITY], rate_state: [F::zero(); RATE] }
36    }
37}
38
39impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> State<F, RATE, CAPACITY> {
40    /// Returns an immutable iterator over the state.
41    pub fn iter(&self) -> impl Iterator<Item = &F> + Clone {
42        self.capacity_state.iter().chain(self.rate_state.iter())
43    }
44
45    /// Returns a mutable iterator over the state.
46    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut F> {
47        self.capacity_state.iter_mut().chain(self.rate_state.iter_mut())
48    }
49}
50
51impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Index<usize> for State<F, RATE, CAPACITY> {
52    type Output = F;
53
54    fn index(&self, index: usize) -> &Self::Output {
55        assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
56        if index < CAPACITY { &self.capacity_state[index] } else { &self.rate_state[index - CAPACITY] }
57    }
58}
59
60impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> IndexMut<usize> for State<F, RATE, CAPACITY> {
61    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
62        assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
63        if index < CAPACITY { &mut self.capacity_state[index] } else { &mut self.rate_state[index - CAPACITY] }
64    }
65}
66
67#[derive(Clone, Debug, PartialEq, Eq)]
68pub struct Poseidon<F: PrimeField, const RATE: usize> {
69    parameters: Arc<PoseidonParameters<F, RATE, 1>>,
70}
71
72impl<F: PrimeField, const RATE: usize> Poseidon<F, RATE> {
73    /// Initializes a new instance of the cryptographic hash function.
74    pub fn setup() -> Self {
75        Self { parameters: Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) }
76    }
77
78    /// Evaluate the cryptographic hash function over a list of field elements
79    /// as input.
80    pub fn evaluate(&self, input: &[F]) -> F {
81        self.evaluate_many(input, 1)[0]
82    }
83
84    /// Evaluate the cryptographic hash function over a list of field elements
85    /// as input, and returns the specified number of field elements as
86    /// output.
87    pub fn evaluate_many(&self, input: &[F], num_outputs: usize) -> Vec<F> {
88        let mut sponge = PoseidonSponge::<F, RATE, 1>::new_with_parameters(&self.parameters);
89        sponge.absorb_native_field_elements(input);
90        sponge.squeeze_native_field_elements(num_outputs).to_vec()
91    }
92
93    /// Evaluate the cryptographic hash function over a non-fixed-length vector,
94    /// in which the length also needs to be hashed.
95    pub fn evaluate_with_len(&self, input: &[F]) -> F {
96        self.evaluate(&[vec![F::from(input.len() as u128)], input.to_vec()].concat())
97    }
98
99    pub fn parameters(&self) -> &Arc<PoseidonParameters<F, RATE, 1>> {
100        &self.parameters
101    }
102}
103
104/// A duplex sponge based using the Poseidon permutation.
105///
106/// This implementation of Poseidon is entirely from Fractal's implementation in
107/// [COS20][cos] with small syntax changes.
108///
109/// [cos]: https://eprint.iacr.org/2019/1076
110#[derive(Clone, Debug)]
111pub struct PoseidonSponge<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
112    /// Sponge Parameters
113    parameters: Arc<PoseidonParameters<F, RATE, CAPACITY>>,
114    /// Current sponge's state (current elements in the permutation block)
115    state: State<F, RATE, CAPACITY>,
116    /// Current mode (whether its absorbing or squeezing)
117    pub mode: DuplexSpongeMode,
118    /// A persistent lookup table used when compressing elements.
119    adjustment_factor_lookup_table: Arc<[F]>,
120}
121
122impl<F: PrimeField, const RATE: usize> AlgebraicSponge<F, RATE> for PoseidonSponge<F, RATE, 1> {
123    type Parameters = Arc<PoseidonParameters<F, RATE, 1>>;
124
125    fn sample_parameters() -> Self::Parameters {
126        Arc::new(F::default_poseidon_parameters::<RATE>().unwrap())
127    }
128
129    fn new_with_parameters(parameters: &Self::Parameters) -> Self {
130        Self {
131            parameters: parameters.clone(),
132            state: State::default(),
133            mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
134            adjustment_factor_lookup_table: {
135                let capacity = F::size_in_bits() - 1;
136                let mut table = Vec::<F>::with_capacity(capacity);
137
138                let mut cur = F::one();
139                for _ in 0..capacity {
140                    table.push(cur);
141                    cur.double_in_place();
142                }
143
144                table.into()
145            },
146        }
147    }
148
149    /// Takes in field elements.
150    fn absorb_native_field_elements<T: ToConstraintField<F>>(&mut self, elements: &[T]) {
151        let input = elements.iter().flat_map(|e| e.to_field_elements().unwrap()).collect::<Vec<_>>();
152        if !input.is_empty() {
153            match self.mode {
154                DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
155                    if next_absorb_index == RATE {
156                        self.permute();
157                        next_absorb_index = 0;
158                    }
159                    self.absorb_internal(next_absorb_index, &input);
160                }
161                DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
162                    self.permute();
163                    self.absorb_internal(0, &input);
164                }
165            }
166        }
167    }
168
169    /// Takes in field elements.
170    fn absorb_nonnative_field_elements<Target: PrimeField>(&mut self, elements: impl IntoIterator<Item = Target>) {
171        Self::push_elements_to_sponge(self, elements, OptimizationType::Weight);
172    }
173
174    fn squeeze_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
175        self.get_fe(num, false)
176    }
177
178    fn squeeze_native_field_elements(&mut self, num_elements: usize) -> SmallVec<[F; 10]> {
179        if num_elements == 0 {
180            return SmallVec::<[F; 10]>::new();
181        }
182        let mut output = if num_elements <= 10 {
183            smallvec::smallvec_inline![F::zero(); 10]
184        } else {
185            smallvec::smallvec![F::zero(); num_elements]
186        };
187
188        match self.mode {
189            DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
190                self.permute();
191                self.squeeze_internal(0, &mut output[..num_elements]);
192            }
193            DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
194                if next_squeeze_index == RATE {
195                    self.permute();
196                    next_squeeze_index = 0;
197                }
198                self.squeeze_internal(next_squeeze_index, &mut output[..num_elements]);
199            }
200        }
201
202        output.truncate(num_elements);
203        output
204    }
205
206    /// Takes out field elements of 168 bits.
207    fn squeeze_short_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
208        self.get_fe(num, true)
209    }
210}
211
212impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
213    #[inline]
214    fn apply_ark(&mut self, round_number: usize) {
215        for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
216            *state_elem += ark_elem;
217        }
218    }
219
220    #[inline]
221    fn apply_s_box(&mut self, is_full_round: bool) {
222        if is_full_round {
223            // Full rounds apply the S Box (x^alpha) to every element of state
224            for elem in self.state.iter_mut() {
225                *elem = elem.pow([self.parameters.alpha]);
226            }
227        } else {
228            // Partial rounds apply the S Box (x^alpha) to just the first element of state
229            self.state[0] = self.state[0].pow([self.parameters.alpha]);
230        }
231    }
232
233    #[inline]
234    fn apply_mds(&mut self) {
235        let mut new_state = State::default();
236        new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
237            *new_elem = F::sum_of_products(self.state.iter(), mds_row.iter());
238        });
239        self.state = new_state;
240    }
241
242    #[inline]
243    fn permute(&mut self) {
244        // Determine the partial rounds range bound.
245        let partial_rounds = self.parameters.partial_rounds;
246        let full_rounds = self.parameters.full_rounds;
247        let full_rounds_over_2 = full_rounds / 2;
248        let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
249
250        // Iterate through all rounds to permute.
251        for i in 0..(partial_rounds + full_rounds) {
252            let is_full_round = !partial_round_range.contains(&i);
253            self.apply_ark(i);
254            self.apply_s_box(is_full_round);
255            self.apply_mds();
256        }
257    }
258
259    /// Absorbs everything in elements, this does not end in an absorption.
260    #[inline]
261    fn absorb_internal(&mut self, mut rate_start: usize, input: &[F]) {
262        if !input.is_empty() {
263            let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
264            let num_elements_remaining = input.len() - first_chunk_size;
265            let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
266            let rest_chunks = rest_chunk.chunks(RATE);
267            // The total number of chunks is `elements[num_elements_remaining..].len() /
268            // RATE`, plus 1 for the remainder.
269            let total_num_chunks = 1 + // 1 for the first chunk
270                // We add all the chunks that are perfectly divisible by `RATE`
271                (num_elements_remaining / RATE) +
272                // And also add 1 if the last chunk is non-empty
273                // (i.e. if `num_elements_remaining` is not a multiple of `RATE`)
274                usize::from((num_elements_remaining % RATE) != 0);
275
276            // Absorb the input elements, `RATE` elements at a time, except for the first
277            // chunk, which is of size `RATE - rate_start`.
278            for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
279                for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state[rate_start..]) {
280                    *state_elem += element;
281                }
282                // Are we in the last chunk?
283                // If so, let's wrap up.
284                if i == total_num_chunks - 1 {
285                    self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
286                    return;
287                } else {
288                    self.permute();
289                }
290                rate_start = 0;
291            }
292        }
293    }
294
295    /// Squeeze |output| many elements. This does not end in a squeeze
296    #[inline]
297    fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [F]) {
298        let output_size = output.len();
299        if output_size != 0 {
300            let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
301            let num_output_remaining = output.len() - first_chunk_size;
302            let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
303            assert_eq!(rest_chunk.len(), num_output_remaining);
304            let rest_chunks = rest_chunk.chunks_mut(RATE);
305            // The total number of chunks is `output[num_output_remaining..].len() / RATE`,
306            // plus 1 for the remainder.
307            let total_num_chunks = 1 + // 1 for the first chunk
308                // We add all the chunks that are perfectly divisible by `RATE`
309                (num_output_remaining / RATE) +
310                // And also add 1 if the last chunk is non-empty
311                // (i.e. if `num_output_remaining` is not a multiple of `RATE`)
312                usize::from((num_output_remaining % RATE) != 0);
313
314            // Absorb the input output, `RATE` output at a time, except for the first chunk,
315            // which is of size `RATE - rate_start`.
316            for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
317                let range = rate_start..(rate_start + chunk.len());
318                debug_assert_eq!(
319                    chunk.len(),
320                    self.state.rate_state[range.clone()].len(),
321                    "failed with squeeze {output_size} at rate {RATE} and rate_start {rate_start}"
322                );
323                chunk.copy_from_slice(&self.state.rate_state[range]);
324                // Are we in the last chunk?
325                // If so, let's wrap up.
326                if i == total_num_chunks - 1 {
327                    self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
328                    return;
329                } else {
330                    self.permute();
331                }
332                rate_start = 0;
333            }
334        }
335    }
336
337    /// Compress every two elements if possible.
338    /// Provides a vector of (limb, num_of_additions), both of which are F.
339    pub fn compress_elements<TargetField: PrimeField, I: Iterator<Item = (F, F)>>(
340        &self,
341        mut src_limbs: Peekable<I>,
342        ty: OptimizationType,
343    ) -> Vec<F> {
344        let capacity = F::size_in_bits() - 1;
345        let mut dest_limbs = Vec::<F>::new();
346
347        let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty);
348
349        // Prepare a reusable vector to be used in overhead calculation.
350        let mut num_bits = Vec::new();
351
352        while let Some(first) = src_limbs.next() {
353            let second = src_limbs.peek();
354
355            let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one(), &mut num_bits);
356            let second_max_bits_per_limb = if let Some(second) = second {
357                params.bits_per_limb + crate::overhead!(second.1 + F::one(), &mut num_bits)
358            } else {
359                0
360            };
361
362            if let Some(second) = second {
363                if first_max_bits_per_limb + second_max_bits_per_limb <= capacity {
364                    let adjustment_factor = &self.adjustment_factor_lookup_table[second_max_bits_per_limb];
365
366                    dest_limbs.push(first.0 * adjustment_factor + second.0);
367                    src_limbs.next();
368                } else {
369                    dest_limbs.push(first.0);
370                }
371            } else {
372                dest_limbs.push(first.0);
373            }
374        }
375
376        dest_limbs
377    }
378
379    /// Convert a `TargetField` element into limbs (not constraints)
380    /// This is an internal function that would be reused by a number of other
381    /// functions
382    pub fn get_limbs_representations<TargetField: PrimeField>(
383        elem: &TargetField,
384        optimization_type: OptimizationType,
385    ) -> SmallVec<[F; 10]> {
386        Self::get_limbs_representations_from_big_integer::<TargetField>(&elem.to_bigint(), optimization_type)
387    }
388
389    /// Obtain the limbs directly from a big int
390    pub fn get_limbs_representations_from_big_integer<TargetField: PrimeField>(
391        elem: &<TargetField as PrimeField>::BigInteger,
392        optimization_type: OptimizationType,
393    ) -> SmallVec<[F; 10]> {
394        let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type);
395
396        // Prepare a reusable vector for the BE bits.
397        let mut cur_bits = Vec::new();
398        // Push the lower limbs first
399        let mut limbs: SmallVec<[F; 10]> = SmallVec::new();
400        let mut cur = *elem;
401        for _ in 0..params.num_limbs {
402            cur.write_bits_be(&mut cur_bits); // `write_bits_be` is big endian
403            let cur_mod_r =
404                <F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..])
405                    .unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
406            limbs.push(F::from_bigint(cur_mod_r).unwrap());
407            cur.divn(params.bits_per_limb as u32);
408            // Clear the vector after every iteration so its allocation can be reused.
409            cur_bits.clear();
410        }
411
412        // then we reverse, so that the limbs are ``big limb first''
413        limbs.reverse();
414
415        limbs
416    }
417
418    /// Push elements to sponge, treated in the non-native field
419    /// representations.
420    pub fn push_elements_to_sponge<TargetField: PrimeField>(
421        &mut self,
422        src: impl IntoIterator<Item = TargetField>,
423        ty: OptimizationType,
424    ) {
425        let src_limbs = src
426            .into_iter()
427            .flat_map(|elem| {
428                let limbs = Self::get_limbs_representations(&elem, ty);
429                limbs.into_iter().map(|limb| (limb, F::one()))
430                // specifically set to one, since most gadgets in the constraint
431                // world would not have zero noise (due to the relatively weak
432                // normal form testing in `alloc`)
433            })
434            .peekable();
435
436        let dest_limbs = self.compress_elements::<TargetField, _>(src_limbs, ty);
437        self.absorb_native_field_elements(&dest_limbs);
438    }
439
440    /// obtain random bits from hashchain.
441    /// not guaranteed to be uniformly distributed, should only be used in
442    /// certain situations.
443    pub fn get_bits(&mut self, num_bits: usize) -> Vec<bool> {
444        let bits_per_element = F::size_in_bits() - 1;
445        let num_elements = num_bits.div_ceil(bits_per_element);
446
447        let src_elements = self.squeeze_native_field_elements(num_elements);
448        let mut dest_bits = Vec::<bool>::with_capacity(num_elements * bits_per_element);
449
450        let skip = (F::Parameters::REPR_SHAVE_BITS + 1) as usize;
451        for elem in src_elements.iter() {
452            // discard the highest bit
453            let elem_bits = elem.to_bigint().to_bits_be();
454            dest_bits.extend_from_slice(&elem_bits[skip..]);
455        }
456        dest_bits.truncate(num_bits);
457
458        dest_bits
459    }
460
461    /// obtain random field elements from hashchain.
462    /// not guaranteed to be uniformly distributed, should only be used in
463    /// certain situations.
464    pub fn get_fe<TargetField: PrimeField>(
465        &mut self,
466        num_elements: usize,
467        outputs_short_elements: bool,
468    ) -> SmallVec<[TargetField; 10]> {
469        let num_bits_per_nonnative = if outputs_short_elements {
470            168
471        } else {
472            TargetField::size_in_bits() - 1 // also omit the highest bit
473        };
474        let bits = self.get_bits(num_bits_per_nonnative * num_elements);
475
476        let mut lookup_table = Vec::<TargetField>::with_capacity(num_bits_per_nonnative);
477        let mut cur = TargetField::one();
478        for _ in 0..num_bits_per_nonnative {
479            lookup_table.push(cur);
480            cur.double_in_place();
481        }
482
483        let dest_elements = bits
484            .chunks_exact(num_bits_per_nonnative)
485            .map(|per_nonnative_bits| {
486                // technically, this can be done via BigInterger::from_bits; here, we use this
487                // method for consistency with the gadget counterpart
488                let mut res = TargetField::zero();
489
490                for (i, bit) in per_nonnative_bits.iter().rev().enumerate() {
491                    if *bit {
492                        res += &lookup_table[i];
493                    }
494                }
495                res
496            })
497            .collect::<SmallVec<_>>();
498        debug_assert_eq!(dest_elements.len(), num_elements);
499
500        dest_elements
501    }
502}