sp1_recursion_core/chips/
batch_fri.rs

1#![allow(clippy::needless_range_loop)]
2
3use crate::{air::Block, builder::SP1RecursionAirBuilder, Address, ExecutionRecord};
4use core::borrow::Borrow;
5use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
6use p3_field::PrimeField32;
7use p3_matrix::{dense::RowMajorMatrix, Matrix};
8use sp1_core_machine::utils::next_power_of_two;
9use sp1_derive::AlignedBorrow;
10use sp1_stark::air::{BaseAirBuilder, BinomialExtension, ExtensionAirBuilder, MachineAir};
11
12#[cfg(feature = "sys")]
13use {
14    crate::{BatchFRIEvent, BatchFRIInstr, Instruction},
15    itertools::Itertools,
16    p3_baby_bear::BabyBear,
17    p3_field::AbstractField,
18    sp1_core_machine::utils::pad_rows_fixed,
19    std::borrow::BorrowMut,
20    tracing::instrument,
21};
22
23pub const NUM_BATCH_FRI_COLS: usize = core::mem::size_of::<BatchFRICols<u8>>();
24pub const NUM_BATCH_FRI_PREPROCESSED_COLS: usize =
25    core::mem::size_of::<BatchFRIPreprocessedCols<u8>>();
26
27#[derive(Clone, Debug, Copy, Default)]
28pub struct BatchFRIChip<const DEGREE: usize>;
29
30/// The preprocessed columns for a batch FRI invocation.
31#[derive(AlignedBorrow, Debug, Clone, Copy)]
32#[repr(C)]
33pub struct BatchFRIPreprocessedCols<T: Copy> {
34    pub is_real: T,
35    pub is_end: T,
36    pub acc_addr: Address<T>,
37    pub alpha_pow_addr: Address<T>,
38    pub p_at_z_addr: Address<T>,
39    pub p_at_x_addr: Address<T>,
40}
41
42/// The main columns for a batch FRI invocation.
43#[derive(AlignedBorrow, Debug, Clone, Copy)]
44#[repr(C)]
45pub struct BatchFRICols<T: Copy> {
46    pub acc: Block<T>,
47    pub alpha_pow: Block<T>,
48    pub p_at_z: Block<T>,
49    pub p_at_x: T,
50}
51
52impl<F, const DEGREE: usize> BaseAir<F> for BatchFRIChip<DEGREE> {
53    fn width(&self) -> usize {
54        NUM_BATCH_FRI_COLS
55    }
56}
57
58impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for BatchFRIChip<DEGREE> {
59    type Record = ExecutionRecord<F>;
60
61    type Program = crate::RecursionProgram<F>;
62
63    fn name(&self) -> String {
64        "BatchFRI".to_string()
65    }
66
67    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
68        // This is a no-op.
69    }
70
71    fn preprocessed_width(&self) -> usize {
72        NUM_BATCH_FRI_PREPROCESSED_COLS
73    }
74
75    #[cfg(not(feature = "sys"))]
76    fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
77        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
78    }
79
80    #[cfg(feature = "sys")]
81    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
82        assert_eq!(
83            std::any::TypeId::of::<F>(),
84            std::any::TypeId::of::<BabyBear>(),
85            "generate_preprocessed_trace only supports BabyBear field"
86        );
87
88        let mut rows = Vec::new();
89        let instrs = unsafe {
90            std::mem::transmute::<Vec<&Box<BatchFRIInstr<F>>>, Vec<&Box<BatchFRIInstr<BabyBear>>>>(
91                program
92                    .inner
93                    .iter()
94                    .filter_map(|instruction| match instruction {
95                        Instruction::BatchFRI(x) => Some(x),
96                        _ => None,
97                    })
98                    .collect::<Vec<_>>(),
99            )
100        };
101        instrs.iter().for_each(|instruction| {
102            let BatchFRIInstr { base_vec_addrs: _, ext_single_addrs: _, ext_vec_addrs, acc_mult } =
103                instruction.as_ref();
104            let len: usize = ext_vec_addrs.p_at_z.len();
105            let mut row_add = vec![[BabyBear::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS]; len];
106            debug_assert_eq!(*acc_mult, BabyBear::one());
107
108            row_add.iter_mut().enumerate().for_each(|(i, row)| {
109                let cols: &mut BatchFRIPreprocessedCols<BabyBear> = row.as_mut_slice().borrow_mut();
110                unsafe {
111                    crate::sys::batch_fri_instr_to_row_babybear(&instruction.into(), cols, i);
112                }
113            });
114            rows.extend(row_add);
115        });
116
117        // Pad the trace to a power of two.
118        pad_rows_fixed(
119            &mut rows,
120            || [BabyBear::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS],
121            program.fixed_log2_rows(self),
122        );
123
124        let trace = RowMajorMatrix::new(
125            unsafe {
126                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
127                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
128                )
129            },
130            NUM_BATCH_FRI_PREPROCESSED_COLS,
131        );
132        Some(trace)
133    }
134
135    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
136        let events = &input.batch_fri_events;
137        Some(next_power_of_two(events.len(), input.fixed_log2_rows(self)))
138    }
139
140    #[cfg(not(feature = "sys"))]
141    fn generate_trace(&self, _input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
142        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
143    }
144
145    #[cfg(feature = "sys")]
146    #[instrument(name = "generate batch fri trace", level = "debug", skip_all, fields(rows = input.batch_fri_events.len()))]
147    fn generate_trace(
148        &self,
149        input: &ExecutionRecord<F>,
150        _: &mut ExecutionRecord<F>,
151    ) -> RowMajorMatrix<F> {
152        assert_eq!(
153            std::any::TypeId::of::<F>(),
154            std::any::TypeId::of::<BabyBear>(),
155            "generate_trace only supports BabyBear field"
156        );
157
158        let mut rows = input
159            .batch_fri_events
160            .iter()
161            .map(|event| {
162                let bb_event = unsafe {
163                    std::mem::transmute::<&BatchFRIEvent<F>, &BatchFRIEvent<BabyBear>>(event)
164                };
165                let mut row = [BabyBear::zero(); NUM_BATCH_FRI_COLS];
166                let cols: &mut BatchFRICols<BabyBear> = row.as_mut_slice().borrow_mut();
167                cols.acc = bb_event.ext_single.acc;
168                cols.alpha_pow = bb_event.ext_vec.alpha_pow;
169                cols.p_at_z = bb_event.ext_vec.p_at_z;
170                cols.p_at_x = bb_event.base_vec.p_at_x;
171                row
172            })
173            .collect_vec();
174
175        // Pad the trace to a power of two.
176        rows.resize(self.num_rows(input).unwrap(), [BabyBear::zero(); NUM_BATCH_FRI_COLS]);
177
178        // Convert the trace to a row major matrix.
179        let trace = RowMajorMatrix::new(
180            unsafe {
181                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
182                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
183                )
184            },
185            NUM_BATCH_FRI_COLS,
186        );
187
188        #[cfg(debug_assertions)]
189        eprintln!(
190            "batch fri trace dims is width: {:?}, height: {:?}",
191            trace.width(),
192            trace.height()
193        );
194
195        trace
196    }
197
198    fn included(&self, _record: &Self::Record) -> bool {
199        true
200    }
201}
202
203impl<const DEGREE: usize> BatchFRIChip<DEGREE> {
204    pub fn eval_batch_fri<AB: SP1RecursionAirBuilder>(
205        &self,
206        builder: &mut AB,
207        local: &BatchFRICols<AB::Var>,
208        next: &BatchFRICols<AB::Var>,
209        local_prepr: &BatchFRIPreprocessedCols<AB::Var>,
210        _next_prepr: &BatchFRIPreprocessedCols<AB::Var>,
211    ) {
212        // Constrain memory read for alpha_pow, p_at_z, and p_at_x.
213        builder.receive_block(local_prepr.alpha_pow_addr, local.alpha_pow, local_prepr.is_real);
214        builder.receive_block(local_prepr.p_at_z_addr, local.p_at_z, local_prepr.is_real);
215        builder.receive_single(local_prepr.p_at_x_addr, local.p_at_x, local_prepr.is_real);
216
217        // Constrain memory write for the accumulator.
218        // Note that we write with multiplicity 1, when `is_end` is true.
219        builder.send_block(local_prepr.acc_addr, local.acc, local_prepr.is_end);
220
221        // Constrain the accumulator value of the first row.
222        builder.when_first_row().assert_ext_eq(
223            local.acc.as_extension::<AB>(),
224            local.alpha_pow.as_extension::<AB>() *
225                (local.p_at_z.as_extension::<AB>() -
226                    BinomialExtension::from_base(local.p_at_x.into())),
227        );
228
229        // Constrain the accumulator of the next row when the current row is the end of loop.
230        builder.when_transition().when(local_prepr.is_end).assert_ext_eq(
231            next.acc.as_extension::<AB>(),
232            next.alpha_pow.as_extension::<AB>() *
233                (next.p_at_z.as_extension::<AB>() -
234                    BinomialExtension::from_base(next.p_at_x.into())),
235        );
236
237        // Constrain the accumulator of the next row when the current row is not the end of loop.
238        builder.when_transition().when_not(local_prepr.is_end).assert_ext_eq(
239            next.acc.as_extension::<AB>(),
240            local.acc.as_extension::<AB>() +
241                next.alpha_pow.as_extension::<AB>() *
242                    (next.p_at_z.as_extension::<AB>() -
243                        BinomialExtension::from_base(next.p_at_x.into())),
244        );
245    }
246
247    pub const fn do_memory_access<T: Copy>(local: &BatchFRIPreprocessedCols<T>) -> T {
248        local.is_real
249    }
250}
251
252impl<AB, const DEGREE: usize> Air<AB> for BatchFRIChip<DEGREE>
253where
254    AB: SP1RecursionAirBuilder + PairBuilder,
255{
256    fn eval(&self, builder: &mut AB) {
257        let main = builder.main();
258        let (local, next) = (main.row_slice(0), main.row_slice(1));
259        let local: &BatchFRICols<AB::Var> = (*local).borrow();
260        let next: &BatchFRICols<AB::Var> = (*next).borrow();
261        let prepr = builder.preprocessed();
262        let (prepr_local, prepr_next) = (prepr.row_slice(0), prepr.row_slice(1));
263        let prepr_local: &BatchFRIPreprocessedCols<AB::Var> = (*prepr_local).borrow();
264        let prepr_next: &BatchFRIPreprocessedCols<AB::Var> = (*prepr_next).borrow();
265
266        // Dummy constraints to normalize to DEGREE.
267        let lhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
268        let rhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
269        builder.assert_eq(lhs, rhs);
270
271        self.eval_batch_fri::<AB>(builder, local, next, prepr_local, prepr_next);
272    }
273}
274
275#[cfg(all(test, feature = "sys"))]
276mod tests {
277    use crate::{chips::test_fixtures, Instruction, RecursionProgram};
278    use p3_baby_bear::BabyBear;
279    use p3_field::AbstractField;
280    use p3_matrix::dense::RowMajorMatrix;
281
282    use super::*;
283
284    const DEGREE: usize = 2;
285
286    fn generate_trace_reference<const DEGREE: usize>(
287        input: &ExecutionRecord<BabyBear>,
288        _: &mut ExecutionRecord<BabyBear>,
289    ) -> RowMajorMatrix<BabyBear> {
290        type F = BabyBear;
291
292        let mut rows = input
293            .batch_fri_events
294            .iter()
295            .map(|event| {
296                let mut row = [F::zero(); NUM_BATCH_FRI_COLS];
297                let cols: &mut BatchFRICols<F> = row.as_mut_slice().borrow_mut();
298                cols.acc = event.ext_single.acc;
299                cols.alpha_pow = event.ext_vec.alpha_pow;
300                cols.p_at_z = event.ext_vec.p_at_z;
301                cols.p_at_x = event.base_vec.p_at_x;
302                row
303            })
304            .collect_vec();
305
306        rows.resize(
307            BatchFRIChip::<DEGREE>.num_rows(input).unwrap(),
308            [F::zero(); NUM_BATCH_FRI_COLS],
309        );
310
311        RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_BATCH_FRI_COLS)
312    }
313
314    #[test]
315    fn generate_trace() {
316        let shard = test_fixtures::shard();
317        let mut execution_record = test_fixtures::default_execution_record();
318        let trace = BatchFRIChip::<DEGREE>.generate_trace(&shard, &mut execution_record);
319        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
320
321        assert_eq!(trace, generate_trace_reference::<DEGREE>(&shard, &mut execution_record));
322    }
323
324    fn generate_preprocessed_trace_reference<const DEGREE: usize>(
325        program: &RecursionProgram<BabyBear>,
326    ) -> RowMajorMatrix<BabyBear> {
327        type F = BabyBear;
328
329        let mut rows: Vec<[F; NUM_BATCH_FRI_PREPROCESSED_COLS]> = Vec::new();
330        program
331            .inner
332            .iter()
333            .filter_map(|instruction| match instruction {
334                Instruction::BatchFRI(instr) => Some(instr),
335                _ => None,
336            })
337            .for_each(|instruction| {
338                let BatchFRIInstr { base_vec_addrs, ext_single_addrs, ext_vec_addrs, acc_mult } =
339                    instruction.as_ref();
340                let len = ext_vec_addrs.p_at_z.len();
341                let mut row_add = vec![[F::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS]; len];
342                debug_assert_eq!(*acc_mult, F::one());
343
344                row_add.iter_mut().enumerate().for_each(|(_i, row)| {
345                    let row: &mut BatchFRIPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
346                    row.is_real = F::one();
347                    row.is_end = F::from_bool(_i == len - 1);
348                    row.acc_addr = ext_single_addrs.acc;
349                    row.alpha_pow_addr = ext_vec_addrs.alpha_pow[_i];
350                    row.p_at_z_addr = ext_vec_addrs.p_at_z[_i];
351                    row.p_at_x_addr = base_vec_addrs.p_at_x[_i];
352                });
353                rows.extend(row_add);
354            });
355
356        pad_rows_fixed(
357            &mut rows,
358            || [F::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS],
359            program.fixed_log2_rows(&BatchFRIChip::<DEGREE>),
360        );
361
362        RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_BATCH_FRI_PREPROCESSED_COLS)
363    }
364
365    #[test]
366    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
367    fn generate_preprocessed_trace() {
368        let program = test_fixtures::program();
369        let trace = BatchFRIChip::<DEGREE>.generate_preprocessed_trace(&program).unwrap();
370        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
371
372        assert_eq!(trace, generate_preprocessed_trace_reference::<DEGREE>(&program));
373    }
374}