sp1_recursion_core/chips/
batch_fri.rs1#![allow(clippy::needless_range_loop)]
2
3use crate::{air::Block, builder::SP1RecursionAirBuilder, Address, ExecutionRecord};
4use core::borrow::Borrow;
5use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
6use p3_field::PrimeField32;
7use p3_matrix::{dense::RowMajorMatrix, Matrix};
8use sp1_core_machine::utils::next_power_of_two;
9use sp1_derive::AlignedBorrow;
10use sp1_stark::air::{BaseAirBuilder, BinomialExtension, ExtensionAirBuilder, MachineAir};
11
12#[cfg(feature = "sys")]
13use {
14 crate::{BatchFRIEvent, BatchFRIInstr, Instruction},
15 itertools::Itertools,
16 p3_baby_bear::BabyBear,
17 p3_field::AbstractField,
18 sp1_core_machine::utils::pad_rows_fixed,
19 std::borrow::BorrowMut,
20 tracing::instrument,
21};
22
23pub const NUM_BATCH_FRI_COLS: usize = core::mem::size_of::<BatchFRICols<u8>>();
24pub const NUM_BATCH_FRI_PREPROCESSED_COLS: usize =
25 core::mem::size_of::<BatchFRIPreprocessedCols<u8>>();
26
27#[derive(Clone, Debug, Copy, Default)]
28pub struct BatchFRIChip<const DEGREE: usize>;
29
30#[derive(AlignedBorrow, Debug, Clone, Copy)]
32#[repr(C)]
33pub struct BatchFRIPreprocessedCols<T: Copy> {
34 pub is_real: T,
35 pub is_end: T,
36 pub acc_addr: Address<T>,
37 pub alpha_pow_addr: Address<T>,
38 pub p_at_z_addr: Address<T>,
39 pub p_at_x_addr: Address<T>,
40}
41
42#[derive(AlignedBorrow, Debug, Clone, Copy)]
44#[repr(C)]
45pub struct BatchFRICols<T: Copy> {
46 pub acc: Block<T>,
47 pub alpha_pow: Block<T>,
48 pub p_at_z: Block<T>,
49 pub p_at_x: T,
50}
51
52impl<F, const DEGREE: usize> BaseAir<F> for BatchFRIChip<DEGREE> {
53 fn width(&self) -> usize {
54 NUM_BATCH_FRI_COLS
55 }
56}
57
58impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for BatchFRIChip<DEGREE> {
59 type Record = ExecutionRecord<F>;
60
61 type Program = crate::RecursionProgram<F>;
62
63 fn name(&self) -> String {
64 "BatchFRI".to_string()
65 }
66
67 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
68 }
70
71 fn preprocessed_width(&self) -> usize {
72 NUM_BATCH_FRI_PREPROCESSED_COLS
73 }
74
75 #[cfg(not(feature = "sys"))]
76 fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
77 unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
78 }
79
80 #[cfg(feature = "sys")]
81 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
82 assert_eq!(
83 std::any::TypeId::of::<F>(),
84 std::any::TypeId::of::<BabyBear>(),
85 "generate_preprocessed_trace only supports BabyBear field"
86 );
87
88 let mut rows = Vec::new();
89 let instrs = unsafe {
90 std::mem::transmute::<Vec<&Box<BatchFRIInstr<F>>>, Vec<&Box<BatchFRIInstr<BabyBear>>>>(
91 program
92 .inner
93 .iter()
94 .filter_map(|instruction| match instruction {
95 Instruction::BatchFRI(x) => Some(x),
96 _ => None,
97 })
98 .collect::<Vec<_>>(),
99 )
100 };
101 instrs.iter().for_each(|instruction| {
102 let BatchFRIInstr { base_vec_addrs: _, ext_single_addrs: _, ext_vec_addrs, acc_mult } =
103 instruction.as_ref();
104 let len: usize = ext_vec_addrs.p_at_z.len();
105 let mut row_add = vec![[BabyBear::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS]; len];
106 debug_assert_eq!(*acc_mult, BabyBear::one());
107
108 row_add.iter_mut().enumerate().for_each(|(i, row)| {
109 let cols: &mut BatchFRIPreprocessedCols<BabyBear> = row.as_mut_slice().borrow_mut();
110 unsafe {
111 crate::sys::batch_fri_instr_to_row_babybear(&instruction.into(), cols, i);
112 }
113 });
114 rows.extend(row_add);
115 });
116
117 pad_rows_fixed(
119 &mut rows,
120 || [BabyBear::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS],
121 program.fixed_log2_rows(self),
122 );
123
124 let trace = RowMajorMatrix::new(
125 unsafe {
126 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
127 rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
128 )
129 },
130 NUM_BATCH_FRI_PREPROCESSED_COLS,
131 );
132 Some(trace)
133 }
134
135 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
136 let events = &input.batch_fri_events;
137 Some(next_power_of_two(events.len(), input.fixed_log2_rows(self)))
138 }
139
140 #[cfg(not(feature = "sys"))]
141 fn generate_trace(&self, _input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
142 unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
143 }
144
145 #[cfg(feature = "sys")]
146 #[instrument(name = "generate batch fri trace", level = "debug", skip_all, fields(rows = input.batch_fri_events.len()))]
147 fn generate_trace(
148 &self,
149 input: &ExecutionRecord<F>,
150 _: &mut ExecutionRecord<F>,
151 ) -> RowMajorMatrix<F> {
152 assert_eq!(
153 std::any::TypeId::of::<F>(),
154 std::any::TypeId::of::<BabyBear>(),
155 "generate_trace only supports BabyBear field"
156 );
157
158 let mut rows = input
159 .batch_fri_events
160 .iter()
161 .map(|event| {
162 let bb_event = unsafe {
163 std::mem::transmute::<&BatchFRIEvent<F>, &BatchFRIEvent<BabyBear>>(event)
164 };
165 let mut row = [BabyBear::zero(); NUM_BATCH_FRI_COLS];
166 let cols: &mut BatchFRICols<BabyBear> = row.as_mut_slice().borrow_mut();
167 cols.acc = bb_event.ext_single.acc;
168 cols.alpha_pow = bb_event.ext_vec.alpha_pow;
169 cols.p_at_z = bb_event.ext_vec.p_at_z;
170 cols.p_at_x = bb_event.base_vec.p_at_x;
171 row
172 })
173 .collect_vec();
174
175 rows.resize(self.num_rows(input).unwrap(), [BabyBear::zero(); NUM_BATCH_FRI_COLS]);
177
178 let trace = RowMajorMatrix::new(
180 unsafe {
181 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
182 rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
183 )
184 },
185 NUM_BATCH_FRI_COLS,
186 );
187
188 #[cfg(debug_assertions)]
189 eprintln!(
190 "batch fri trace dims is width: {:?}, height: {:?}",
191 trace.width(),
192 trace.height()
193 );
194
195 trace
196 }
197
198 fn included(&self, _record: &Self::Record) -> bool {
199 true
200 }
201}
202
203impl<const DEGREE: usize> BatchFRIChip<DEGREE> {
204 pub fn eval_batch_fri<AB: SP1RecursionAirBuilder>(
205 &self,
206 builder: &mut AB,
207 local: &BatchFRICols<AB::Var>,
208 next: &BatchFRICols<AB::Var>,
209 local_prepr: &BatchFRIPreprocessedCols<AB::Var>,
210 _next_prepr: &BatchFRIPreprocessedCols<AB::Var>,
211 ) {
212 builder.receive_block(local_prepr.alpha_pow_addr, local.alpha_pow, local_prepr.is_real);
214 builder.receive_block(local_prepr.p_at_z_addr, local.p_at_z, local_prepr.is_real);
215 builder.receive_single(local_prepr.p_at_x_addr, local.p_at_x, local_prepr.is_real);
216
217 builder.send_block(local_prepr.acc_addr, local.acc, local_prepr.is_end);
220
221 builder.when_first_row().assert_ext_eq(
223 local.acc.as_extension::<AB>(),
224 local.alpha_pow.as_extension::<AB>() *
225 (local.p_at_z.as_extension::<AB>() -
226 BinomialExtension::from_base(local.p_at_x.into())),
227 );
228
229 builder.when_transition().when(local_prepr.is_end).assert_ext_eq(
231 next.acc.as_extension::<AB>(),
232 next.alpha_pow.as_extension::<AB>() *
233 (next.p_at_z.as_extension::<AB>() -
234 BinomialExtension::from_base(next.p_at_x.into())),
235 );
236
237 builder.when_transition().when_not(local_prepr.is_end).assert_ext_eq(
239 next.acc.as_extension::<AB>(),
240 local.acc.as_extension::<AB>() +
241 next.alpha_pow.as_extension::<AB>() *
242 (next.p_at_z.as_extension::<AB>() -
243 BinomialExtension::from_base(next.p_at_x.into())),
244 );
245 }
246
247 pub const fn do_memory_access<T: Copy>(local: &BatchFRIPreprocessedCols<T>) -> T {
248 local.is_real
249 }
250}
251
252impl<AB, const DEGREE: usize> Air<AB> for BatchFRIChip<DEGREE>
253where
254 AB: SP1RecursionAirBuilder + PairBuilder,
255{
256 fn eval(&self, builder: &mut AB) {
257 let main = builder.main();
258 let (local, next) = (main.row_slice(0), main.row_slice(1));
259 let local: &BatchFRICols<AB::Var> = (*local).borrow();
260 let next: &BatchFRICols<AB::Var> = (*next).borrow();
261 let prepr = builder.preprocessed();
262 let (prepr_local, prepr_next) = (prepr.row_slice(0), prepr.row_slice(1));
263 let prepr_local: &BatchFRIPreprocessedCols<AB::Var> = (*prepr_local).borrow();
264 let prepr_next: &BatchFRIPreprocessedCols<AB::Var> = (*prepr_next).borrow();
265
266 let lhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
268 let rhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
269 builder.assert_eq(lhs, rhs);
270
271 self.eval_batch_fri::<AB>(builder, local, next, prepr_local, prepr_next);
272 }
273}
274
275#[cfg(all(test, feature = "sys"))]
276mod tests {
277 use crate::{chips::test_fixtures, Instruction, RecursionProgram};
278 use p3_baby_bear::BabyBear;
279 use p3_field::AbstractField;
280 use p3_matrix::dense::RowMajorMatrix;
281
282 use super::*;
283
284 const DEGREE: usize = 2;
285
286 fn generate_trace_reference<const DEGREE: usize>(
287 input: &ExecutionRecord<BabyBear>,
288 _: &mut ExecutionRecord<BabyBear>,
289 ) -> RowMajorMatrix<BabyBear> {
290 type F = BabyBear;
291
292 let mut rows = input
293 .batch_fri_events
294 .iter()
295 .map(|event| {
296 let mut row = [F::zero(); NUM_BATCH_FRI_COLS];
297 let cols: &mut BatchFRICols<F> = row.as_mut_slice().borrow_mut();
298 cols.acc = event.ext_single.acc;
299 cols.alpha_pow = event.ext_vec.alpha_pow;
300 cols.p_at_z = event.ext_vec.p_at_z;
301 cols.p_at_x = event.base_vec.p_at_x;
302 row
303 })
304 .collect_vec();
305
306 rows.resize(
307 BatchFRIChip::<DEGREE>.num_rows(input).unwrap(),
308 [F::zero(); NUM_BATCH_FRI_COLS],
309 );
310
311 RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_BATCH_FRI_COLS)
312 }
313
314 #[test]
315 fn generate_trace() {
316 let shard = test_fixtures::shard();
317 let mut execution_record = test_fixtures::default_execution_record();
318 let trace = BatchFRIChip::<DEGREE>.generate_trace(&shard, &mut execution_record);
319 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
320
321 assert_eq!(trace, generate_trace_reference::<DEGREE>(&shard, &mut execution_record));
322 }
323
324 fn generate_preprocessed_trace_reference<const DEGREE: usize>(
325 program: &RecursionProgram<BabyBear>,
326 ) -> RowMajorMatrix<BabyBear> {
327 type F = BabyBear;
328
329 let mut rows: Vec<[F; NUM_BATCH_FRI_PREPROCESSED_COLS]> = Vec::new();
330 program
331 .inner
332 .iter()
333 .filter_map(|instruction| match instruction {
334 Instruction::BatchFRI(instr) => Some(instr),
335 _ => None,
336 })
337 .for_each(|instruction| {
338 let BatchFRIInstr { base_vec_addrs, ext_single_addrs, ext_vec_addrs, acc_mult } =
339 instruction.as_ref();
340 let len = ext_vec_addrs.p_at_z.len();
341 let mut row_add = vec![[F::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS]; len];
342 debug_assert_eq!(*acc_mult, F::one());
343
344 row_add.iter_mut().enumerate().for_each(|(_i, row)| {
345 let row: &mut BatchFRIPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
346 row.is_real = F::one();
347 row.is_end = F::from_bool(_i == len - 1);
348 row.acc_addr = ext_single_addrs.acc;
349 row.alpha_pow_addr = ext_vec_addrs.alpha_pow[_i];
350 row.p_at_z_addr = ext_vec_addrs.p_at_z[_i];
351 row.p_at_x_addr = base_vec_addrs.p_at_x[_i];
352 });
353 rows.extend(row_add);
354 });
355
356 pad_rows_fixed(
357 &mut rows,
358 || [F::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS],
359 program.fixed_log2_rows(&BatchFRIChip::<DEGREE>),
360 );
361
362 RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_BATCH_FRI_PREPROCESSED_COLS)
363 }
364
365 #[test]
366 #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
367 fn generate_preprocessed_trace() {
368 let program = test_fixtures::program();
369 let trace = BatchFRIChip::<DEGREE>.generate_preprocessed_trace(&program).unwrap();
370 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
371
372 assert_eq!(trace, generate_preprocessed_trace_reference::<DEGREE>(&program));
373 }
374}