sp1_recursion_compiler/circuit/
builder.rs

1//! An implementation of Poseidon2 over BN254.
2
3use std::{borrow::Cow, iter::repeat};
4
5use crate::prelude::*;
6use itertools::Itertools;
7use p3_baby_bear::BabyBear;
8use p3_field::{AbstractExtensionField, AbstractField};
9use sp1_recursion_core::{
10    air::RecursionPublicValues, chips::poseidon2_skinny::WIDTH, D, DIGEST_SIZE, HASH_RATE,
11};
12use sp1_stark::{
13    septic_curve::SepticCurve, septic_digest::SepticDigest, septic_extension::SepticExtension,
14};
15
16pub trait CircuitV2Builder<C: Config> {
17    fn bits2num_v2_f(
18        &mut self,
19        bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
20    ) -> Felt<C::F>;
21    fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>>;
22    fn exp_reverse_bits_v2(&mut self, input: Felt<C::F>, power_bits: Vec<Felt<C::F>>)
23        -> Felt<C::F>;
24    fn batch_fri_v2(
25        &mut self,
26        alphas: Vec<Ext<C::F, C::EF>>,
27        p_at_zs: Vec<Ext<C::F, C::EF>>,
28        p_at_xs: Vec<Felt<C::F>>,
29    ) -> Ext<C::F, C::EF>;
30    fn poseidon2_permute_v2(&mut self, state: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH];
31    fn poseidon2_hash_v2(&mut self, array: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE];
32    fn poseidon2_compress_v2(
33        &mut self,
34        input: impl IntoIterator<Item = Felt<C::F>>,
35    ) -> [Felt<C::F>; DIGEST_SIZE];
36    fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C>;
37    fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D];
38    fn add_curve_v2(
39        &mut self,
40        point1: SepticCurve<Felt<C::F>>,
41        point2: SepticCurve<Felt<C::F>>,
42    ) -> SepticCurve<Felt<C::F>>;
43    fn assert_digest_zero_v2(&mut self, is_real: Felt<C::F>, digest: SepticDigest<Felt<C::F>>);
44    fn sum_digest_v2(&mut self, digests: Vec<SepticDigest<Felt<C::F>>>)
45        -> SepticDigest<Felt<C::F>>;
46    fn select_global_cumulative_sum(
47        &mut self,
48        is_first_shard: Felt<C::F>,
49        vk_digest: SepticDigest<Felt<C::F>>,
50    ) -> SepticDigest<Felt<C::F>>;
51    fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>);
52    fn cycle_tracker_v2_enter(&mut self, name: impl Into<Cow<'static, str>>);
53    fn cycle_tracker_v2_exit(&mut self);
54    fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF>;
55    fn hint_felt_v2(&mut self) -> Felt<C::F>;
56    fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>>;
57    fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>>;
58}
59
60impl<C: Config<F = BabyBear>> CircuitV2Builder<C> for Builder<C> {
61    fn bits2num_v2_f(
62        &mut self,
63        bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
64    ) -> Felt<<C as Config>::F> {
65        let mut num: Felt<_> = self.eval(C::F::zero());
66        for (i, bit) in bits.into_iter().enumerate() {
67            // Add `bit * 2^i` to the sum.
68            num = self.eval(num + bit * C::F::from_wrapped_u32(1 << i));
69        }
70        num
71    }
72
73    /// Converts a felt to bits inside a circuit.
74    fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>> {
75        let output = std::iter::from_fn(|| Some(self.uninit())).take(num_bits).collect::<Vec<_>>();
76        self.push_op(DslIr::CircuitV2HintBitsF(output.clone(), num));
77
78        let x: SymbolicFelt<_> = output
79            .iter()
80            .enumerate()
81            .map(|(i, &bit)| {
82                self.assert_felt_eq(bit * (bit - C::F::one()), C::F::zero());
83                bit * C::F::from_wrapped_u32(1 << i)
84            })
85            .sum();
86
87        // Range check the bits to be less than the BabyBear modulus.
88
89        assert!(num_bits <= 31, "num_bits must be less than or equal to 31");
90
91        // If there are less than 31 bits, there is nothing to check.
92        if num_bits > 30 {
93            // Since BabyBear modulus is 2^31 - 2^27 + 1, if any of the top `4` bits are zero, the
94            // number is less than 2^27, and we can stop the iteration. Othwriwse, if all the top
95            // `4` bits are '1`, we need to check that all the bottom `27` are '0`
96
97            // Get a flag that is zero if any of the top `4` bits are zero, and one otherwise. We
98            // can do this by simply taking their product (which is bitwise AND).
99            let are_all_top_bits_one: Felt<_> = self.eval(
100                output
101                    .iter()
102                    .rev()
103                    .take(4)
104                    .copied()
105                    .map(SymbolicFelt::from)
106                    .product::<SymbolicFelt<_>>(),
107            );
108
109            // Assert that if all the top `4` bits are one, then all the bottom `27` bits are zero.
110            for bit in output.iter().take(27).copied() {
111                self.assert_felt_eq(bit * are_all_top_bits_one, C::F::zero());
112            }
113        }
114
115        // Check that the original number matches the bit decomposition.
116        self.assert_felt_eq(x, num);
117
118        output
119    }
120
121    /// A version of `exp_reverse_bits_len` that uses the ExpReverseBitsLen precompile.
122    fn exp_reverse_bits_v2(
123        &mut self,
124        input: Felt<C::F>,
125        power_bits: Vec<Felt<C::F>>,
126    ) -> Felt<C::F> {
127        let output: Felt<_> = self.uninit();
128        self.push_op(DslIr::CircuitV2ExpReverseBits(output, input, power_bits));
129        output
130    }
131
132    /// A version of the `batch_fri` that uses the BatchFRI precompile.
133    fn batch_fri_v2(
134        &mut self,
135        alpha_pows: Vec<Ext<C::F, C::EF>>,
136        p_at_zs: Vec<Ext<C::F, C::EF>>,
137        p_at_xs: Vec<Felt<C::F>>,
138    ) -> Ext<C::F, C::EF> {
139        let output: Ext<_, _> = self.uninit();
140        self.push_op(DslIr::CircuitV2BatchFRI(Box::new((output, alpha_pows, p_at_zs, p_at_xs))));
141        output
142    }
143
144    /// Applies the Poseidon2 permutation to the given array.
145    fn poseidon2_permute_v2(&mut self, array: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH] {
146        let output: [Felt<C::F>; WIDTH] = core::array::from_fn(|_| self.uninit());
147        self.push_op(DslIr::CircuitV2Poseidon2PermuteBabyBear(Box::new((output, array))));
148        output
149    }
150
151    /// Applies the Poseidon2 hash function to the given array.
152    ///
153    /// Reference: [p3_symmetric::PaddingFreeSponge]
154    fn poseidon2_hash_v2(&mut self, input: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE] {
155        // static_assert(RATE < WIDTH)
156        let mut state = core::array::from_fn(|_| self.eval(C::F::zero()));
157        for input_chunk in input.chunks(HASH_RATE) {
158            state[..input_chunk.len()].copy_from_slice(input_chunk);
159            state = self.poseidon2_permute_v2(state);
160        }
161        let state: [Felt<C::F>; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
162        state
163    }
164
165    /// Applies the Poseidon2 compression function to the given array.
166    ///
167    /// Reference: [p3_symmetric::TruncatedPermutation]
168    fn poseidon2_compress_v2(
169        &mut self,
170        input: impl IntoIterator<Item = Felt<C::F>>,
171    ) -> [Felt<C::F>; DIGEST_SIZE] {
172        // debug_assert!(DIGEST_SIZE * N <= WIDTH);
173        let mut pre_iter = input.into_iter().chain(repeat(self.eval(C::F::default())));
174        let pre = core::array::from_fn(move |_| pre_iter.next().unwrap());
175        let post = self.poseidon2_permute_v2(pre);
176        let post: [Felt<C::F>; DIGEST_SIZE] = post[..DIGEST_SIZE].try_into().unwrap();
177        post
178    }
179
180    /// Runs FRI fold.
181    fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C> {
182        let mut uninit_vec = |len| std::iter::from_fn(|| Some(self.uninit())).take(len).collect();
183        let output = CircuitV2FriFoldOutput {
184            alpha_pow_output: uninit_vec(input.alpha_pow_input.len()),
185            ro_output: uninit_vec(input.ro_input.len()),
186        };
187        self.push_op(DslIr::CircuitV2FriFold(Box::new((output.clone(), input))));
188        output
189    }
190
191    /// Decomposes an ext into its felt coordinates.
192    fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D] {
193        let felts = core::array::from_fn(|_| self.uninit());
194        self.push_op(DslIr::CircuitExt2Felt(felts, ext));
195        // Verify that the decomposed extension element is correct.
196        let mut reconstructed_ext: Ext<C::F, C::EF> = self.constant(C::EF::zero());
197        for i in 0..4 {
198            let felt = felts[i];
199            let monomial: Ext<C::F, C::EF> = self.constant(C::EF::monomial(i));
200            reconstructed_ext = self.eval(reconstructed_ext + monomial * felt);
201        }
202
203        self.assert_ext_eq(reconstructed_ext, ext);
204
205        felts
206    }
207
208    /// Adds two septic elliptic curve points.
209    fn add_curve_v2(
210        &mut self,
211        point1: SepticCurve<Felt<C::F>>,
212        point2: SepticCurve<Felt<C::F>>,
213    ) -> SepticCurve<Felt<C::F>> {
214        // Hint the curve addition result.
215        let point_sum_x: [Felt<C::F>; 7] = core::array::from_fn(|_| self.uninit());
216        let point_sum_y: [Felt<C::F>; 7] = core::array::from_fn(|_| self.uninit());
217        let point =
218            SepticCurve { x: SepticExtension(point_sum_x), y: SepticExtension(point_sum_y) };
219        self.push_op(DslIr::CircuitV2HintAddCurve(Box::new((point, point1, point2))));
220
221        // Convert each point into a point over SymbolicFelt.
222        let point1_symbolic = SepticCurve::convert(point1, |x| x.into());
223        let point2_symbolic = SepticCurve::convert(point2, |x| x.into());
224        let point_symbolic = SepticCurve::convert(point, |x| x.into());
225
226        // Evaluate `sum_checker_x` and `sum_checker_y`.
227        let sum_checker_x = SepticCurve::<SymbolicFelt<C::F>>::sum_checker_x(
228            point1_symbolic,
229            point2_symbolic,
230            point_symbolic,
231        );
232
233        let sum_checker_y = SepticCurve::<SymbolicFelt<C::F>>::sum_checker_y(
234            point1_symbolic,
235            point2_symbolic,
236            point_symbolic,
237        );
238
239        // Constrain `sum_checker_x` and `sum_checker_y` to be all zero.
240        for limb in sum_checker_x.0 {
241            self.assert_felt_eq(limb, C::F::zero());
242        }
243
244        for limb in sum_checker_y.0 {
245            self.assert_felt_eq(limb, C::F::zero());
246        }
247
248        point
249    }
250
251    /// Asserts that the `digest` is the zero digest when `is_real` is non-zero.
252    fn assert_digest_zero_v2(&mut self, is_real: Felt<C::F>, digest: SepticDigest<Felt<C::F>>) {
253        let zero = SepticDigest::<SymbolicFelt<C::F>>::zero();
254        for (digest_limb_x, zero_limb_x) in digest.0.x.0.into_iter().zip_eq(zero.0.x.0.into_iter())
255        {
256            self.assert_felt_eq(is_real * digest_limb_x, is_real * zero_limb_x);
257        }
258        for (digest_limb_y, zero_limb_y) in digest.0.y.0.into_iter().zip_eq(zero.0.y.0.into_iter())
259        {
260            self.assert_felt_eq(is_real * digest_limb_y, is_real * zero_limb_y);
261        }
262    }
263
264    /// Returns the zero digest when `is_first_shard` is zero, and returns the `vk_digest` when
265    /// `is_first_shard` is one. It is assumed that `is_first_shard` is already checked to be a
266    /// boolean.
267    fn select_global_cumulative_sum(
268        &mut self,
269        is_first_shard: Felt<C::F>,
270        vk_digest: SepticDigest<Felt<C::F>>,
271    ) -> SepticDigest<Felt<C::F>> {
272        let zero = SepticDigest::<SymbolicFelt<C::F>>::zero();
273        let one: Felt<C::F> = self.constant(C::F::one());
274        let x = SepticExtension(core::array::from_fn(|i| {
275            self.eval(is_first_shard * vk_digest.0.x.0[i] + (one - is_first_shard) * zero.0.x.0[i])
276        }));
277        let y = SepticExtension(core::array::from_fn(|i| {
278            self.eval(is_first_shard * vk_digest.0.y.0[i] + (one - is_first_shard) * zero.0.y.0[i])
279        }));
280        SepticDigest(SepticCurve { x, y })
281    }
282
283    // Sums the digests into one.
284    fn sum_digest_v2(
285        &mut self,
286        digests: Vec<SepticDigest<Felt<C::F>>>,
287    ) -> SepticDigest<Felt<C::F>> {
288        let mut convert_to_felt =
289            |point: SepticCurve<C::F>| SepticCurve::convert(point, |value| self.eval(value));
290
291        let start = convert_to_felt(SepticDigest::starting_digest().0);
292        let zero_digest = convert_to_felt(SepticDigest::zero().0);
293
294        if digests.is_empty() {
295            return SepticDigest(zero_digest);
296        }
297
298        let neg_start = convert_to_felt(SepticDigest::starting_digest().0.neg());
299        let neg_zero_digest = convert_to_felt(SepticDigest::zero().0.neg());
300
301        let mut ret = start;
302        for (i, digest) in digests.clone().into_iter().enumerate() {
303            ret = self.add_curve_v2(ret, digest.0);
304            if i != digests.len() - 1 {
305                ret = self.add_curve_v2(ret, neg_zero_digest)
306            }
307        }
308        SepticDigest(self.add_curve_v2(ret, neg_start))
309    }
310
311    // Commits public values.
312    fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>) {
313        self.push_op(DslIr::CircuitV2CommitPublicValues(Box::new(public_values)));
314    }
315
316    fn cycle_tracker_v2_enter(&mut self, name: impl Into<Cow<'static, str>>) {
317        self.push_op(DslIr::CycleTrackerV2Enter(name.into()));
318    }
319
320    fn cycle_tracker_v2_exit(&mut self) {
321        self.push_op(DslIr::CycleTrackerV2Exit);
322    }
323
324    /// Hint a single felt.
325    fn hint_felt_v2(&mut self) -> Felt<C::F> {
326        self.hint_felts_v2(1)[0]
327    }
328
329    /// Hint a single ext.
330    fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF> {
331        self.hint_exts_v2(1)[0]
332    }
333
334    /// Hint a vector of felts.
335    fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>> {
336        let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
337        self.push_op(DslIr::CircuitV2HintFelts(arr[0], len));
338        arr
339    }
340
341    /// Hint a vector of exts.
342    fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>> {
343        let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
344        self.push_op(DslIr::CircuitV2HintExts(arr[0], len));
345        arr
346    }
347}