sp1_recursion_compiler/circuit/
builder.rs1use std::{borrow::Cow, iter::repeat};
4
5use crate::prelude::*;
6use itertools::Itertools;
7use p3_baby_bear::BabyBear;
8use p3_field::{AbstractExtensionField, AbstractField};
9use sp1_recursion_core::air::RecursionPublicValues;
10use sp1_recursion_core::{chips::poseidon2_skinny::WIDTH, D, DIGEST_SIZE, HASH_RATE};
11use sp1_stark::septic_curve::SepticCurve;
12use sp1_stark::septic_digest::SepticDigest;
13use sp1_stark::septic_extension::SepticExtension;
14
15pub trait CircuitV2Builder<C: Config> {
16 fn bits2num_v2_f(
17 &mut self,
18 bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
19 ) -> Felt<C::F>;
20 fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>>;
21 fn exp_reverse_bits_v2(&mut self, input: Felt<C::F>, power_bits: Vec<Felt<C::F>>)
22 -> Felt<C::F>;
23 fn batch_fri_v2(
24 &mut self,
25 alphas: Vec<Ext<C::F, C::EF>>,
26 p_at_zs: Vec<Ext<C::F, C::EF>>,
27 p_at_xs: Vec<Felt<C::F>>,
28 ) -> Ext<C::F, C::EF>;
29 fn poseidon2_permute_v2(&mut self, state: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH];
30 fn poseidon2_hash_v2(&mut self, array: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE];
31 fn poseidon2_compress_v2(
32 &mut self,
33 input: impl IntoIterator<Item = Felt<C::F>>,
34 ) -> [Felt<C::F>; DIGEST_SIZE];
35 fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C>;
36 fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D];
37 fn add_curve_v2(
38 &mut self,
39 point1: SepticCurve<Felt<C::F>>,
40 point2: SepticCurve<Felt<C::F>>,
41 ) -> SepticCurve<Felt<C::F>>;
42 fn assert_digest_zero_v2(&mut self, is_real: Felt<C::F>, digest: SepticDigest<Felt<C::F>>);
43 fn sum_digest_v2(&mut self, digests: Vec<SepticDigest<Felt<C::F>>>)
44 -> SepticDigest<Felt<C::F>>;
45 fn select_global_cumulative_sum(
46 &mut self,
47 is_first_shard: Felt<C::F>,
48 vk_digest: SepticDigest<Felt<C::F>>,
49 ) -> SepticDigest<Felt<C::F>>;
50 fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>);
51 fn cycle_tracker_v2_enter(&mut self, name: impl Into<Cow<'static, str>>);
52 fn cycle_tracker_v2_exit(&mut self);
53 fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF>;
54 fn hint_felt_v2(&mut self) -> Felt<C::F>;
55 fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>>;
56 fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>>;
57}
58
59impl<C: Config<F = BabyBear>> CircuitV2Builder<C> for Builder<C> {
60 fn bits2num_v2_f(
61 &mut self,
62 bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
63 ) -> Felt<<C as Config>::F> {
64 let mut num: Felt<_> = self.eval(C::F::zero());
65 for (i, bit) in bits.into_iter().enumerate() {
66 num = self.eval(num + bit * C::F::from_wrapped_u32(1 << i));
68 }
69 num
70 }
71
72 fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>> {
74 let output = std::iter::from_fn(|| Some(self.uninit())).take(num_bits).collect::<Vec<_>>();
75 self.push_op(DslIr::CircuitV2HintBitsF(output.clone(), num));
76
77 let x: SymbolicFelt<_> = output
78 .iter()
79 .enumerate()
80 .map(|(i, &bit)| {
81 self.assert_felt_eq(bit * (bit - C::F::one()), C::F::zero());
82 bit * C::F::from_wrapped_u32(1 << i)
83 })
84 .sum();
85
86 assert!(num_bits <= 31, "num_bits must be less than or equal to 31");
89
90 if num_bits > 30 {
92 let are_all_top_bits_one: Felt<_> = self.eval(
99 output
100 .iter()
101 .rev()
102 .take(4)
103 .copied()
104 .map(SymbolicFelt::from)
105 .product::<SymbolicFelt<_>>(),
106 );
107
108 for bit in output.iter().take(27).copied() {
110 self.assert_felt_eq(bit * are_all_top_bits_one, C::F::zero());
111 }
112 }
113
114 self.assert_felt_eq(x, num);
116
117 output
118 }
119
120 fn exp_reverse_bits_v2(
122 &mut self,
123 input: Felt<C::F>,
124 power_bits: Vec<Felt<C::F>>,
125 ) -> Felt<C::F> {
126 let output: Felt<_> = self.uninit();
127 self.push_op(DslIr::CircuitV2ExpReverseBits(output, input, power_bits));
128 output
129 }
130
131 fn batch_fri_v2(
133 &mut self,
134 alpha_pows: Vec<Ext<C::F, C::EF>>,
135 p_at_zs: Vec<Ext<C::F, C::EF>>,
136 p_at_xs: Vec<Felt<C::F>>,
137 ) -> Ext<C::F, C::EF> {
138 let output: Ext<_, _> = self.uninit();
139 self.push_op(DslIr::CircuitV2BatchFRI(Box::new((output, alpha_pows, p_at_zs, p_at_xs))));
140 output
141 }
142
143 fn poseidon2_permute_v2(&mut self, array: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH] {
145 let output: [Felt<C::F>; WIDTH] = core::array::from_fn(|_| self.uninit());
146 self.push_op(DslIr::CircuitV2Poseidon2PermuteBabyBear(Box::new((output, array))));
147 output
148 }
149
150 fn poseidon2_hash_v2(&mut self, input: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE] {
154 let mut state = core::array::from_fn(|_| self.eval(C::F::zero()));
156 for input_chunk in input.chunks(HASH_RATE) {
157 state[..input_chunk.len()].copy_from_slice(input_chunk);
158 state = self.poseidon2_permute_v2(state);
159 }
160 let state: [Felt<C::F>; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
161 state
162 }
163
164 fn poseidon2_compress_v2(
168 &mut self,
169 input: impl IntoIterator<Item = Felt<C::F>>,
170 ) -> [Felt<C::F>; DIGEST_SIZE] {
171 let mut pre_iter = input.into_iter().chain(repeat(self.eval(C::F::default())));
173 let pre = core::array::from_fn(move |_| pre_iter.next().unwrap());
174 let post = self.poseidon2_permute_v2(pre);
175 let post: [Felt<C::F>; DIGEST_SIZE] = post[..DIGEST_SIZE].try_into().unwrap();
176 post
177 }
178
179 fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C> {
181 let mut uninit_vec = |len| std::iter::from_fn(|| Some(self.uninit())).take(len).collect();
182 let output = CircuitV2FriFoldOutput {
183 alpha_pow_output: uninit_vec(input.alpha_pow_input.len()),
184 ro_output: uninit_vec(input.ro_input.len()),
185 };
186 self.push_op(DslIr::CircuitV2FriFold(Box::new((output.clone(), input))));
187 output
188 }
189
190 fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D] {
192 let felts = core::array::from_fn(|_| self.uninit());
193 self.push_op(DslIr::CircuitExt2Felt(felts, ext));
194 let mut reconstructed_ext: Ext<C::F, C::EF> = self.constant(C::EF::zero());
196 for i in 0..4 {
197 let felt = felts[i];
198 let monomial: Ext<C::F, C::EF> = self.constant(C::EF::monomial(i));
199 reconstructed_ext = self.eval(reconstructed_ext + monomial * felt);
200 }
201
202 self.assert_ext_eq(reconstructed_ext, ext);
203
204 felts
205 }
206
207 fn add_curve_v2(
209 &mut self,
210 point1: SepticCurve<Felt<C::F>>,
211 point2: SepticCurve<Felt<C::F>>,
212 ) -> SepticCurve<Felt<C::F>> {
213 let point_sum_x: [Felt<C::F>; 7] = core::array::from_fn(|_| self.uninit());
215 let point_sum_y: [Felt<C::F>; 7] = core::array::from_fn(|_| self.uninit());
216 let point =
217 SepticCurve { x: SepticExtension(point_sum_x), y: SepticExtension(point_sum_y) };
218 self.push_op(DslIr::CircuitV2HintAddCurve(Box::new((point, point1, point2))));
219
220 let point1_symbolic = SepticCurve::convert(point1, |x| x.into());
222 let point2_symbolic = SepticCurve::convert(point2, |x| x.into());
223 let point_symbolic = SepticCurve::convert(point, |x| x.into());
224
225 let sum_checker_x = SepticCurve::<SymbolicFelt<C::F>>::sum_checker_x(
227 point1_symbolic,
228 point2_symbolic,
229 point_symbolic,
230 );
231
232 let sum_checker_y = SepticCurve::<SymbolicFelt<C::F>>::sum_checker_y(
233 point1_symbolic,
234 point2_symbolic,
235 point_symbolic,
236 );
237
238 for limb in sum_checker_x.0 {
240 self.assert_felt_eq(limb, C::F::zero());
241 }
242
243 for limb in sum_checker_y.0 {
244 self.assert_felt_eq(limb, C::F::zero());
245 }
246
247 point
248 }
249
250 fn assert_digest_zero_v2(&mut self, is_real: Felt<C::F>, digest: SepticDigest<Felt<C::F>>) {
252 let zero = SepticDigest::<SymbolicFelt<C::F>>::zero();
253 for (digest_limb_x, zero_limb_x) in digest.0.x.0.into_iter().zip_eq(zero.0.x.0.into_iter())
254 {
255 self.assert_felt_eq(is_real * digest_limb_x, is_real * zero_limb_x);
256 }
257 for (digest_limb_y, zero_limb_y) in digest.0.y.0.into_iter().zip_eq(zero.0.y.0.into_iter())
258 {
259 self.assert_felt_eq(is_real * digest_limb_y, is_real * zero_limb_y);
260 }
261 }
262
263 fn select_global_cumulative_sum(
266 &mut self,
267 is_first_shard: Felt<C::F>,
268 vk_digest: SepticDigest<Felt<C::F>>,
269 ) -> SepticDigest<Felt<C::F>> {
270 let zero = SepticDigest::<SymbolicFelt<C::F>>::zero();
271 let one: Felt<C::F> = self.constant(C::F::one());
272 let x = SepticExtension(core::array::from_fn(|i| {
273 self.eval(is_first_shard * vk_digest.0.x.0[i] + (one - is_first_shard) * zero.0.x.0[i])
274 }));
275 let y = SepticExtension(core::array::from_fn(|i| {
276 self.eval(is_first_shard * vk_digest.0.y.0[i] + (one - is_first_shard) * zero.0.y.0[i])
277 }));
278 SepticDigest(SepticCurve { x, y })
279 }
280
281 fn sum_digest_v2(
283 &mut self,
284 digests: Vec<SepticDigest<Felt<C::F>>>,
285 ) -> SepticDigest<Felt<C::F>> {
286 let mut convert_to_felt =
287 |point: SepticCurve<C::F>| SepticCurve::convert(point, |value| self.eval(value));
288
289 let start = convert_to_felt(SepticDigest::starting_digest().0);
290 let zero_digest = convert_to_felt(SepticDigest::zero().0);
291
292 if digests.is_empty() {
293 return SepticDigest(zero_digest);
294 }
295
296 let neg_start = convert_to_felt(SepticDigest::starting_digest().0.neg());
297 let neg_zero_digest = convert_to_felt(SepticDigest::zero().0.neg());
298
299 let mut ret = start;
300 for (i, digest) in digests.clone().into_iter().enumerate() {
301 ret = self.add_curve_v2(ret, digest.0);
302 if i != digests.len() - 1 {
303 ret = self.add_curve_v2(ret, neg_zero_digest)
304 }
305 }
306 SepticDigest(self.add_curve_v2(ret, neg_start))
307 }
308
309 fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>) {
311 self.push_op(DslIr::CircuitV2CommitPublicValues(Box::new(public_values)));
312 }
313
314 fn cycle_tracker_v2_enter(&mut self, name: impl Into<Cow<'static, str>>) {
315 self.push_op(DslIr::CycleTrackerV2Enter(name.into()));
316 }
317
318 fn cycle_tracker_v2_exit(&mut self) {
319 self.push_op(DslIr::CycleTrackerV2Exit);
320 }
321
322 fn hint_felt_v2(&mut self) -> Felt<C::F> {
324 self.hint_felts_v2(1)[0]
325 }
326
327 fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF> {
329 self.hint_exts_v2(1)[0]
330 }
331
332 fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>> {
334 let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
335 self.push_op(DslIr::CircuitV2HintFelts(arr[0], len));
336 arr
337 }
338
339 fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>> {
341 let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
342 self.push_op(DslIr::CircuitV2HintExts(arr[0], len));
343 arr
344 }
345}