sp1_recursion_core/chips/
batch_fri.rs

1#![allow(clippy::needless_range_loop)]
2
3use crate::{
4    air::Block, builder::SP1RecursionAirBuilder, Address, BatchFRIEvent, BatchFRIInstr,
5    ExecutionRecord, Instruction,
6};
7use core::borrow::Borrow;
8use itertools::Itertools;
9use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
10use p3_baby_bear::BabyBear;
11use p3_field::{AbstractField, PrimeField32};
12use p3_matrix::{dense::RowMajorMatrix, Matrix};
13use sp1_core_machine::utils::{next_power_of_two, pad_rows_fixed};
14use sp1_derive::AlignedBorrow;
15use sp1_stark::air::{BaseAirBuilder, BinomialExtension, ExtensionAirBuilder, MachineAir};
16
17use std::borrow::BorrowMut;
18use tracing::instrument;
19
20pub const NUM_BATCH_FRI_COLS: usize = core::mem::size_of::<BatchFRICols<u8>>();
21pub const NUM_BATCH_FRI_PREPROCESSED_COLS: usize =
22    core::mem::size_of::<BatchFRIPreprocessedCols<u8>>();
23
24#[derive(Clone, Debug, Copy, Default)]
25pub struct BatchFRIChip<const DEGREE: usize>;
26
27/// The preprocessed columns for a batch FRI invocation.
28#[derive(AlignedBorrow, Debug, Clone, Copy)]
29#[repr(C)]
30pub struct BatchFRIPreprocessedCols<T: Copy> {
31    pub is_real: T,
32    pub is_end: T,
33    pub acc_addr: Address<T>,
34    pub alpha_pow_addr: Address<T>,
35    pub p_at_z_addr: Address<T>,
36    pub p_at_x_addr: Address<T>,
37}
38
39/// The main columns for a batch FRI invocation.
40#[derive(AlignedBorrow, Debug, Clone, Copy)]
41#[repr(C)]
42pub struct BatchFRICols<T: Copy> {
43    pub acc: Block<T>,
44    pub alpha_pow: Block<T>,
45    pub p_at_z: Block<T>,
46    pub p_at_x: T,
47}
48
49impl<F, const DEGREE: usize> BaseAir<F> for BatchFRIChip<DEGREE> {
50    fn width(&self) -> usize {
51        NUM_BATCH_FRI_COLS
52    }
53}
54
55impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for BatchFRIChip<DEGREE> {
56    type Record = ExecutionRecord<F>;
57
58    type Program = crate::RecursionProgram<F>;
59
60    fn name(&self) -> String {
61        "BatchFRI".to_string()
62    }
63
64    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
65        // This is a no-op.
66    }
67
68    fn preprocessed_width(&self) -> usize {
69        NUM_BATCH_FRI_PREPROCESSED_COLS
70    }
71
72    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
73        assert_eq!(
74            std::any::TypeId::of::<F>(),
75            std::any::TypeId::of::<BabyBear>(),
76            "generate_preprocessed_trace only supports BabyBear field"
77        );
78
79        let mut rows = Vec::new();
80        let instrs = unsafe {
81            std::mem::transmute::<Vec<&Box<BatchFRIInstr<F>>>, Vec<&Box<BatchFRIInstr<BabyBear>>>>(
82                program
83                    .inner
84                    .iter()
85                    .filter_map(|instruction| match instruction {
86                        Instruction::BatchFRI(x) => Some(x),
87                        _ => None,
88                    })
89                    .collect::<Vec<_>>(),
90            )
91        };
92        instrs.iter().for_each(|instruction| {
93            let BatchFRIInstr { base_vec_addrs: _, ext_single_addrs: _, ext_vec_addrs, acc_mult } =
94                instruction.as_ref();
95            let len: usize = ext_vec_addrs.p_at_z.len();
96            let mut row_add = vec![[BabyBear::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS]; len];
97            debug_assert_eq!(*acc_mult, BabyBear::one());
98
99            row_add.iter_mut().enumerate().for_each(|(i, row)| {
100                let cols: &mut BatchFRIPreprocessedCols<BabyBear> = row.as_mut_slice().borrow_mut();
101                unsafe {
102                    crate::sys::batch_fri_instr_to_row_babybear(&instruction.into(), cols, i);
103                }
104            });
105            rows.extend(row_add);
106        });
107
108        // Pad the trace to a power of two.
109        pad_rows_fixed(
110            &mut rows,
111            || [BabyBear::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS],
112            program.fixed_log2_rows(self),
113        );
114
115        let trace = RowMajorMatrix::new(
116            unsafe {
117                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
118                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
119                )
120            },
121            NUM_BATCH_FRI_PREPROCESSED_COLS,
122        );
123        Some(trace)
124    }
125
126    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
127        let events = &input.batch_fri_events;
128        Some(next_power_of_two(events.len(), input.fixed_log2_rows(self)))
129    }
130
131    #[instrument(name = "generate batch fri trace", level = "debug", skip_all, fields(rows = input.batch_fri_events.len()))]
132    fn generate_trace(
133        &self,
134        input: &ExecutionRecord<F>,
135        _: &mut ExecutionRecord<F>,
136    ) -> RowMajorMatrix<F> {
137        assert_eq!(
138            std::any::TypeId::of::<F>(),
139            std::any::TypeId::of::<BabyBear>(),
140            "generate_trace only supports BabyBear field"
141        );
142
143        let mut rows = input
144            .batch_fri_events
145            .iter()
146            .map(|event| {
147                let bb_event = unsafe {
148                    std::mem::transmute::<&BatchFRIEvent<F>, &BatchFRIEvent<BabyBear>>(event)
149                };
150                let mut row = [BabyBear::zero(); NUM_BATCH_FRI_COLS];
151                let cols: &mut BatchFRICols<BabyBear> = row.as_mut_slice().borrow_mut();
152                cols.acc = bb_event.ext_single.acc;
153                cols.alpha_pow = bb_event.ext_vec.alpha_pow;
154                cols.p_at_z = bb_event.ext_vec.p_at_z;
155                cols.p_at_x = bb_event.base_vec.p_at_x;
156                row
157            })
158            .collect_vec();
159
160        // Pad the trace to a power of two.
161        rows.resize(self.num_rows(input).unwrap(), [BabyBear::zero(); NUM_BATCH_FRI_COLS]);
162
163        // Convert the trace to a row major matrix.
164        let trace = RowMajorMatrix::new(
165            unsafe {
166                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
167                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
168                )
169            },
170            NUM_BATCH_FRI_COLS,
171        );
172
173        #[cfg(debug_assertions)]
174        eprintln!(
175            "batch fri trace dims is width: {:?}, height: {:?}",
176            trace.width(),
177            trace.height()
178        );
179
180        trace
181    }
182
183    fn included(&self, _record: &Self::Record) -> bool {
184        true
185    }
186}
187
188impl<const DEGREE: usize> BatchFRIChip<DEGREE> {
189    pub fn eval_batch_fri<AB: SP1RecursionAirBuilder>(
190        &self,
191        builder: &mut AB,
192        local: &BatchFRICols<AB::Var>,
193        next: &BatchFRICols<AB::Var>,
194        local_prepr: &BatchFRIPreprocessedCols<AB::Var>,
195        _next_prepr: &BatchFRIPreprocessedCols<AB::Var>,
196    ) {
197        // Constrain memory read for alpha_pow, p_at_z, and p_at_x.
198        builder.receive_block(local_prepr.alpha_pow_addr, local.alpha_pow, local_prepr.is_real);
199        builder.receive_block(local_prepr.p_at_z_addr, local.p_at_z, local_prepr.is_real);
200        builder.receive_single(local_prepr.p_at_x_addr, local.p_at_x, local_prepr.is_real);
201
202        // Constrain memory write for the accumulator.
203        // Note that we write with multiplicity 1, when `is_end` is true.
204        builder.send_block(local_prepr.acc_addr, local.acc, local_prepr.is_end);
205
206        // Constrain the accumulator value of the first row.
207        builder.when_first_row().assert_ext_eq(
208            local.acc.as_extension::<AB>(),
209            local.alpha_pow.as_extension::<AB>() *
210                (local.p_at_z.as_extension::<AB>() -
211                    BinomialExtension::from_base(local.p_at_x.into())),
212        );
213
214        // Constrain the accumulator of the next row when the current row is the end of loop.
215        builder.when_transition().when(local_prepr.is_end).assert_ext_eq(
216            next.acc.as_extension::<AB>(),
217            next.alpha_pow.as_extension::<AB>() *
218                (next.p_at_z.as_extension::<AB>() -
219                    BinomialExtension::from_base(next.p_at_x.into())),
220        );
221
222        // Constrain the accumulator of the next row when the current row is not the end of loop.
223        builder.when_transition().when_not(local_prepr.is_end).assert_ext_eq(
224            next.acc.as_extension::<AB>(),
225            local.acc.as_extension::<AB>() +
226                next.alpha_pow.as_extension::<AB>() *
227                    (next.p_at_z.as_extension::<AB>() -
228                        BinomialExtension::from_base(next.p_at_x.into())),
229        );
230    }
231
232    pub const fn do_memory_access<T: Copy>(local: &BatchFRIPreprocessedCols<T>) -> T {
233        local.is_real
234    }
235}
236
237impl<AB, const DEGREE: usize> Air<AB> for BatchFRIChip<DEGREE>
238where
239    AB: SP1RecursionAirBuilder + PairBuilder,
240{
241    fn eval(&self, builder: &mut AB) {
242        let main = builder.main();
243        let (local, next) = (main.row_slice(0), main.row_slice(1));
244        let local: &BatchFRICols<AB::Var> = (*local).borrow();
245        let next: &BatchFRICols<AB::Var> = (*next).borrow();
246        let prepr = builder.preprocessed();
247        let (prepr_local, prepr_next) = (prepr.row_slice(0), prepr.row_slice(1));
248        let prepr_local: &BatchFRIPreprocessedCols<AB::Var> = (*prepr_local).borrow();
249        let prepr_next: &BatchFRIPreprocessedCols<AB::Var> = (*prepr_next).borrow();
250
251        // Dummy constraints to normalize to DEGREE.
252        let lhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
253        let rhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
254        builder.assert_eq(lhs, rhs);
255
256        self.eval_batch_fri::<AB>(builder, local, next, prepr_local, prepr_next);
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use crate::{chips::test_fixtures, Instruction, RecursionProgram};
263    use p3_baby_bear::BabyBear;
264    use p3_field::AbstractField;
265    use p3_matrix::dense::RowMajorMatrix;
266
267    use super::*;
268
269    const DEGREE: usize = 2;
270
271    fn generate_trace_reference<const DEGREE: usize>(
272        input: &ExecutionRecord<BabyBear>,
273        _: &mut ExecutionRecord<BabyBear>,
274    ) -> RowMajorMatrix<BabyBear> {
275        type F = BabyBear;
276
277        let mut rows = input
278            .batch_fri_events
279            .iter()
280            .map(|event| {
281                let mut row = [F::zero(); NUM_BATCH_FRI_COLS];
282                let cols: &mut BatchFRICols<F> = row.as_mut_slice().borrow_mut();
283                cols.acc = event.ext_single.acc;
284                cols.alpha_pow = event.ext_vec.alpha_pow;
285                cols.p_at_z = event.ext_vec.p_at_z;
286                cols.p_at_x = event.base_vec.p_at_x;
287                row
288            })
289            .collect_vec();
290
291        rows.resize(
292            BatchFRIChip::<DEGREE>.num_rows(input).unwrap(),
293            [F::zero(); NUM_BATCH_FRI_COLS],
294        );
295
296        RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_BATCH_FRI_COLS)
297    }
298
299    #[test]
300    fn generate_trace() {
301        let shard = test_fixtures::shard();
302        let mut execution_record = test_fixtures::default_execution_record();
303        let trace = BatchFRIChip::<DEGREE>.generate_trace(&shard, &mut execution_record);
304        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
305
306        assert_eq!(trace, generate_trace_reference::<DEGREE>(&shard, &mut execution_record));
307    }
308
309    fn generate_preprocessed_trace_reference<const DEGREE: usize>(
310        program: &RecursionProgram<BabyBear>,
311    ) -> RowMajorMatrix<BabyBear> {
312        type F = BabyBear;
313
314        let mut rows: Vec<[F; NUM_BATCH_FRI_PREPROCESSED_COLS]> = Vec::new();
315        program
316            .inner
317            .iter()
318            .filter_map(|instruction| match instruction {
319                Instruction::BatchFRI(instr) => Some(instr),
320                _ => None,
321            })
322            .for_each(|instruction| {
323                let BatchFRIInstr { base_vec_addrs, ext_single_addrs, ext_vec_addrs, acc_mult } =
324                    instruction.as_ref();
325                let len = ext_vec_addrs.p_at_z.len();
326                let mut row_add = vec![[F::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS]; len];
327                debug_assert_eq!(*acc_mult, F::one());
328
329                row_add.iter_mut().enumerate().for_each(|(_i, row)| {
330                    let row: &mut BatchFRIPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
331                    row.is_real = F::one();
332                    row.is_end = F::from_bool(_i == len - 1);
333                    row.acc_addr = ext_single_addrs.acc;
334                    row.alpha_pow_addr = ext_vec_addrs.alpha_pow[_i];
335                    row.p_at_z_addr = ext_vec_addrs.p_at_z[_i];
336                    row.p_at_x_addr = base_vec_addrs.p_at_x[_i];
337                });
338                rows.extend(row_add);
339            });
340
341        pad_rows_fixed(
342            &mut rows,
343            || [F::zero(); NUM_BATCH_FRI_PREPROCESSED_COLS],
344            program.fixed_log2_rows(&BatchFRIChip::<DEGREE>),
345        );
346
347        RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_BATCH_FRI_PREPROCESSED_COLS)
348    }
349
350    #[test]
351    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
352    fn generate_preprocessed_trace() {
353        let program = test_fixtures::program();
354        let trace = BatchFRIChip::<DEGREE>.generate_preprocessed_trace(&program).unwrap();
355        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
356
357        assert_eq!(trace, generate_preprocessed_trace_reference::<DEGREE>(&program));
358    }
359}