sp1_recursion_compiler/circuit/
builder.rs1use std::{borrow::Cow, iter::repeat};
4
5use crate::prelude::*;
6use itertools::Itertools;
7use p3_baby_bear::BabyBear;
8use p3_field::{AbstractExtensionField, AbstractField};
9use sp1_recursion_core::{
10 air::RecursionPublicValues, chips::poseidon2_skinny::WIDTH, D, DIGEST_SIZE, HASH_RATE,
11};
12use sp1_stark::{
13 septic_curve::SepticCurve, septic_digest::SepticDigest, septic_extension::SepticExtension,
14};
15
16pub trait CircuitV2Builder<C: Config> {
17 fn bits2num_v2_f(
18 &mut self,
19 bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
20 ) -> Felt<C::F>;
21 fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>>;
22 fn exp_reverse_bits_v2(&mut self, input: Felt<C::F>, power_bits: Vec<Felt<C::F>>)
23 -> Felt<C::F>;
24 fn batch_fri_v2(
25 &mut self,
26 alphas: Vec<Ext<C::F, C::EF>>,
27 p_at_zs: Vec<Ext<C::F, C::EF>>,
28 p_at_xs: Vec<Felt<C::F>>,
29 ) -> Ext<C::F, C::EF>;
30 fn poseidon2_permute_v2(&mut self, state: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH];
31 fn poseidon2_hash_v2(&mut self, array: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE];
32 fn poseidon2_compress_v2(
33 &mut self,
34 input: impl IntoIterator<Item = Felt<C::F>>,
35 ) -> [Felt<C::F>; DIGEST_SIZE];
36 fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C>;
37 fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D];
38 fn add_curve_v2(
39 &mut self,
40 point1: SepticCurve<Felt<C::F>>,
41 point2: SepticCurve<Felt<C::F>>,
42 ) -> SepticCurve<Felt<C::F>>;
43 fn assert_digest_zero_v2(&mut self, is_real: Felt<C::F>, digest: SepticDigest<Felt<C::F>>);
44 fn sum_digest_v2(&mut self, digests: Vec<SepticDigest<Felt<C::F>>>)
45 -> SepticDigest<Felt<C::F>>;
46 fn select_global_cumulative_sum(
47 &mut self,
48 is_first_shard: Felt<C::F>,
49 vk_digest: SepticDigest<Felt<C::F>>,
50 ) -> SepticDigest<Felt<C::F>>;
51 fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>);
52 fn cycle_tracker_v2_enter(&mut self, name: impl Into<Cow<'static, str>>);
53 fn cycle_tracker_v2_exit(&mut self);
54 fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF>;
55 fn hint_felt_v2(&mut self) -> Felt<C::F>;
56 fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>>;
57 fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>>;
58}
59
60impl<C: Config<F = BabyBear>> CircuitV2Builder<C> for Builder<C> {
61 fn bits2num_v2_f(
62 &mut self,
63 bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
64 ) -> Felt<<C as Config>::F> {
65 let mut num: Felt<_> = self.eval(C::F::zero());
66 for (i, bit) in bits.into_iter().enumerate() {
67 num = self.eval(num + bit * C::F::from_wrapped_u32(1 << i));
69 }
70 num
71 }
72
73 fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>> {
75 let output = std::iter::from_fn(|| Some(self.uninit())).take(num_bits).collect::<Vec<_>>();
76 self.push_op(DslIr::CircuitV2HintBitsF(output.clone(), num));
77
78 let x: SymbolicFelt<_> = output
79 .iter()
80 .enumerate()
81 .map(|(i, &bit)| {
82 self.assert_felt_eq(bit * (bit - C::F::one()), C::F::zero());
83 bit * C::F::from_wrapped_u32(1 << i)
84 })
85 .sum();
86
87 assert!(num_bits <= 31, "num_bits must be less than or equal to 31");
90
91 if num_bits > 30 {
93 let are_all_top_bits_one: Felt<_> = self.eval(
100 output
101 .iter()
102 .rev()
103 .take(4)
104 .copied()
105 .map(SymbolicFelt::from)
106 .product::<SymbolicFelt<_>>(),
107 );
108
109 for bit in output.iter().take(27).copied() {
111 self.assert_felt_eq(bit * are_all_top_bits_one, C::F::zero());
112 }
113 }
114
115 self.assert_felt_eq(x, num);
117
118 output
119 }
120
121 fn exp_reverse_bits_v2(
123 &mut self,
124 input: Felt<C::F>,
125 power_bits: Vec<Felt<C::F>>,
126 ) -> Felt<C::F> {
127 let output: Felt<_> = self.uninit();
128 self.push_op(DslIr::CircuitV2ExpReverseBits(output, input, power_bits));
129 output
130 }
131
132 fn batch_fri_v2(
134 &mut self,
135 alpha_pows: Vec<Ext<C::F, C::EF>>,
136 p_at_zs: Vec<Ext<C::F, C::EF>>,
137 p_at_xs: Vec<Felt<C::F>>,
138 ) -> Ext<C::F, C::EF> {
139 let output: Ext<_, _> = self.uninit();
140 self.push_op(DslIr::CircuitV2BatchFRI(Box::new((output, alpha_pows, p_at_zs, p_at_xs))));
141 output
142 }
143
144 fn poseidon2_permute_v2(&mut self, array: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH] {
146 let output: [Felt<C::F>; WIDTH] = core::array::from_fn(|_| self.uninit());
147 self.push_op(DslIr::CircuitV2Poseidon2PermuteBabyBear(Box::new((output, array))));
148 output
149 }
150
151 fn poseidon2_hash_v2(&mut self, input: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE] {
155 let mut state = core::array::from_fn(|_| self.eval(C::F::zero()));
157 for input_chunk in input.chunks(HASH_RATE) {
158 state[..input_chunk.len()].copy_from_slice(input_chunk);
159 state = self.poseidon2_permute_v2(state);
160 }
161 let state: [Felt<C::F>; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
162 state
163 }
164
165 fn poseidon2_compress_v2(
169 &mut self,
170 input: impl IntoIterator<Item = Felt<C::F>>,
171 ) -> [Felt<C::F>; DIGEST_SIZE] {
172 let mut pre_iter = input.into_iter().chain(repeat(self.eval(C::F::default())));
174 let pre = core::array::from_fn(move |_| pre_iter.next().unwrap());
175 let post = self.poseidon2_permute_v2(pre);
176 let post: [Felt<C::F>; DIGEST_SIZE] = post[..DIGEST_SIZE].try_into().unwrap();
177 post
178 }
179
180 fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C> {
182 let mut uninit_vec = |len| std::iter::from_fn(|| Some(self.uninit())).take(len).collect();
183 let output = CircuitV2FriFoldOutput {
184 alpha_pow_output: uninit_vec(input.alpha_pow_input.len()),
185 ro_output: uninit_vec(input.ro_input.len()),
186 };
187 self.push_op(DslIr::CircuitV2FriFold(Box::new((output.clone(), input))));
188 output
189 }
190
191 fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D] {
193 let felts = core::array::from_fn(|_| self.uninit());
194 self.push_op(DslIr::CircuitExt2Felt(felts, ext));
195 let mut reconstructed_ext: Ext<C::F, C::EF> = self.constant(C::EF::zero());
197 for i in 0..4 {
198 let felt = felts[i];
199 let monomial: Ext<C::F, C::EF> = self.constant(C::EF::monomial(i));
200 reconstructed_ext = self.eval(reconstructed_ext + monomial * felt);
201 }
202
203 self.assert_ext_eq(reconstructed_ext, ext);
204
205 felts
206 }
207
208 fn add_curve_v2(
210 &mut self,
211 point1: SepticCurve<Felt<C::F>>,
212 point2: SepticCurve<Felt<C::F>>,
213 ) -> SepticCurve<Felt<C::F>> {
214 let point_sum_x: [Felt<C::F>; 7] = core::array::from_fn(|_| self.uninit());
216 let point_sum_y: [Felt<C::F>; 7] = core::array::from_fn(|_| self.uninit());
217 let point =
218 SepticCurve { x: SepticExtension(point_sum_x), y: SepticExtension(point_sum_y) };
219 self.push_op(DslIr::CircuitV2HintAddCurve(Box::new((point, point1, point2))));
220
221 let point1_symbolic = SepticCurve::convert(point1, |x| x.into());
223 let point2_symbolic = SepticCurve::convert(point2, |x| x.into());
224 let point_symbolic = SepticCurve::convert(point, |x| x.into());
225
226 let sum_checker_x = SepticCurve::<SymbolicFelt<C::F>>::sum_checker_x(
228 point1_symbolic,
229 point2_symbolic,
230 point_symbolic,
231 );
232
233 let sum_checker_y = SepticCurve::<SymbolicFelt<C::F>>::sum_checker_y(
234 point1_symbolic,
235 point2_symbolic,
236 point_symbolic,
237 );
238
239 for limb in sum_checker_x.0 {
241 self.assert_felt_eq(limb, C::F::zero());
242 }
243
244 for limb in sum_checker_y.0 {
245 self.assert_felt_eq(limb, C::F::zero());
246 }
247
248 point
249 }
250
251 fn assert_digest_zero_v2(&mut self, is_real: Felt<C::F>, digest: SepticDigest<Felt<C::F>>) {
253 let zero = SepticDigest::<SymbolicFelt<C::F>>::zero();
254 for (digest_limb_x, zero_limb_x) in digest.0.x.0.into_iter().zip_eq(zero.0.x.0.into_iter())
255 {
256 self.assert_felt_eq(is_real * digest_limb_x, is_real * zero_limb_x);
257 }
258 for (digest_limb_y, zero_limb_y) in digest.0.y.0.into_iter().zip_eq(zero.0.y.0.into_iter())
259 {
260 self.assert_felt_eq(is_real * digest_limb_y, is_real * zero_limb_y);
261 }
262 }
263
264 fn select_global_cumulative_sum(
268 &mut self,
269 is_first_shard: Felt<C::F>,
270 vk_digest: SepticDigest<Felt<C::F>>,
271 ) -> SepticDigest<Felt<C::F>> {
272 let zero = SepticDigest::<SymbolicFelt<C::F>>::zero();
273 let one: Felt<C::F> = self.constant(C::F::one());
274 let x = SepticExtension(core::array::from_fn(|i| {
275 self.eval(is_first_shard * vk_digest.0.x.0[i] + (one - is_first_shard) * zero.0.x.0[i])
276 }));
277 let y = SepticExtension(core::array::from_fn(|i| {
278 self.eval(is_first_shard * vk_digest.0.y.0[i] + (one - is_first_shard) * zero.0.y.0[i])
279 }));
280 SepticDigest(SepticCurve { x, y })
281 }
282
283 fn sum_digest_v2(
285 &mut self,
286 digests: Vec<SepticDigest<Felt<C::F>>>,
287 ) -> SepticDigest<Felt<C::F>> {
288 let mut convert_to_felt =
289 |point: SepticCurve<C::F>| SepticCurve::convert(point, |value| self.eval(value));
290
291 let start = convert_to_felt(SepticDigest::starting_digest().0);
292 let zero_digest = convert_to_felt(SepticDigest::zero().0);
293
294 if digests.is_empty() {
295 return SepticDigest(zero_digest);
296 }
297
298 let neg_start = convert_to_felt(SepticDigest::starting_digest().0.neg());
299 let neg_zero_digest = convert_to_felt(SepticDigest::zero().0.neg());
300
301 let mut ret = start;
302 for (i, digest) in digests.clone().into_iter().enumerate() {
303 ret = self.add_curve_v2(ret, digest.0);
304 if i != digests.len() - 1 {
305 ret = self.add_curve_v2(ret, neg_zero_digest)
306 }
307 }
308 SepticDigest(self.add_curve_v2(ret, neg_start))
309 }
310
311 fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>) {
313 self.push_op(DslIr::CircuitV2CommitPublicValues(Box::new(public_values)));
314 }
315
316 fn cycle_tracker_v2_enter(&mut self, name: impl Into<Cow<'static, str>>) {
317 self.push_op(DslIr::CycleTrackerV2Enter(name.into()));
318 }
319
320 fn cycle_tracker_v2_exit(&mut self) {
321 self.push_op(DslIr::CycleTrackerV2Exit);
322 }
323
324 fn hint_felt_v2(&mut self) -> Felt<C::F> {
326 self.hint_felts_v2(1)[0]
327 }
328
329 fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF> {
331 self.hint_exts_v2(1)[0]
332 }
333
334 fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>> {
336 let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
337 self.push_op(DslIr::CircuitV2HintFelts(arr[0], len));
338 arr
339 }
340
341 fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>> {
343 let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
344 self.push_op(DslIr::CircuitV2HintExts(arr[0], len));
345 arr
346 }
347}