snarkvm_algorithms/crypto_hash/
poseidon.rs1use crate::{AlgebraicSponge, DuplexSpongeMode, nonnative_params::*};
17use snarkvm_fields::{FieldParameters, PoseidonParameters, PrimeField, ToConstraintField};
18use snarkvm_utilities::{BigInteger, FromBits, ToBits};
19
20use smallvec::SmallVec;
21use std::{
22 iter::Peekable,
23 ops::{Index, IndexMut},
24 sync::Arc,
25};
26
27#[derive(Copy, Clone, Debug)]
28pub struct State<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
29 capacity_state: [F; CAPACITY],
30 rate_state: [F; RATE],
31}
32
33impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Default for State<F, RATE, CAPACITY> {
34 fn default() -> Self {
35 Self { capacity_state: [F::zero(); CAPACITY], rate_state: [F::zero(); RATE] }
36 }
37}
38
39impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> State<F, RATE, CAPACITY> {
40 pub fn iter(&self) -> impl Iterator<Item = &F> + Clone {
42 self.capacity_state.iter().chain(self.rate_state.iter())
43 }
44
45 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut F> {
47 self.capacity_state.iter_mut().chain(self.rate_state.iter_mut())
48 }
49}
50
51impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Index<usize> for State<F, RATE, CAPACITY> {
52 type Output = F;
53
54 fn index(&self, index: usize) -> &Self::Output {
55 assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
56 if index < CAPACITY { &self.capacity_state[index] } else { &self.rate_state[index - CAPACITY] }
57 }
58}
59
60impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> IndexMut<usize> for State<F, RATE, CAPACITY> {
61 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
62 assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
63 if index < CAPACITY { &mut self.capacity_state[index] } else { &mut self.rate_state[index - CAPACITY] }
64 }
65}
66
67#[derive(Clone, Debug, PartialEq, Eq)]
68pub struct Poseidon<F: PrimeField, const RATE: usize> {
69 parameters: Arc<PoseidonParameters<F, RATE, 1>>,
70}
71
72impl<F: PrimeField, const RATE: usize> Poseidon<F, RATE> {
73 pub fn setup() -> Self {
75 Self { parameters: Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) }
76 }
77
78 pub fn evaluate(&self, input: &[F]) -> F {
81 self.evaluate_many(input, 1)[0]
82 }
83
84 pub fn evaluate_many(&self, input: &[F], num_outputs: usize) -> Vec<F> {
88 let mut sponge = PoseidonSponge::<F, RATE, 1>::new_with_parameters(&self.parameters);
89 sponge.absorb_native_field_elements(input);
90 sponge.squeeze_native_field_elements(num_outputs).to_vec()
91 }
92
93 pub fn evaluate_with_len(&self, input: &[F]) -> F {
96 self.evaluate(&[vec![F::from(input.len() as u128)], input.to_vec()].concat())
97 }
98
99 pub fn parameters(&self) -> &Arc<PoseidonParameters<F, RATE, 1>> {
100 &self.parameters
101 }
102}
103
104#[derive(Clone, Debug)]
111pub struct PoseidonSponge<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
112 parameters: Arc<PoseidonParameters<F, RATE, CAPACITY>>,
114 state: State<F, RATE, CAPACITY>,
116 pub mode: DuplexSpongeMode,
118 adjustment_factor_lookup_table: Arc<[F]>,
120}
121
122impl<F: PrimeField, const RATE: usize> AlgebraicSponge<F, RATE> for PoseidonSponge<F, RATE, 1> {
123 type Parameters = Arc<PoseidonParameters<F, RATE, 1>>;
124
125 fn sample_parameters() -> Self::Parameters {
126 Arc::new(F::default_poseidon_parameters::<RATE>().unwrap())
127 }
128
129 fn new_with_parameters(parameters: &Self::Parameters) -> Self {
130 Self {
131 parameters: parameters.clone(),
132 state: State::default(),
133 mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
134 adjustment_factor_lookup_table: {
135 let capacity = F::size_in_bits() - 1;
136 let mut table = Vec::<F>::with_capacity(capacity);
137
138 let mut cur = F::one();
139 for _ in 0..capacity {
140 table.push(cur);
141 cur.double_in_place();
142 }
143
144 table.into()
145 },
146 }
147 }
148
149 fn absorb_native_field_elements<T: ToConstraintField<F>>(&mut self, elements: &[T]) {
151 let input = elements.iter().flat_map(|e| e.to_field_elements().unwrap()).collect::<Vec<_>>();
152 if !input.is_empty() {
153 match self.mode {
154 DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
155 if next_absorb_index == RATE {
156 self.permute();
157 next_absorb_index = 0;
158 }
159 self.absorb_internal(next_absorb_index, &input);
160 }
161 DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
162 self.permute();
163 self.absorb_internal(0, &input);
164 }
165 }
166 }
167 }
168
169 fn absorb_nonnative_field_elements<Target: PrimeField>(&mut self, elements: impl IntoIterator<Item = Target>) {
171 Self::push_elements_to_sponge(self, elements, OptimizationType::Weight);
172 }
173
174 fn squeeze_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
175 self.get_fe(num, false)
176 }
177
178 fn squeeze_native_field_elements(&mut self, num_elements: usize) -> SmallVec<[F; 10]> {
179 if num_elements == 0 {
180 return SmallVec::<[F; 10]>::new();
181 }
182 let mut output = if num_elements <= 10 {
183 smallvec::smallvec_inline![F::zero(); 10]
184 } else {
185 smallvec::smallvec![F::zero(); num_elements]
186 };
187
188 match self.mode {
189 DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
190 self.permute();
191 self.squeeze_internal(0, &mut output[..num_elements]);
192 }
193 DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
194 if next_squeeze_index == RATE {
195 self.permute();
196 next_squeeze_index = 0;
197 }
198 self.squeeze_internal(next_squeeze_index, &mut output[..num_elements]);
199 }
200 }
201
202 output.truncate(num_elements);
203 output
204 }
205
206 fn squeeze_short_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
208 self.get_fe(num, true)
209 }
210}
211
212impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
213 #[inline]
214 fn apply_ark(&mut self, round_number: usize) {
215 for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
216 *state_elem += ark_elem;
217 }
218 }
219
220 #[inline]
221 fn apply_s_box(&mut self, is_full_round: bool) {
222 if is_full_round {
223 for elem in self.state.iter_mut() {
225 *elem = elem.pow([self.parameters.alpha]);
226 }
227 } else {
228 self.state[0] = self.state[0].pow([self.parameters.alpha]);
230 }
231 }
232
233 #[inline]
234 fn apply_mds(&mut self) {
235 let mut new_state = State::default();
236 new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
237 *new_elem = F::sum_of_products(self.state.iter(), mds_row.iter());
238 });
239 self.state = new_state;
240 }
241
242 #[inline]
243 fn permute(&mut self) {
244 let partial_rounds = self.parameters.partial_rounds;
246 let full_rounds = self.parameters.full_rounds;
247 let full_rounds_over_2 = full_rounds / 2;
248 let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
249
250 for i in 0..(partial_rounds + full_rounds) {
252 let is_full_round = !partial_round_range.contains(&i);
253 self.apply_ark(i);
254 self.apply_s_box(is_full_round);
255 self.apply_mds();
256 }
257 }
258
259 #[inline]
261 fn absorb_internal(&mut self, mut rate_start: usize, input: &[F]) {
262 if !input.is_empty() {
263 let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
264 let num_elements_remaining = input.len() - first_chunk_size;
265 let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
266 let rest_chunks = rest_chunk.chunks(RATE);
267 let total_num_chunks = 1 + (num_elements_remaining / RATE) +
272 usize::from((num_elements_remaining % RATE) != 0);
275
276 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
279 for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state[rate_start..]) {
280 *state_elem += element;
281 }
282 if i == total_num_chunks - 1 {
285 self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
286 return;
287 } else {
288 self.permute();
289 }
290 rate_start = 0;
291 }
292 }
293 }
294
295 #[inline]
297 fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [F]) {
298 let output_size = output.len();
299 if output_size != 0 {
300 let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
301 let num_output_remaining = output.len() - first_chunk_size;
302 let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
303 assert_eq!(rest_chunk.len(), num_output_remaining);
304 let rest_chunks = rest_chunk.chunks_mut(RATE);
305 let total_num_chunks = 1 + (num_output_remaining / RATE) +
310 usize::from((num_output_remaining % RATE) != 0);
313
314 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
317 let range = rate_start..(rate_start + chunk.len());
318 debug_assert_eq!(
319 chunk.len(),
320 self.state.rate_state[range.clone()].len(),
321 "failed with squeeze {output_size} at rate {RATE} and rate_start {rate_start}"
322 );
323 chunk.copy_from_slice(&self.state.rate_state[range]);
324 if i == total_num_chunks - 1 {
327 self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
328 return;
329 } else {
330 self.permute();
331 }
332 rate_start = 0;
333 }
334 }
335 }
336
337 pub fn compress_elements<TargetField: PrimeField, I: Iterator<Item = (F, F)>>(
340 &self,
341 mut src_limbs: Peekable<I>,
342 ty: OptimizationType,
343 ) -> Vec<F> {
344 let capacity = F::size_in_bits() - 1;
345 let mut dest_limbs = Vec::<F>::new();
346
347 let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty);
348
349 let mut num_bits = Vec::new();
351
352 while let Some(first) = src_limbs.next() {
353 let second = src_limbs.peek();
354
355 let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one(), &mut num_bits);
356 let second_max_bits_per_limb = if let Some(second) = second {
357 params.bits_per_limb + crate::overhead!(second.1 + F::one(), &mut num_bits)
358 } else {
359 0
360 };
361
362 if let Some(second) = second {
363 if first_max_bits_per_limb + second_max_bits_per_limb <= capacity {
364 let adjustment_factor = &self.adjustment_factor_lookup_table[second_max_bits_per_limb];
365
366 dest_limbs.push(first.0 * adjustment_factor + second.0);
367 src_limbs.next();
368 } else {
369 dest_limbs.push(first.0);
370 }
371 } else {
372 dest_limbs.push(first.0);
373 }
374 }
375
376 dest_limbs
377 }
378
379 pub fn get_limbs_representations<TargetField: PrimeField>(
383 elem: &TargetField,
384 optimization_type: OptimizationType,
385 ) -> SmallVec<[F; 10]> {
386 Self::get_limbs_representations_from_big_integer::<TargetField>(&elem.to_bigint(), optimization_type)
387 }
388
389 pub fn get_limbs_representations_from_big_integer<TargetField: PrimeField>(
391 elem: &<TargetField as PrimeField>::BigInteger,
392 optimization_type: OptimizationType,
393 ) -> SmallVec<[F; 10]> {
394 let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type);
395
396 let mut cur_bits = Vec::new();
398 let mut limbs: SmallVec<[F; 10]> = SmallVec::new();
400 let mut cur = *elem;
401 for _ in 0..params.num_limbs {
402 cur.write_bits_be(&mut cur_bits); let cur_mod_r =
404 <F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..])
405 .unwrap(); limbs.push(F::from_bigint(cur_mod_r).unwrap());
407 cur.divn(params.bits_per_limb as u32);
408 cur_bits.clear();
410 }
411
412 limbs.reverse();
414
415 limbs
416 }
417
418 pub fn push_elements_to_sponge<TargetField: PrimeField>(
421 &mut self,
422 src: impl IntoIterator<Item = TargetField>,
423 ty: OptimizationType,
424 ) {
425 let src_limbs = src
426 .into_iter()
427 .flat_map(|elem| {
428 let limbs = Self::get_limbs_representations(&elem, ty);
429 limbs.into_iter().map(|limb| (limb, F::one()))
430 })
434 .peekable();
435
436 let dest_limbs = self.compress_elements::<TargetField, _>(src_limbs, ty);
437 self.absorb_native_field_elements(&dest_limbs);
438 }
439
440 pub fn get_bits(&mut self, num_bits: usize) -> Vec<bool> {
444 let bits_per_element = F::size_in_bits() - 1;
445 let num_elements = num_bits.div_ceil(bits_per_element);
446
447 let src_elements = self.squeeze_native_field_elements(num_elements);
448 let mut dest_bits = Vec::<bool>::with_capacity(num_elements * bits_per_element);
449
450 let skip = (F::Parameters::REPR_SHAVE_BITS + 1) as usize;
451 for elem in src_elements.iter() {
452 let elem_bits = elem.to_bigint().to_bits_be();
454 dest_bits.extend_from_slice(&elem_bits[skip..]);
455 }
456 dest_bits.truncate(num_bits);
457
458 dest_bits
459 }
460
461 pub fn get_fe<TargetField: PrimeField>(
465 &mut self,
466 num_elements: usize,
467 outputs_short_elements: bool,
468 ) -> SmallVec<[TargetField; 10]> {
469 let num_bits_per_nonnative = if outputs_short_elements {
470 168
471 } else {
472 TargetField::size_in_bits() - 1 };
474 let bits = self.get_bits(num_bits_per_nonnative * num_elements);
475
476 let mut lookup_table = Vec::<TargetField>::with_capacity(num_bits_per_nonnative);
477 let mut cur = TargetField::one();
478 for _ in 0..num_bits_per_nonnative {
479 lookup_table.push(cur);
480 cur.double_in_place();
481 }
482
483 let dest_elements = bits
484 .chunks_exact(num_bits_per_nonnative)
485 .map(|per_nonnative_bits| {
486 let mut res = TargetField::zero();
489
490 for (i, bit) in per_nonnative_bits.iter().rev().enumerate() {
491 if *bit {
492 res += &lookup_table[i];
493 }
494 }
495 res
496 })
497 .collect::<SmallVec<_>>();
498 debug_assert_eq!(dest_elements.len(), num_elements);
499
500 dest_elements
501 }
502}