Skip to main content

snarkvm_algorithms/crypto_hash/
poseidon.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{AlgebraicSponge, DuplexSpongeMode, nonnative_params::*};
17use snarkvm_fields::{FieldParameters, PoseidonParameters, PrimeField, ToConstraintField};
18use snarkvm_utilities::{BigInteger, FromBits, ToBits};
19
20use smallvec::SmallVec;
21use std::{
22    iter::Peekable,
23    ops::{Index, IndexMut},
24    sync::Arc,
25};
26
27#[derive(Copy, Clone, Debug)]
28pub struct State<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
29    capacity_state: [F; CAPACITY],
30    rate_state: [F; RATE],
31}
32
33impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Default for State<F, RATE, CAPACITY> {
34    fn default() -> Self {
35        Self { capacity_state: [F::zero(); CAPACITY], rate_state: [F::zero(); RATE] }
36    }
37}
38
39impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> State<F, RATE, CAPACITY> {
40    /// Returns an immutable iterator over the state.
41    pub fn iter(&self) -> impl Iterator<Item = &F> + Clone {
42        self.capacity_state.iter().chain(self.rate_state.iter())
43    }
44
45    /// Returns a mutable iterator over the state.
46    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut F> {
47        self.capacity_state.iter_mut().chain(self.rate_state.iter_mut())
48    }
49}
50
51impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Index<usize> for State<F, RATE, CAPACITY> {
52    type Output = F;
53
54    fn index(&self, index: usize) -> &Self::Output {
55        assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
56        if index < CAPACITY { &self.capacity_state[index] } else { &self.rate_state[index - CAPACITY] }
57    }
58}
59
60impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> IndexMut<usize> for State<F, RATE, CAPACITY> {
61    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
62        assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
63        if index < CAPACITY { &mut self.capacity_state[index] } else { &mut self.rate_state[index - CAPACITY] }
64    }
65}
66
67#[derive(Clone, Debug, PartialEq, Eq)]
68pub struct Poseidon<F: PrimeField, const RATE: usize> {
69    parameters: Arc<PoseidonParameters<F, RATE, 1>>,
70}
71
72impl<F: PrimeField, const RATE: usize> Poseidon<F, RATE> {
73    /// Initializes a new instance of the cryptographic hash function.
74    pub fn setup() -> Self {
75        Self { parameters: Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) }
76    }
77
78    /// Evaluate the cryptographic hash function over a list of field elements
79    /// as input.
80    pub fn evaluate(&self, input: &[F]) -> F {
81        self.evaluate_many(input, 1)[0]
82    }
83
84    /// Evaluate the cryptographic hash function over a list of field elements
85    /// as input, and returns the specified number of field elements as
86    /// output.
87    pub fn evaluate_many(&self, input: &[F], num_outputs: usize) -> Vec<F> {
88        let mut sponge = PoseidonSponge::<F, RATE, 1>::new_with_parameters(&self.parameters);
89        sponge.absorb_native_field_elements(input);
90        sponge.squeeze_native_field_elements(num_outputs).to_vec()
91    }
92
93    /// Evaluate the cryptographic hash function over a non-fixed-length vector,
94    /// in which the length also needs to be hashed.
95    pub fn evaluate_with_len(&self, input: &[F]) -> F {
96        self.evaluate(&[vec![F::from(input.len() as u128)], input.to_vec()].concat())
97    }
98
99    pub fn parameters(&self) -> &Arc<PoseidonParameters<F, RATE, 1>> {
100        &self.parameters
101    }
102}
103
104/// A duplex sponge based using the Poseidon permutation.
105///
106/// This implementation of Poseidon is entirely from Fractal's implementation in
107/// [COS20][cos] with small syntax changes.
108///
109/// [cos]: https://eprint.iacr.org/2019/1076
110#[derive(Clone, Debug)]
111pub struct PoseidonSponge<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
112    /// Sponge Parameters
113    parameters: Arc<PoseidonParameters<F, RATE, CAPACITY>>,
114    /// Current sponge's state (current elements in the permutation block)
115    state: State<F, RATE, CAPACITY>,
116    /// Current mode (whether its absorbing or squeezing)
117    pub mode: DuplexSpongeMode,
118    /// A persistent lookup table used when compressing elements.
119    adjustment_factor_lookup_table: Arc<[F]>,
120}
121
122impl<F: PrimeField, const RATE: usize> AlgebraicSponge<F, RATE> for PoseidonSponge<F, RATE, 1> {
123    type Parameters = Arc<PoseidonParameters<F, RATE, 1>>;
124
125    fn sample_parameters() -> Self::Parameters {
126        Arc::new(F::default_poseidon_parameters::<RATE>().unwrap())
127    }
128
129    fn new_with_parameters(parameters: &Self::Parameters) -> Self {
130        Self {
131            parameters: parameters.clone(),
132            state: State::default(),
133            mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
134            adjustment_factor_lookup_table: {
135                let capacity = F::size_in_bits() - 1;
136                let mut table = Vec::<F>::with_capacity(capacity);
137
138                let mut cur = F::one();
139                for _ in 0..capacity {
140                    table.push(cur);
141                    cur.double_in_place();
142                }
143
144                table.into()
145            },
146        }
147    }
148
149    /// Takes in field elements.
150    fn absorb_native_field_elements<T: ToConstraintField<F>>(&mut self, elements: &[T]) {
151        let input = elements.iter().flat_map(|e| e.to_field_elements().unwrap()).collect::<Vec<_>>();
152        if !input.is_empty() {
153            match self.mode {
154                DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
155                    if next_absorb_index == RATE {
156                        self.permute();
157                        next_absorb_index = 0;
158                    }
159                    self.absorb_internal(next_absorb_index, &input);
160                }
161                DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
162                    self.permute();
163                    self.absorb_internal(0, &input);
164                }
165            }
166        }
167    }
168
169    /// Takes in field elements.
170    fn absorb_nonnative_field_elements<Target: PrimeField>(&mut self, elements: impl IntoIterator<Item = Target>) {
171        Self::push_elements_to_sponge(self, elements, OptimizationType::Weight);
172    }
173
174    fn squeeze_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
175        self.get_fe(num, false)
176    }
177
178    fn squeeze_native_field_elements(&mut self, num_elements: usize) -> SmallVec<[F; 10]> {
179        if num_elements == 0 {
180            return SmallVec::<[F; 10]>::new();
181        }
182        let mut output = if num_elements <= 10 {
183            smallvec::smallvec_inline![F::zero(); 10]
184        } else {
185            smallvec::smallvec![F::zero(); num_elements]
186        };
187
188        match self.mode {
189            DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
190                self.permute();
191                self.squeeze_internal(0, &mut output[..num_elements]);
192            }
193            DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
194                if next_squeeze_index == RATE {
195                    self.permute();
196                    next_squeeze_index = 0;
197                }
198                self.squeeze_internal(next_squeeze_index, &mut output[..num_elements]);
199            }
200        }
201
202        output.truncate(num_elements);
203        output
204    }
205
206    /// Takes out field elements of 168 bits.
207    fn squeeze_short_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
208        self.get_fe(num, true)
209    }
210}
211
212impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
213    #[inline]
214    fn apply_ark(&mut self, round_number: usize) {
215        for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
216            *state_elem += ark_elem;
217        }
218    }
219
220    #[inline]
221    fn apply_s_box(&mut self, is_full_round: bool) {
222        if is_full_round {
223            // Full rounds apply the S Box (x^alpha) to every element of state
224            for elem in self.state.iter_mut() {
225                *elem = elem.pow([self.parameters.alpha]);
226            }
227        } else {
228            // Partial rounds apply the S Box (x^alpha) to just the first element of state
229            self.state[0] = self.state[0].pow([self.parameters.alpha]);
230        }
231    }
232
233    #[inline]
234    fn apply_mds(&mut self) {
235        let mut new_state = State::default();
236        let curr_state: Vec<F> = self.state.iter().copied().collect::<Vec<_>>();
237        new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
238            *new_elem = F::sum_of_products(&curr_state, mds_row);
239        });
240        self.state = new_state;
241    }
242
243    #[inline]
244    fn permute(&mut self) {
245        // Determine the partial rounds range bound.
246        let partial_rounds = self.parameters.partial_rounds;
247        let full_rounds = self.parameters.full_rounds;
248        let full_rounds_over_2 = full_rounds / 2;
249        let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
250
251        // Iterate through all rounds to permute.
252        for i in 0..(partial_rounds + full_rounds) {
253            let is_full_round = !partial_round_range.contains(&i);
254            self.apply_ark(i);
255            self.apply_s_box(is_full_round);
256            self.apply_mds();
257        }
258    }
259
260    /// Absorbs everything in elements, this does not end in an absorption.
261    #[inline]
262    fn absorb_internal(&mut self, mut rate_start: usize, input: &[F]) {
263        if !input.is_empty() {
264            let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
265            let num_elements_remaining = input.len() - first_chunk_size;
266            let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
267            let rest_chunks = rest_chunk.chunks(RATE);
268            // The total number of chunks is `elements[num_elements_remaining..].len() /
269            // RATE`, plus 1 for the remainder.
270            let total_num_chunks = 1 + // 1 for the first chunk
271                // We add all the chunks that are perfectly divisible by `RATE`
272                (num_elements_remaining / RATE) +
273                // And also add 1 if the last chunk is non-empty
274                // (i.e. if `num_elements_remaining` is not a multiple of `RATE`)
275                usize::from((num_elements_remaining % RATE) != 0);
276
277            // Absorb the input elements, `RATE` elements at a time, except for the first
278            // chunk, which is of size `RATE - rate_start`.
279            for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
280                for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state[rate_start..]) {
281                    *state_elem += element;
282                }
283                // Are we in the last chunk?
284                // If so, let's wrap up.
285                if i == total_num_chunks - 1 {
286                    self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
287                    return;
288                } else {
289                    self.permute();
290                }
291                rate_start = 0;
292            }
293        }
294    }
295
296    /// Squeeze |output| many elements. This does not end in a squeeze
297    #[inline]
298    fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [F]) {
299        let output_size = output.len();
300        if output_size != 0 {
301            let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
302            let num_output_remaining = output.len() - first_chunk_size;
303            let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
304            assert_eq!(rest_chunk.len(), num_output_remaining);
305            let rest_chunks = rest_chunk.chunks_mut(RATE);
306            // The total number of chunks is `output[num_output_remaining..].len() / RATE`,
307            // plus 1 for the remainder.
308            let total_num_chunks = 1 + // 1 for the first chunk
309                // We add all the chunks that are perfectly divisible by `RATE`
310                (num_output_remaining / RATE) +
311                // And also add 1 if the last chunk is non-empty
312                // (i.e. if `num_output_remaining` is not a multiple of `RATE`)
313                usize::from((num_output_remaining % RATE) != 0);
314
315            // Absorb the input output, `RATE` output at a time, except for the first chunk,
316            // which is of size `RATE - rate_start`.
317            for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
318                let range = rate_start..(rate_start + chunk.len());
319                debug_assert_eq!(
320                    chunk.len(),
321                    self.state.rate_state[range.clone()].len(),
322                    "failed with squeeze {output_size} at rate {RATE} and rate_start {rate_start}"
323                );
324                chunk.copy_from_slice(&self.state.rate_state[range]);
325                // Are we in the last chunk?
326                // If so, let's wrap up.
327                if i == total_num_chunks - 1 {
328                    self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
329                    return;
330                } else {
331                    self.permute();
332                }
333                rate_start = 0;
334            }
335        }
336    }
337
338    /// Compress every two elements if possible.
339    /// Provides a vector of (limb, num_of_additions), both of which are F.
340    pub fn compress_elements<TargetField: PrimeField, I: Iterator<Item = (F, F)>>(
341        &self,
342        mut src_limbs: Peekable<I>,
343        ty: OptimizationType,
344    ) -> Vec<F> {
345        let capacity = F::size_in_bits() - 1;
346        let mut dest_limbs = Vec::<F>::new();
347
348        let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty);
349
350        // Prepare a reusable vector to be used in overhead calculation.
351        let mut num_bits = Vec::new();
352
353        while let Some(first) = src_limbs.next() {
354            let second = src_limbs.peek();
355
356            let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one(), &mut num_bits);
357            let second_max_bits_per_limb = if let Some(second) = second {
358                params.bits_per_limb + crate::overhead!(second.1 + F::one(), &mut num_bits)
359            } else {
360                0
361            };
362
363            if let Some(second) = second {
364                if first_max_bits_per_limb + second_max_bits_per_limb <= capacity {
365                    let adjustment_factor = &self.adjustment_factor_lookup_table[second_max_bits_per_limb];
366
367                    dest_limbs.push(first.0 * adjustment_factor + second.0);
368                    src_limbs.next();
369                } else {
370                    dest_limbs.push(first.0);
371                }
372            } else {
373                dest_limbs.push(first.0);
374            }
375        }
376
377        dest_limbs
378    }
379
380    /// Convert a `TargetField` element into limbs (not constraints)
381    /// This is an internal function that would be reused by a number of other
382    /// functions
383    pub fn get_limbs_representations<TargetField: PrimeField>(
384        elem: &TargetField,
385        optimization_type: OptimizationType,
386    ) -> SmallVec<[F; 10]> {
387        Self::get_limbs_representations_from_big_integer::<TargetField>(&elem.to_bigint(), optimization_type)
388    }
389
390    /// Obtain the limbs directly from a big int
391    pub fn get_limbs_representations_from_big_integer<TargetField: PrimeField>(
392        elem: &<TargetField as PrimeField>::BigInteger,
393        optimization_type: OptimizationType,
394    ) -> SmallVec<[F; 10]> {
395        let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type);
396
397        // Prepare a reusable vector for the BE bits.
398        let mut cur_bits = Vec::new();
399        // Push the lower limbs first
400        let mut limbs: SmallVec<[F; 10]> = SmallVec::new();
401        let mut cur = *elem;
402        for _ in 0..params.num_limbs {
403            cur.write_bits_be(&mut cur_bits); // `write_bits_be` is big endian
404            let cur_mod_r =
405                <F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..])
406                    .unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
407            limbs.push(F::from_bigint(cur_mod_r).unwrap());
408            cur.divn(params.bits_per_limb as u32);
409            // Clear the vector after every iteration so its allocation can be reused.
410            cur_bits.clear();
411        }
412
413        // then we reverse, so that the limbs are ``big limb first''
414        limbs.reverse();
415
416        limbs
417    }
418
419    /// Push elements to sponge, treated in the non-native field
420    /// representations.
421    pub fn push_elements_to_sponge<TargetField: PrimeField>(
422        &mut self,
423        src: impl IntoIterator<Item = TargetField>,
424        ty: OptimizationType,
425    ) {
426        let src_limbs = src
427            .into_iter()
428            .flat_map(|elem| {
429                let limbs = Self::get_limbs_representations(&elem, ty);
430                limbs.into_iter().map(|limb| (limb, F::one()))
431                // specifically set to one, since most gadgets in the constraint
432                // world would not have zero noise (due to the relatively weak
433                // normal form testing in `alloc`)
434            })
435            .peekable();
436
437        let dest_limbs = self.compress_elements::<TargetField, _>(src_limbs, ty);
438        self.absorb_native_field_elements(&dest_limbs);
439    }
440
441    /// obtain random bits from hashchain.
442    /// not guaranteed to be uniformly distributed, should only be used in
443    /// certain situations.
444    pub fn get_bits(&mut self, num_bits: usize) -> Vec<bool> {
445        let bits_per_element = F::size_in_bits() - 1;
446        let num_elements = num_bits.div_ceil(bits_per_element);
447
448        let src_elements = self.squeeze_native_field_elements(num_elements);
449        let mut dest_bits = Vec::<bool>::with_capacity(num_elements * bits_per_element);
450
451        let skip = (F::Parameters::REPR_SHAVE_BITS + 1) as usize;
452        for elem in src_elements.iter() {
453            // discard the highest bit
454            let elem_bits = elem.to_bigint().to_bits_be();
455            dest_bits.extend_from_slice(&elem_bits[skip..]);
456        }
457        dest_bits.truncate(num_bits);
458
459        dest_bits
460    }
461
462    /// obtain random field elements from hashchain.
463    /// not guaranteed to be uniformly distributed, should only be used in
464    /// certain situations.
465    pub fn get_fe<TargetField: PrimeField>(
466        &mut self,
467        num_elements: usize,
468        outputs_short_elements: bool,
469    ) -> SmallVec<[TargetField; 10]> {
470        let num_bits_per_nonnative = if outputs_short_elements {
471            168
472        } else {
473            TargetField::size_in_bits() - 1 // also omit the highest bit
474        };
475        let bits = self.get_bits(num_bits_per_nonnative * num_elements);
476
477        let mut lookup_table = Vec::<TargetField>::with_capacity(num_bits_per_nonnative);
478        let mut cur = TargetField::one();
479        for _ in 0..num_bits_per_nonnative {
480            lookup_table.push(cur);
481            cur.double_in_place();
482        }
483
484        let dest_elements = bits
485            .chunks_exact(num_bits_per_nonnative)
486            .map(|per_nonnative_bits| {
487                // technically, this can be done via BigInteger::from_bits; here, we use this
488                // method for consistency with the gadget counterpart
489                let mut res = TargetField::zero();
490
491                for (i, bit) in per_nonnative_bits.iter().rev().enumerate() {
492                    if *bit {
493                        res += &lookup_table[i];
494                    }
495                }
496                res
497            })
498            .collect::<SmallVec<_>>();
499        debug_assert_eq!(dest_elements.len(), num_elements);
500
501        dest_elements
502    }
503}