sp1_recursion_core/chips/
batch_fri.rs1#![allow(clippy::needless_range_loop)]
2
3use crate::{
4 air::Block, builder::SP1RecursionAirBuilder, Address, BatchFRIEvent, BatchFRIInstr,
5 ExecutionRecord, Instruction,
6};
7use core::borrow::Borrow;
8use itertools::Itertools;
9use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
10use p3_baby_bear::BabyBear;
11use p3_field::{AbstractField, PrimeField32};
12use p3_matrix::{dense::RowMajorMatrix, Matrix};
13use sp1_core_machine::utils::{next_power_of_two, pad_rows_fixed};
14use sp1_derive::AlignedBorrow;
15use sp1_stark::air::{BaseAirBuilder, BinomialExtension, ExtensionAirBuilder, MachineAir};
16
17use std::borrow::BorrowMut;
18use tracing::instrument;
19
20pub const NUM_BATCH_FRI_COLS: usize = core::mem::size_of::<BatchFRICols<u8>>();
21pub const NUM_BATCH_FRI_PREPROCESSED_COLS: usize =
22 core::mem::size_of::<BatchFRIPreprocessedCols<u8>>();
23
24#[derive(Clone, Debug, Copy, Default)]
25pub struct BatchFRIChip<const DEGREE: usize>;
26
27#[derive(AlignedBorrow, Debug, Clone, Copy)]
29#[repr(C)]
30pub struct BatchFRIPreprocessedCols<T: Copy> {
31 pub is_real: T,
32 pub is_end: T,
33 pub acc_addr: Address<T>,
34 pub alpha_pow_addr: Address<T>,
35 pub p_at_z_addr: Address<T>,
36 pub p_at_x_addr: Address<T>,
37}
38
39#[derive(AlignedBorrow, Debug, Clone, Copy)]
41#[repr(C)]
42pub struct BatchFRICols<T: Copy> {
43 pub acc: Block<T>,
44 pub alpha_pow: Block<T>,
45 pub p_at_z: Block<T>,
46 pub p_at_x: T,
47}
48
49impl<F, const DEGREE: usize> BaseAir<F> for BatchFRIChip<DEGREE> {
50 fn width(&self) -> usize {
51 NUM_BATCH_FRI_COLS
52 }
53}
54
55impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for BatchFRIChip<DEGREE> {
56 type Record = ExecutionRecord<F>;
57
58 type Program = crate::RecursionProgram<F>;
59
60 fn name(&self) -> String {
61 "BatchFRI".to_string()
62 }
63
64 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
65 }
67
68 fn preprocessed_width(&self) -> usize {
69 NUM_BATCH_FRI_PREPROCESSED_COLS
70 }
71
72 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
73 assert_eq!(
74 std::any::TypeId::of::<F>(),
75 std::any::TypeId::of::<BabyBear>(),
76 "generate_preprocessed_trace only supports BabyBear field"
77 );
78
79 let mut rows = Vec::new();
80 let instrs = unsafe {
81 std::mem::transmute::<Vec<&Box<BatchFRIInstr<F>>>, Vec<&Box<BatchFRIInstr<BabyBear>>>>(
82 program
83 .inner
84 .iter()
85 .filter_map(|instruction| match instruction {
86 Instruction::BatchFRI(x) => Some(x),
87 _ => None,
88 })
89 .collect::<Vec<_>>(),
90 )
91 };
92 instrs.iter().for_each(|instruction| {
93 let BatchFRIInstr { base_vec_addrs: _, ext_single_addrs: _, ext_vec_addrs, acc_mult } =
94 instruction.as_ref();
95 let len: usize = ext_vec_addrs.p_at_z.len();
96 let mut row_add = vec![[BabyBear::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS]; len];
97 debug_assert_eq!(*acc_mult, BabyBear::one());
98
99 row_add.iter_mut().enumerate().for_each(|(i, row)| {
100 let cols: &mut BatchFRIPreprocessedCols<BabyBear> = row.as_mut_slice().borrow_mut();
101 unsafe {
102 crate::sys::batch_fri_instr_to_row_babybear(&instruction.into(), cols, i);
103 }
104 });
105 rows.extend(row_add);
106 });
107
108 pad_rows_fixed(
110 &mut rows,
111 || [BabyBear::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS],
112 program.fixed_log2_rows(self),
113 );
114
115 let trace = RowMajorMatrix::new(
116 unsafe {
117 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
118 rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
119 )
120 },
121 NUM_BATCH_FRI_PREPROCESSED_COLS,
122 );
123 Some(trace)
124 }
125
126 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
127 let events = &input.batch_fri_events;
128 Some(next_power_of_two(events.len(), input.fixed_log2_rows(self)))
129 }
130
131 #[instrument(name = "generate batch fri trace", level = "debug", skip_all, fields(rows = input.batch_fri_events.len()))]
132 fn generate_trace(
133 &self,
134 input: &ExecutionRecord<F>,
135 _: &mut ExecutionRecord<F>,
136 ) -> RowMajorMatrix<F> {
137 assert_eq!(
138 std::any::TypeId::of::<F>(),
139 std::any::TypeId::of::<BabyBear>(),
140 "generate_trace only supports BabyBear field"
141 );
142
143 let mut rows = input
144 .batch_fri_events
145 .iter()
146 .map(|event| {
147 let bb_event = unsafe {
148 std::mem::transmute::<&BatchFRIEvent<F>, &BatchFRIEvent<BabyBear>>(event)
149 };
150 let mut row = [BabyBear::zero(); NUM_BATCH_FRI_COLS];
151 let cols: &mut BatchFRICols<BabyBear> = row.as_mut_slice().borrow_mut();
152 cols.acc = bb_event.ext_single.acc;
153 cols.alpha_pow = bb_event.ext_vec.alpha_pow;
154 cols.p_at_z = bb_event.ext_vec.p_at_z;
155 cols.p_at_x = bb_event.base_vec.p_at_x;
156 row
157 })
158 .collect_vec();
159
160 rows.resize(self.num_rows(input).unwrap(), [BabyBear::zero(); NUM_BATCH_FRI_COLS]);
162
163 let trace = RowMajorMatrix::new(
165 unsafe {
166 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
167 rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
168 )
169 },
170 NUM_BATCH_FRI_COLS,
171 );
172
173 #[cfg(debug_assertions)]
174 eprintln!(
175 "batch fri trace dims is width: {:?}, height: {:?}",
176 trace.width(),
177 trace.height()
178 );
179
180 trace
181 }
182
183 fn included(&self, _record: &Self::Record) -> bool {
184 true
185 }
186}
187
188impl<const DEGREE: usize> BatchFRIChip<DEGREE> {
189 pub fn eval_batch_fri<AB: SP1RecursionAirBuilder>(
190 &self,
191 builder: &mut AB,
192 local: &BatchFRICols<AB::Var>,
193 next: &BatchFRICols<AB::Var>,
194 local_prepr: &BatchFRIPreprocessedCols<AB::Var>,
195 _next_prepr: &BatchFRIPreprocessedCols<AB::Var>,
196 ) {
197 builder.receive_block(local_prepr.alpha_pow_addr, local.alpha_pow, local_prepr.is_real);
199 builder.receive_block(local_prepr.p_at_z_addr, local.p_at_z, local_prepr.is_real);
200 builder.receive_single(local_prepr.p_at_x_addr, local.p_at_x, local_prepr.is_real);
201
202 builder.send_block(local_prepr.acc_addr, local.acc, local_prepr.is_end);
205
206 builder.when_first_row().assert_ext_eq(
208 local.acc.as_extension::<AB>(),
209 local.alpha_pow.as_extension::<AB>() *
210 (local.p_at_z.as_extension::<AB>() -
211 BinomialExtension::from_base(local.p_at_x.into())),
212 );
213
214 builder.when_transition().when(local_prepr.is_end).assert_ext_eq(
216 next.acc.as_extension::<AB>(),
217 next.alpha_pow.as_extension::<AB>() *
218 (next.p_at_z.as_extension::<AB>() -
219 BinomialExtension::from_base(next.p_at_x.into())),
220 );
221
222 builder.when_transition().when_not(local_prepr.is_end).assert_ext_eq(
224 next.acc.as_extension::<AB>(),
225 local.acc.as_extension::<AB>() +
226 next.alpha_pow.as_extension::<AB>() *
227 (next.p_at_z.as_extension::<AB>() -
228 BinomialExtension::from_base(next.p_at_x.into())),
229 );
230 }
231
232 pub const fn do_memory_access<T: Copy>(local: &BatchFRIPreprocessedCols<T>) -> T {
233 local.is_real
234 }
235}
236
237impl<AB, const DEGREE: usize> Air<AB> for BatchFRIChip<DEGREE>
238where
239 AB: SP1RecursionAirBuilder + PairBuilder,
240{
241 fn eval(&self, builder: &mut AB) {
242 let main = builder.main();
243 let (local, next) = (main.row_slice(0), main.row_slice(1));
244 let local: &BatchFRICols<AB::Var> = (*local).borrow();
245 let next: &BatchFRICols<AB::Var> = (*next).borrow();
246 let prepr = builder.preprocessed();
247 let (prepr_local, prepr_next) = (prepr.row_slice(0), prepr.row_slice(1));
248 let prepr_local: &BatchFRIPreprocessedCols<AB::Var> = (*prepr_local).borrow();
249 let prepr_next: &BatchFRIPreprocessedCols<AB::Var> = (*prepr_next).borrow();
250
251 let lhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
253 let rhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
254 builder.assert_eq(lhs, rhs);
255
256 self.eval_batch_fri::<AB>(builder, local, next, prepr_local, prepr_next);
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use crate::{chips::test_fixtures, Instruction, RecursionProgram};
263 use p3_baby_bear::BabyBear;
264 use p3_field::AbstractField;
265 use p3_matrix::dense::RowMajorMatrix;
266
267 use super::*;
268
269 const DEGREE: usize = 2;
270
271 fn generate_trace_reference<const DEGREE: usize>(
272 input: &ExecutionRecord<BabyBear>,
273 _: &mut ExecutionRecord<BabyBear>,
274 ) -> RowMajorMatrix<BabyBear> {
275 type F = BabyBear;
276
277 let mut rows = input
278 .batch_fri_events
279 .iter()
280 .map(|event| {
281 let mut row = [F::zero(); NUM_BATCH_FRI_COLS];
282 let cols: &mut BatchFRICols<F> = row.as_mut_slice().borrow_mut();
283 cols.acc = event.ext_single.acc;
284 cols.alpha_pow = event.ext_vec.alpha_pow;
285 cols.p_at_z = event.ext_vec.p_at_z;
286 cols.p_at_x = event.base_vec.p_at_x;
287 row
288 })
289 .collect_vec();
290
291 rows.resize(
292 BatchFRIChip::<DEGREE>.num_rows(input).unwrap(),
293 [F::zero(); NUM_BATCH_FRI_COLS],
294 );
295
296 RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_BATCH_FRI_COLS)
297 }
298
299 #[test]
300 fn generate_trace() {
301 let shard = test_fixtures::shard();
302 let mut execution_record = test_fixtures::default_execution_record();
303 let trace = BatchFRIChip::<DEGREE>.generate_trace(&shard, &mut execution_record);
304 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
305
306 assert_eq!(trace, generate_trace_reference::<DEGREE>(&shard, &mut execution_record));
307 }
308
309 fn generate_preprocessed_trace_reference<const DEGREE: usize>(
310 program: &RecursionProgram<BabyBear>,
311 ) -> RowMajorMatrix<BabyBear> {
312 type F = BabyBear;
313
314 let mut rows: Vec<[F; NUM_BATCH_FRI_PREPROCESSED_COLS]> = Vec::new();
315 program
316 .inner
317 .iter()
318 .filter_map(|instruction| match instruction {
319 Instruction::BatchFRI(instr) => Some(instr),
320 _ => None,
321 })
322 .for_each(|instruction| {
323 let BatchFRIInstr { base_vec_addrs, ext_single_addrs, ext_vec_addrs, acc_mult } =
324 instruction.as_ref();
325 let len = ext_vec_addrs.p_at_z.len();
326 let mut row_add = vec![[F::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS]; len];
327 debug_assert_eq!(*acc_mult, F::one());
328
329 row_add.iter_mut().enumerate().for_each(|(_i, row)| {
330 let row: &mut BatchFRIPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
331 row.is_real = F::one();
332 row.is_end = F::from_bool(_i == len - 1);
333 row.acc_addr = ext_single_addrs.acc;
334 row.alpha_pow_addr = ext_vec_addrs.alpha_pow[_i];
335 row.p_at_z_addr = ext_vec_addrs.p_at_z[_i];
336 row.p_at_x_addr = base_vec_addrs.p_at_x[_i];
337 });
338 rows.extend(row_add);
339 });
340
341 pad_rows_fixed(
342 &mut rows,
343 || [F::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS],
344 program.fixed_log2_rows(&BatchFRIChip::<DEGREE>),
345 );
346
347 RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_BATCH_FRI_PREPROCESSED_COLS)
348 }
349
350 #[test]
351 #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
352 fn generate_preprocessed_trace() {
353 let program = test_fixtures::program();
354 let trace = BatchFRIChip::<DEGREE>.generate_preprocessed_trace(&program).unwrap();
355 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
356
357 assert_eq!(trace, generate_preprocessed_trace_reference::<DEGREE>(&program));
358 }
359}