sp1_recursion_compiler/circuit/
builder.rs

1//! An implementation of Poseidon2 over BN254.
2
3use std::{borrow::Cow, iter::repeat};
4
5use crate::prelude::*;
6use itertools::Itertools;
7use p3_baby_bear::BabyBear;
8use p3_field::{AbstractExtensionField, AbstractField};
9use sp1_recursion_core::air::RecursionPublicValues;
10use sp1_recursion_core::{chips::poseidon2_skinny::WIDTH, D, DIGEST_SIZE, HASH_RATE};
11use sp1_stark::septic_curve::SepticCurve;
12use sp1_stark::septic_digest::SepticDigest;
13use sp1_stark::septic_extension::SepticExtension;
14
15pub trait CircuitV2Builder<C: Config> {
16    fn bits2num_v2_f(
17        &mut self,
18        bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
19    ) -> Felt<C::F>;
20    fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>>;
21    fn exp_reverse_bits_v2(&mut self, input: Felt<C::F>, power_bits: Vec<Felt<C::F>>)
22        -> Felt<C::F>;
23    fn batch_fri_v2(
24        &mut self,
25        alphas: Vec<Ext<C::F, C::EF>>,
26        p_at_zs: Vec<Ext<C::F, C::EF>>,
27        p_at_xs: Vec<Felt<C::F>>,
28    ) -> Ext<C::F, C::EF>;
29    fn poseidon2_permute_v2(&mut self, state: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH];
30    fn poseidon2_hash_v2(&mut self, array: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE];
31    fn poseidon2_compress_v2(
32        &mut self,
33        input: impl IntoIterator<Item = Felt<C::F>>,
34    ) -> [Felt<C::F>; DIGEST_SIZE];
35    fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C>;
36    fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D];
37    fn add_curve_v2(
38        &mut self,
39        point1: SepticCurve<Felt<C::F>>,
40        point2: SepticCurve<Felt<C::F>>,
41    ) -> SepticCurve<Felt<C::F>>;
42    fn assert_digest_zero_v2(&mut self, is_real: Felt<C::F>, digest: SepticDigest<Felt<C::F>>);
43    fn sum_digest_v2(&mut self, digests: Vec<SepticDigest<Felt<C::F>>>)
44        -> SepticDigest<Felt<C::F>>;
45    fn select_global_cumulative_sum(
46        &mut self,
47        is_first_shard: Felt<C::F>,
48        vk_digest: SepticDigest<Felt<C::F>>,
49    ) -> SepticDigest<Felt<C::F>>;
50    fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>);
51    fn cycle_tracker_v2_enter(&mut self, name: impl Into<Cow<'static, str>>);
52    fn cycle_tracker_v2_exit(&mut self);
53    fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF>;
54    fn hint_felt_v2(&mut self) -> Felt<C::F>;
55    fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>>;
56    fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>>;
57}
58
59impl<C: Config<F = BabyBear>> CircuitV2Builder<C> for Builder<C> {
60    fn bits2num_v2_f(
61        &mut self,
62        bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
63    ) -> Felt<<C as Config>::F> {
64        let mut num: Felt<_> = self.eval(C::F::zero());
65        for (i, bit) in bits.into_iter().enumerate() {
66            // Add `bit * 2^i` to the sum.
67            num = self.eval(num + bit * C::F::from_wrapped_u32(1 << i));
68        }
69        num
70    }
71
72    /// Converts a felt to bits inside a circuit.
73    fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>> {
74        let output = std::iter::from_fn(|| Some(self.uninit())).take(num_bits).collect::<Vec<_>>();
75        self.push_op(DslIr::CircuitV2HintBitsF(output.clone(), num));
76
77        let x: SymbolicFelt<_> = output
78            .iter()
79            .enumerate()
80            .map(|(i, &bit)| {
81                self.assert_felt_eq(bit * (bit - C::F::one()), C::F::zero());
82                bit * C::F::from_wrapped_u32(1 << i)
83            })
84            .sum();
85
86        // Range check the bits to be less than the BabyBear modulus.
87
88        assert!(num_bits <= 31, "num_bits must be less than or equal to 31");
89
90        // If there are less than 31 bits, there is nothing to check.
91        if num_bits > 30 {
92            // Since BabyBear modulus is 2^31 - 2^27 + 1, if any of the top `4` bits are zero, the
93            // number is less than 2^27, and we can stop the iteration. Othwriwse, if all the top
94            // `4` bits are '1`, we need to check that all the bottom `27` are '0`
95
96            // Get a flag that is zero if any of the top `4` bits are zero, and one otherwise. We
97            // can do this by simply taking their product (which is bitwise AND).
98            let are_all_top_bits_one: Felt<_> = self.eval(
99                output
100                    .iter()
101                    .rev()
102                    .take(4)
103                    .copied()
104                    .map(SymbolicFelt::from)
105                    .product::<SymbolicFelt<_>>(),
106            );
107
108            // Assert that if all the top `4` bits are one, then all the bottom `27` bits are zero.
109            for bit in output.iter().take(27).copied() {
110                self.assert_felt_eq(bit * are_all_top_bits_one, C::F::zero());
111            }
112        }
113
114        // Check that the original number matches the bit decomposition.
115        self.assert_felt_eq(x, num);
116
117        output
118    }
119
120    /// A version of `exp_reverse_bits_len` that uses the ExpReverseBitsLen precompile.
121    fn exp_reverse_bits_v2(
122        &mut self,
123        input: Felt<C::F>,
124        power_bits: Vec<Felt<C::F>>,
125    ) -> Felt<C::F> {
126        let output: Felt<_> = self.uninit();
127        self.push_op(DslIr::CircuitV2ExpReverseBits(output, input, power_bits));
128        output
129    }
130
131    /// A version of the `batch_fri` that uses the BatchFRI precompile.
132    fn batch_fri_v2(
133        &mut self,
134        alpha_pows: Vec<Ext<C::F, C::EF>>,
135        p_at_zs: Vec<Ext<C::F, C::EF>>,
136        p_at_xs: Vec<Felt<C::F>>,
137    ) -> Ext<C::F, C::EF> {
138        let output: Ext<_, _> = self.uninit();
139        self.push_op(DslIr::CircuitV2BatchFRI(Box::new((output, alpha_pows, p_at_zs, p_at_xs))));
140        output
141    }
142
143    /// Applies the Poseidon2 permutation to the given array.
144    fn poseidon2_permute_v2(&mut self, array: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH] {
145        let output: [Felt<C::F>; WIDTH] = core::array::from_fn(|_| self.uninit());
146        self.push_op(DslIr::CircuitV2Poseidon2PermuteBabyBear(Box::new((output, array))));
147        output
148    }
149
150    /// Applies the Poseidon2 hash function to the given array.
151    ///
152    /// Reference: [p3_symmetric::PaddingFreeSponge]
153    fn poseidon2_hash_v2(&mut self, input: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE] {
154        // static_assert(RATE < WIDTH)
155        let mut state = core::array::from_fn(|_| self.eval(C::F::zero()));
156        for input_chunk in input.chunks(HASH_RATE) {
157            state[..input_chunk.len()].copy_from_slice(input_chunk);
158            state = self.poseidon2_permute_v2(state);
159        }
160        let state: [Felt<C::F>; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
161        state
162    }
163
164    /// Applies the Poseidon2 compression function to the given array.
165    ///
166    /// Reference: [p3_symmetric::TruncatedPermutation]
167    fn poseidon2_compress_v2(
168        &mut self,
169        input: impl IntoIterator<Item = Felt<C::F>>,
170    ) -> [Felt<C::F>; DIGEST_SIZE] {
171        // debug_assert!(DIGEST_SIZE * N <= WIDTH);
172        let mut pre_iter = input.into_iter().chain(repeat(self.eval(C::F::default())));
173        let pre = core::array::from_fn(move |_| pre_iter.next().unwrap());
174        let post = self.poseidon2_permute_v2(pre);
175        let post: [Felt<C::F>; DIGEST_SIZE] = post[..DIGEST_SIZE].try_into().unwrap();
176        post
177    }
178
179    /// Runs FRI fold.
180    fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C> {
181        let mut uninit_vec = |len| std::iter::from_fn(|| Some(self.uninit())).take(len).collect();
182        let output = CircuitV2FriFoldOutput {
183            alpha_pow_output: uninit_vec(input.alpha_pow_input.len()),
184            ro_output: uninit_vec(input.ro_input.len()),
185        };
186        self.push_op(DslIr::CircuitV2FriFold(Box::new((output.clone(), input))));
187        output
188    }
189
190    /// Decomposes an ext into its felt coordinates.
191    fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D] {
192        let felts = core::array::from_fn(|_| self.uninit());
193        self.push_op(DslIr::CircuitExt2Felt(felts, ext));
194        // Verify that the decomposed extension element is correct.
195        let mut reconstructed_ext: Ext<C::F, C::EF> = self.constant(C::EF::zero());
196        for i in 0..4 {
197            let felt = felts[i];
198            let monomial: Ext<C::F, C::EF> = self.constant(C::EF::monomial(i));
199            reconstructed_ext = self.eval(reconstructed_ext + monomial * felt);
200        }
201
202        self.assert_ext_eq(reconstructed_ext, ext);
203
204        felts
205    }
206
207    /// Adds two septic elliptic curve points.
208    fn add_curve_v2(
209        &mut self,
210        point1: SepticCurve<Felt<C::F>>,
211        point2: SepticCurve<Felt<C::F>>,
212    ) -> SepticCurve<Felt<C::F>> {
213        // Hint the curve addition result.
214        let point_sum_x: [Felt<C::F>; 7] = core::array::from_fn(|_| self.uninit());
215        let point_sum_y: [Felt<C::F>; 7] = core::array::from_fn(|_| self.uninit());
216        let point =
217            SepticCurve { x: SepticExtension(point_sum_x), y: SepticExtension(point_sum_y) };
218        self.push_op(DslIr::CircuitV2HintAddCurve(Box::new((point, point1, point2))));
219
220        // Convert each point into a point over SymbolicFelt.
221        let point1_symbolic = SepticCurve::convert(point1, |x| x.into());
222        let point2_symbolic = SepticCurve::convert(point2, |x| x.into());
223        let point_symbolic = SepticCurve::convert(point, |x| x.into());
224
225        // Evaluate `sum_checker_x` and `sum_checker_y`.
226        let sum_checker_x = SepticCurve::<SymbolicFelt<C::F>>::sum_checker_x(
227            point1_symbolic,
228            point2_symbolic,
229            point_symbolic,
230        );
231
232        let sum_checker_y = SepticCurve::<SymbolicFelt<C::F>>::sum_checker_y(
233            point1_symbolic,
234            point2_symbolic,
235            point_symbolic,
236        );
237
238        // Constrain `sum_checker_x` and `sum_checker_y` to be all zero.
239        for limb in sum_checker_x.0 {
240            self.assert_felt_eq(limb, C::F::zero());
241        }
242
243        for limb in sum_checker_y.0 {
244            self.assert_felt_eq(limb, C::F::zero());
245        }
246
247        point
248    }
249
250    /// Asserts that the `digest` is the zero digest when `is_real` is non-zero.
251    fn assert_digest_zero_v2(&mut self, is_real: Felt<C::F>, digest: SepticDigest<Felt<C::F>>) {
252        let zero = SepticDigest::<SymbolicFelt<C::F>>::zero();
253        for (digest_limb_x, zero_limb_x) in digest.0.x.0.into_iter().zip_eq(zero.0.x.0.into_iter())
254        {
255            self.assert_felt_eq(is_real * digest_limb_x, is_real * zero_limb_x);
256        }
257        for (digest_limb_y, zero_limb_y) in digest.0.y.0.into_iter().zip_eq(zero.0.y.0.into_iter())
258        {
259            self.assert_felt_eq(is_real * digest_limb_y, is_real * zero_limb_y);
260        }
261    }
262
263    /// Returns the zero digest when `is_first_shard` is zero, and returns the `vk_digest` when `is_first_shard` is one.
264    /// It is assumed that `is_first_shard` is already checked to be a boolean.
265    fn select_global_cumulative_sum(
266        &mut self,
267        is_first_shard: Felt<C::F>,
268        vk_digest: SepticDigest<Felt<C::F>>,
269    ) -> SepticDigest<Felt<C::F>> {
270        let zero = SepticDigest::<SymbolicFelt<C::F>>::zero();
271        let one: Felt<C::F> = self.constant(C::F::one());
272        let x = SepticExtension(core::array::from_fn(|i| {
273            self.eval(is_first_shard * vk_digest.0.x.0[i] + (one - is_first_shard) * zero.0.x.0[i])
274        }));
275        let y = SepticExtension(core::array::from_fn(|i| {
276            self.eval(is_first_shard * vk_digest.0.y.0[i] + (one - is_first_shard) * zero.0.y.0[i])
277        }));
278        SepticDigest(SepticCurve { x, y })
279    }
280
281    // Sums the digests into one.
282    fn sum_digest_v2(
283        &mut self,
284        digests: Vec<SepticDigest<Felt<C::F>>>,
285    ) -> SepticDigest<Felt<C::F>> {
286        let mut convert_to_felt =
287            |point: SepticCurve<C::F>| SepticCurve::convert(point, |value| self.eval(value));
288
289        let start = convert_to_felt(SepticDigest::starting_digest().0);
290        let zero_digest = convert_to_felt(SepticDigest::zero().0);
291
292        if digests.is_empty() {
293            return SepticDigest(zero_digest);
294        }
295
296        let neg_start = convert_to_felt(SepticDigest::starting_digest().0.neg());
297        let neg_zero_digest = convert_to_felt(SepticDigest::zero().0.neg());
298
299        let mut ret = start;
300        for (i, digest) in digests.clone().into_iter().enumerate() {
301            ret = self.add_curve_v2(ret, digest.0);
302            if i != digests.len() - 1 {
303                ret = self.add_curve_v2(ret, neg_zero_digest)
304            }
305        }
306        SepticDigest(self.add_curve_v2(ret, neg_start))
307    }
308
309    // Commits public values.
310    fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>) {
311        self.push_op(DslIr::CircuitV2CommitPublicValues(Box::new(public_values)));
312    }
313
314    fn cycle_tracker_v2_enter(&mut self, name: impl Into<Cow<'static, str>>) {
315        self.push_op(DslIr::CycleTrackerV2Enter(name.into()));
316    }
317
318    fn cycle_tracker_v2_exit(&mut self) {
319        self.push_op(DslIr::CycleTrackerV2Exit);
320    }
321
322    /// Hint a single felt.
323    fn hint_felt_v2(&mut self) -> Felt<C::F> {
324        self.hint_felts_v2(1)[0]
325    }
326
327    /// Hint a single ext.
328    fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF> {
329        self.hint_exts_v2(1)[0]
330    }
331
332    /// Hint a vector of felts.
333    fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>> {
334        let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
335        self.push_op(DslIr::CircuitV2HintFelts(arr[0], len));
336        arr
337    }
338
339    /// Hint a vector of exts.
340    fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>> {
341        let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
342        self.push_op(DslIr::CircuitV2HintExts(arr[0], len));
343        arr
344    }
345}