sp1_recursion_compiler/ir/
poseidon.rs

1use p3_field::AbstractField;
2use sp1_recursion_core::runtime::{DIGEST_SIZE, HASH_RATE, PERMUTATION_WIDTH};
3
4use super::{Array, Builder, Config, DslIr, Ext, Felt, Usize, Var};
5
6impl<C: Config> Builder<C> {
7    /// Applies the Poseidon2 permutation to the given array.
8    ///
9    /// Reference: [p3_poseidon2::Poseidon2]
10    pub fn poseidon2_permute(&mut self, array: &Array<C, Felt<C::F>>) -> Array<C, Felt<C::F>> {
11        let output = match array {
12            Array::Fixed(values) => {
13                assert_eq!(values.len(), PERMUTATION_WIDTH);
14                self.array::<Felt<C::F>>(Usize::Const(PERMUTATION_WIDTH))
15            }
16            Array::Dyn(_, len) => self.array::<Felt<C::F>>(*len),
17        };
18        self.push_op(DslIr::Poseidon2PermuteBabyBear(Box::new((output.clone(), array.clone()))));
19        output
20    }
21
22    /// Applies the Poseidon2 permutation to the given array.
23    ///
24    /// Reference: [p3_poseidon2::Poseidon2]
25    pub fn poseidon2_permute_mut(&mut self, array: &Array<C, Felt<C::F>>) {
26        self.push_op(DslIr::Poseidon2PermuteBabyBear(Box::new((array.clone(), array.clone()))));
27    }
28
29    /// Applies the Poseidon2 absorb function to the given array.
30    ///
31    /// Reference: [p3_symmetric::PaddingFreeSponge]
32    pub fn poseidon2_absorb(
33        &mut self,
34        p2_hash_and_absorb_num: Var<C::N>,
35        input: &Array<C, Felt<C::F>>,
36    ) {
37        self.push_op(DslIr::Poseidon2AbsorbBabyBear(p2_hash_and_absorb_num, input.clone()));
38    }
39
40    /// Applies the Poseidon2 finalize to the given hash number.
41    ///
42    /// Reference: [p3_symmetric::PaddingFreeSponge]
43    pub fn poseidon2_finalize_mut(
44        &mut self,
45        p2_hash_num: Var<C::N>,
46        output: &Array<C, Felt<C::F>>,
47    ) {
48        self.push_op(DslIr::Poseidon2FinalizeBabyBear(p2_hash_num, output.clone()));
49    }
50
51    /// Applies the Poseidon2 compression function to the given array.
52    ///
53    /// Reference: [p3_symmetric::TruncatedPermutation]
54    pub fn poseidon2_compress(
55        &mut self,
56        left: &Array<C, Felt<C::F>>,
57        right: &Array<C, Felt<C::F>>,
58    ) -> Array<C, Felt<C::F>> {
59        let mut input = self.dyn_array(PERMUTATION_WIDTH);
60        for i in 0..DIGEST_SIZE {
61            let a = self.get(left, i);
62            let b = self.get(right, i);
63            self.set(&mut input, i, a);
64            self.set(&mut input, i + DIGEST_SIZE, b);
65        }
66        self.poseidon2_permute_mut(&input);
67        input
68    }
69
70    /// Applies the Poseidon2 compression to the given array.
71    ///
72    /// Reference: [p3_symmetric::TruncatedPermutation]
73    pub fn poseidon2_compress_x(
74        &mut self,
75        result: &mut Array<C, Felt<C::F>>,
76        left: &Array<C, Felt<C::F>>,
77        right: &Array<C, Felt<C::F>>,
78    ) {
79        self.push_op(DslIr::Poseidon2CompressBabyBear(Box::new((
80            result.clone(),
81            left.clone(),
82            right.clone(),
83        ))));
84    }
85
86    /// Applies the Poseidon2 permutation to the given array.
87    ///
88    /// Reference: [p3_symmetric::PaddingFreeSponge]
89    pub fn poseidon2_hash(&mut self, array: &Array<C, Felt<C::F>>) -> Array<C, Felt<C::F>> {
90        let mut state: Array<C, Felt<C::F>> = self.dyn_array(PERMUTATION_WIDTH);
91
92        let break_flag: Var<_> = self.eval(C::N::zero());
93        let last_index: Usize<_> = self.eval(array.len() - 1);
94        self.range(0, array.len()).step_by(HASH_RATE).for_each(|i, builder| {
95            builder.if_eq(break_flag, C::N::one()).then(|builder| {
96                builder.break_loop();
97            });
98            // Insert elements of the chunk.
99            builder.range(0, HASH_RATE).for_each(|j, builder| {
100                let index: Var<_> = builder.eval(i + j);
101                let element = builder.get(array, index);
102                builder.set_value(&mut state, j, element);
103                builder.if_eq(index, last_index).then(|builder| {
104                    builder.assign(break_flag, C::N::one());
105                    builder.break_loop();
106                });
107            });
108
109            builder.poseidon2_permute_mut(&state);
110        });
111
112        state.truncate(self, Usize::Const(DIGEST_SIZE));
113        state
114    }
115
116    pub fn poseidon2_hash_x(
117        &mut self,
118        array: &Array<C, Array<C, Felt<C::F>>>,
119    ) -> Array<C, Felt<C::F>> {
120        self.cycle_tracker("poseidon2-hash");
121
122        let p2_hash_num = self.p2_hash_num;
123        let two_power_12: Var<_> = self.eval(C::N::from_canonical_u32(1 << 12));
124
125        self.range(0, array.len()).for_each(|i, builder| {
126            let subarray = builder.get(array, i);
127            let p2_hash_and_absorb_num: Var<_> = builder.eval(p2_hash_num * two_power_12 + i);
128
129            builder.poseidon2_absorb(p2_hash_and_absorb_num, &subarray);
130        });
131
132        let output: Array<C, Felt<C::F>> = self.dyn_array(DIGEST_SIZE);
133        self.poseidon2_finalize_mut(self.p2_hash_num, &output);
134
135        self.assign(self.p2_hash_num, self.p2_hash_num + C::N::one());
136
137        self.cycle_tracker("poseidon2-hash");
138        output
139    }
140
141    pub fn poseidon2_hash_ext(
142        &mut self,
143        array: &Array<C, Array<C, Ext<C::F, C::EF>>>,
144    ) -> Array<C, Felt<C::F>> {
145        self.cycle_tracker("poseidon2-hash-ext");
146        let mut state: Array<C, Felt<C::F>> = self.dyn_array(PERMUTATION_WIDTH);
147
148        let idx: Var<_> = self.eval(C::N::zero());
149        self.range(0, array.len()).for_each(|i, builder| {
150            let subarray = builder.get(array, i);
151            builder.range(0, subarray.len()).for_each(|j, builder| {
152                let element = builder.get(&subarray, j);
153                let felts = builder.ext2felt(element);
154                for i in 0..4 {
155                    let felt = builder.get(&felts, i);
156                    builder.set_value(&mut state, idx, felt);
157                    builder.assign(idx, idx + C::N::one());
158                    builder.if_eq(idx, C::N::from_canonical_usize(HASH_RATE)).then(|builder| {
159                        builder.poseidon2_permute_mut(&state);
160                        builder.assign(idx, C::N::zero());
161                    });
162                }
163            });
164        });
165
166        self.if_ne(idx, C::N::zero()).then(|builder| {
167            builder.poseidon2_permute_mut(&state);
168        });
169
170        state.truncate(self, Usize::Const(DIGEST_SIZE));
171        self.cycle_tracker("poseidon2-hash-ext");
172        state
173    }
174}