sp1_recursion_core_v2/chips/
fri_fold.rs

1#![allow(clippy::needless_range_loop)]
2
3use core::borrow::Borrow;
4use itertools::Itertools;
5use sp1_core_machine::utils::pad_rows_fixed;
6use sp1_stark::air::{BinomialExtension, MachineAir};
7use std::borrow::BorrowMut;
8use tracing::instrument;
9
10use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
11use p3_field::PrimeField32;
12use p3_matrix::{dense::RowMajorMatrix, Matrix};
13use sp1_stark::air::{BaseAirBuilder, ExtensionAirBuilder};
14
15use sp1_derive::AlignedBorrow;
16use sp1_recursion_core::air::Block;
17
18use crate::{
19    builder::SP1RecursionAirBuilder,
20    runtime::{Instruction, RecursionProgram},
21    ExecutionRecord, FriFoldInstr,
22};
23
24use super::mem::MemoryAccessCols;
25
26pub const NUM_FRI_FOLD_COLS: usize = core::mem::size_of::<FriFoldCols<u8>>();
27pub const NUM_FRI_FOLD_PREPROCESSED_COLS: usize =
28    core::mem::size_of::<FriFoldPreprocessedCols<u8>>();
29
30pub struct FriFoldChip<const DEGREE: usize> {
31    pub fixed_log2_rows: Option<usize>,
32    pub pad: bool,
33}
34
35impl<const DEGREE: usize> Default for FriFoldChip<DEGREE> {
36    fn default() -> Self {
37        Self { fixed_log2_rows: None, pad: true }
38    }
39}
40
41/// The preprocessed columns for a FRI fold invocation.
42#[derive(AlignedBorrow, Debug, Clone, Copy)]
43#[repr(C)]
44pub struct FriFoldPreprocessedCols<T: Copy> {
45    pub is_first: T,
46
47    // Memory accesses for the single fields.
48    pub z_mem: MemoryAccessCols<T>,
49    pub alpha_mem: MemoryAccessCols<T>,
50    pub x_mem: MemoryAccessCols<T>,
51
52    // Memory accesses for the vector field inputs.
53    pub alpha_pow_input_mem: MemoryAccessCols<T>,
54    pub ro_input_mem: MemoryAccessCols<T>,
55    pub p_at_x_mem: MemoryAccessCols<T>,
56    pub p_at_z_mem: MemoryAccessCols<T>,
57
58    // Memory accesses for the vector field outputs.
59    pub ro_output_mem: MemoryAccessCols<T>,
60    pub alpha_pow_output_mem: MemoryAccessCols<T>,
61
62    pub is_real: T,
63}
64
65#[derive(AlignedBorrow, Debug, Clone, Copy)]
66#[repr(C)]
67pub struct FriFoldCols<T: Copy> {
68    pub z: Block<T>,
69    pub alpha: Block<T>,
70    pub x: T,
71
72    pub p_at_x: Block<T>,
73    pub p_at_z: Block<T>,
74    pub alpha_pow_input: Block<T>,
75    pub ro_input: Block<T>,
76
77    pub alpha_pow_output: Block<T>,
78    pub ro_output: Block<T>,
79}
80
81impl<F, const DEGREE: usize> BaseAir<F> for FriFoldChip<DEGREE> {
82    fn width(&self) -> usize {
83        NUM_FRI_FOLD_COLS
84    }
85}
86
87impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for FriFoldChip<DEGREE> {
88    type Record = ExecutionRecord<F>;
89
90    type Program = RecursionProgram<F>;
91
92    fn name(&self) -> String {
93        "FriFold".to_string()
94    }
95
96    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
97        // This is a no-op.
98    }
99
100    fn preprocessed_width(&self) -> usize {
101        NUM_FRI_FOLD_PREPROCESSED_COLS
102    }
103    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
104        let mut rows: Vec<[F; NUM_FRI_FOLD_PREPROCESSED_COLS]> = Vec::new();
105        program
106            .instructions
107            .iter()
108            .filter_map(|instruction| {
109                if let Instruction::FriFold(instr) = instruction {
110                    Some(instr)
111                } else {
112                    None
113                }
114            })
115            .for_each(|instruction| {
116                let FriFoldInstr {
117                    base_single_addrs,
118                    ext_single_addrs,
119                    ext_vec_addrs,
120                    alpha_pow_mults,
121                    ro_mults,
122                } = instruction.as_ref();
123                let mut row_add =
124                    vec![[F::zero(); NUM_FRI_FOLD_PREPROCESSED_COLS]; ext_vec_addrs.ps_at_z.len()];
125
126                row_add.iter_mut().enumerate().for_each(|(i, row)| {
127                    let row: &mut FriFoldPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
128                    row.is_first = F::from_bool(i == 0);
129
130                    // Only need to read z, x, and alpha on the first iteration, hence the
131                    // multiplicities are i==0.
132                    row.z_mem =
133                        MemoryAccessCols { addr: ext_single_addrs.z, mult: -F::from_bool(i == 0) };
134                    row.x_mem =
135                        MemoryAccessCols { addr: base_single_addrs.x, mult: -F::from_bool(i == 0) };
136                    row.alpha_mem = MemoryAccessCols {
137                        addr: ext_single_addrs.alpha,
138                        mult: -F::from_bool(i == 0),
139                    };
140
141                    // Read the memory for the input vectors.
142                    row.alpha_pow_input_mem = MemoryAccessCols {
143                        addr: ext_vec_addrs.alpha_pow_input[i],
144                        mult: F::neg_one(),
145                    };
146                    row.ro_input_mem =
147                        MemoryAccessCols { addr: ext_vec_addrs.ro_input[i], mult: F::neg_one() };
148                    row.p_at_z_mem =
149                        MemoryAccessCols { addr: ext_vec_addrs.ps_at_z[i], mult: F::neg_one() };
150                    row.p_at_x_mem =
151                        MemoryAccessCols { addr: ext_vec_addrs.mat_opening[i], mult: F::neg_one() };
152
153                    // Write the memory for the output vectors.
154                    row.alpha_pow_output_mem = MemoryAccessCols {
155                        addr: ext_vec_addrs.alpha_pow_output[i],
156                        mult: alpha_pow_mults[i],
157                    };
158                    row.ro_output_mem =
159                        MemoryAccessCols { addr: ext_vec_addrs.ro_output[i], mult: ro_mults[i] };
160
161                    row.is_real = F::one();
162                });
163                rows.extend(row_add);
164            });
165
166        // Pad the trace to a power of two.
167        if self.pad {
168            pad_rows_fixed(
169                &mut rows,
170                || [F::zero(); NUM_FRI_FOLD_PREPROCESSED_COLS],
171                self.fixed_log2_rows,
172            );
173        }
174
175        let trace = RowMajorMatrix::new(
176            rows.into_iter().flatten().collect(),
177            NUM_FRI_FOLD_PREPROCESSED_COLS,
178        );
179        Some(trace)
180    }
181
182    #[instrument(name = "generate fri fold trace", level = "debug", skip_all, fields(rows = input.fri_fold_events.len()))]
183    fn generate_trace(
184        &self,
185        input: &ExecutionRecord<F>,
186        _: &mut ExecutionRecord<F>,
187    ) -> RowMajorMatrix<F> {
188        let mut rows = input
189            .fri_fold_events
190            .iter()
191            .map(|event| {
192                let mut row = [F::zero(); NUM_FRI_FOLD_COLS];
193
194                let cols: &mut FriFoldCols<F> = row.as_mut_slice().borrow_mut();
195
196                cols.x = event.base_single.x;
197                cols.z = event.ext_single.z;
198                cols.alpha = event.ext_single.alpha;
199
200                cols.p_at_z = event.ext_vec.ps_at_z;
201                cols.p_at_x = event.ext_vec.mat_opening;
202                cols.alpha_pow_input = event.ext_vec.alpha_pow_input;
203                cols.ro_input = event.ext_vec.ro_input;
204
205                cols.alpha_pow_output = event.ext_vec.alpha_pow_output;
206                cols.ro_output = event.ext_vec.ro_output;
207
208                row
209            })
210            .collect_vec();
211
212        // Pad the trace to a power of two.
213        if self.pad {
214            pad_rows_fixed(&mut rows, || [F::zero(); NUM_FRI_FOLD_COLS], self.fixed_log2_rows);
215        }
216
217        // Convert the trace to a row major matrix.
218        let trace = RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_FRI_FOLD_COLS);
219
220        #[cfg(debug_assertions)]
221        println!("fri fold trace dims is width: {:?}, height: {:?}", trace.width(), trace.height());
222
223        trace
224    }
225
226    fn included(&self, _record: &Self::Record) -> bool {
227        true
228    }
229}
230
231impl<const DEGREE: usize> FriFoldChip<DEGREE> {
232    pub fn eval_fri_fold<AB: SP1RecursionAirBuilder>(
233        &self,
234        builder: &mut AB,
235        local: &FriFoldCols<AB::Var>,
236        next: &FriFoldCols<AB::Var>,
237        local_prepr: &FriFoldPreprocessedCols<AB::Var>,
238        next_prepr: &FriFoldPreprocessedCols<AB::Var>,
239    ) {
240        // Constrain mem read for x.  Read at the first fri fold row.
241        builder.send_single(local_prepr.x_mem.addr, local.x, local_prepr.x_mem.mult);
242
243        // Ensure that the x value is the same for all rows within a fri fold invocation.
244        builder
245            .when_transition()
246            .when(next_prepr.is_real)
247            .when_not(next_prepr.is_first)
248            .assert_eq(local.x, next.x);
249
250        // Constrain mem read for z.  Read at the first fri fold row.
251        builder.send_block(local_prepr.z_mem.addr, local.z, local_prepr.z_mem.mult);
252
253        // Ensure that the z value is the same for all rows within a fri fold invocation.
254        builder
255            .when_transition()
256            .when(next_prepr.is_real)
257            .when_not(next_prepr.is_first)
258            .assert_ext_eq(local.z.as_extension::<AB>(), next.z.as_extension::<AB>());
259
260        // Constrain mem read for alpha.  Read at the first fri fold row.
261        builder.send_block(local_prepr.alpha_mem.addr, local.alpha, local_prepr.alpha_mem.mult);
262
263        // Ensure that the alpha value is the same for all rows within a fri fold invocation.
264        builder
265            .when_transition()
266            .when(next_prepr.is_real)
267            .when_not(next_prepr.is_first)
268            .assert_ext_eq(local.alpha.as_extension::<AB>(), next.alpha.as_extension::<AB>());
269
270        // Constrain read for alpha_pow_input.
271        builder.send_block(
272            local_prepr.alpha_pow_input_mem.addr,
273            local.alpha_pow_input,
274            local_prepr.alpha_pow_input_mem.mult,
275        );
276
277        // Constrain read for ro_input.
278        builder.send_block(
279            local_prepr.ro_input_mem.addr,
280            local.ro_input,
281            local_prepr.ro_input_mem.mult,
282        );
283
284        // Constrain read for p_at_z.
285        builder.send_block(local_prepr.p_at_z_mem.addr, local.p_at_z, local_prepr.p_at_z_mem.mult);
286
287        // Constrain read for p_at_x.
288        builder.send_block(local_prepr.p_at_x_mem.addr, local.p_at_x, local_prepr.p_at_x_mem.mult);
289
290        // Constrain write for alpha_pow_output.
291        builder.send_block(
292            local_prepr.alpha_pow_output_mem.addr,
293            local.alpha_pow_output,
294            local_prepr.alpha_pow_output_mem.mult,
295        );
296
297        // Constrain write for ro_output.
298        builder.send_block(
299            local_prepr.ro_output_mem.addr,
300            local.ro_output,
301            local_prepr.ro_output_mem.mult,
302        );
303
304        // 1. Constrain new_value = old_value * alpha.
305        let alpha = local.alpha.as_extension::<AB>();
306        let old_alpha_pow = local.alpha_pow_input.as_extension::<AB>();
307        let new_alpha_pow = local.alpha_pow_output.as_extension::<AB>();
308        builder.assert_ext_eq(old_alpha_pow.clone() * alpha, new_alpha_pow.clone());
309
310        // 2. Constrain new_value = old_alpha_pow * quotient + old_ro,
311        // where quotient = (p_at_x - p_at_z) / (x - z)
312        // <=> (new_ro - old_ro) * (z - x) = old_alpha_pow * (p_at_x - p_at_z)
313        let p_at_z = local.p_at_z.as_extension::<AB>();
314        let p_at_x = local.p_at_x.as_extension::<AB>();
315        let z = local.z.as_extension::<AB>();
316        let x = local.x.into();
317        let old_ro = local.ro_input.as_extension::<AB>();
318        let new_ro = local.ro_output.as_extension::<AB>();
319        builder.assert_ext_eq(
320            (new_ro.clone() - old_ro) * (BinomialExtension::from_base(x) - z),
321            (p_at_x - p_at_z) * old_alpha_pow,
322        );
323    }
324
325    pub const fn do_memory_access<T: Copy>(local: &FriFoldPreprocessedCols<T>) -> T {
326        local.is_real
327    }
328}
329
330impl<AB, const DEGREE: usize> Air<AB> for FriFoldChip<DEGREE>
331where
332    AB: SP1RecursionAirBuilder + PairBuilder,
333{
334    fn eval(&self, builder: &mut AB) {
335        let main = builder.main();
336        let (local, next) = (main.row_slice(0), main.row_slice(1));
337        let local: &FriFoldCols<AB::Var> = (*local).borrow();
338        let next: &FriFoldCols<AB::Var> = (*next).borrow();
339        let prepr = builder.preprocessed();
340        let (prepr_local, prepr_next) = (prepr.row_slice(0), prepr.row_slice(1));
341        let prepr_local: &FriFoldPreprocessedCols<AB::Var> = (*prepr_local).borrow();
342        let prepr_next: &FriFoldPreprocessedCols<AB::Var> = (*prepr_next).borrow();
343
344        // Dummy constraints to normalize to DEGREE.
345        let lhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
346        let rhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
347        builder.assert_eq(lhs, rhs);
348
349        self.eval_fri_fold::<AB>(builder, local, next, prepr_local, prepr_next);
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use p3_field::AbstractExtensionField;
356    use rand::{rngs::StdRng, Rng, SeedableRng};
357    use sp1_core_machine::utils::setup_logger;
358    use sp1_recursion_core::{air::Block, stark::config::BabyBearPoseidon2Outer};
359    use sp1_stark::{air::MachineAir, StarkGenericConfig};
360    use std::mem::size_of;
361
362    use p3_baby_bear::BabyBear;
363    use p3_field::AbstractField;
364    use p3_matrix::dense::RowMajorMatrix;
365
366    use crate::{
367        chips::fri_fold::FriFoldChip,
368        machine::tests::run_recursion_test_machines,
369        runtime::{instruction as instr, ExecutionRecord},
370        FriFoldBaseIo, FriFoldEvent, FriFoldExtSingleIo, FriFoldExtVecIo, Instruction,
371        MemAccessKind, RecursionProgram,
372    };
373
374    #[test]
375    fn prove_babybear_circuit_fri_fold() {
376        setup_logger();
377        type SC = BabyBearPoseidon2Outer;
378        type F = <SC as StarkGenericConfig>::Val;
379        type EF = <SC as StarkGenericConfig>::Challenge;
380
381        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
382        let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
383        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
384        let mut random_block =
385            move || Block::from([F::from_canonical_u32(rng.gen_range(0..1 << 16)); 4]);
386        let mut addr = 0;
387
388        let num_ext_vecs: u32 = size_of::<FriFoldExtVecIo<u8>>() as u32;
389        let num_singles: u32 =
390            size_of::<FriFoldBaseIo<u8>>() as u32 + size_of::<FriFoldExtSingleIo<u8>>() as u32;
391
392        let instructions = (2..17)
393            .flat_map(|i: u32| {
394                let alloc_size = i * (num_ext_vecs + 2) + num_singles;
395
396                // Allocate the memory for a FRI fold instruction. Here, i is the lengths
397                // of the vectors for the vector fields of the instruction.
398                let mat_opening_a = (0..i).map(|x| x + addr).collect::<Vec<_>>();
399                let ps_at_z_a = (0..i).map(|x| x + i + addr).collect::<Vec<_>>();
400
401                let alpha_pow_input_a = (0..i).map(|x: u32| x + addr + 2 * i).collect::<Vec<_>>();
402                let ro_input_a = (0..i).map(|x: u32| x + addr + 3 * i).collect::<Vec<_>>();
403
404                let alpha_pow_output_a = (0..i).map(|x: u32| x + addr + 4 * i).collect::<Vec<_>>();
405                let ro_output_a = (0..i).map(|x: u32| x + addr + 5 * i).collect::<Vec<_>>();
406
407                let x_a = addr + 6 * i;
408                let z_a = addr + 6 * i + 1;
409                let alpha_a = addr + 6 * i + 2;
410
411                addr += alloc_size;
412
413                // Generate random values for the inputs.
414                let x = random_felt();
415                let z = random_block();
416                let alpha = random_block();
417
418                let alpha_pow_input = (0..i).map(|_| random_block()).collect::<Vec<_>>();
419                let ro_input = (0..i).map(|_| random_block()).collect::<Vec<_>>();
420
421                let ps_at_z = (0..i).map(|_| random_block()).collect::<Vec<_>>();
422                let mat_opening = (0..i).map(|_| random_block()).collect::<Vec<_>>();
423
424                // Compute the outputs from the inputs.
425                let alpha_pow_output = (0..i)
426                    .map(|i| alpha_pow_input[i as usize].ext::<EF>() * alpha.ext::<EF>())
427                    .collect::<Vec<EF>>();
428                let ro_output = (0..i)
429                    .map(|i| {
430                        let i = i as usize;
431                        ro_input[i].ext::<EF>()
432                            + alpha_pow_input[i].ext::<EF>()
433                                * (-ps_at_z[i].ext::<EF>() + mat_opening[i].ext::<EF>())
434                                / (-z.ext::<EF>() + x)
435                    })
436                    .collect::<Vec<EF>>();
437
438                // Write the inputs to memory.
439                let mut instructions = vec![instr::mem_single(MemAccessKind::Write, 1, x_a, x)];
440
441                instructions.push(instr::mem_block(MemAccessKind::Write, 1, z_a, z));
442
443                instructions.push(instr::mem_block(MemAccessKind::Write, 1, alpha_a, alpha));
444
445                (0..i).for_each(|j_32| {
446                    let j = j_32 as usize;
447                    instructions.push(instr::mem_block(
448                        MemAccessKind::Write,
449                        1,
450                        mat_opening_a[j],
451                        mat_opening[j],
452                    ));
453                    instructions.push(instr::mem_block(
454                        MemAccessKind::Write,
455                        1,
456                        ps_at_z_a[j],
457                        ps_at_z[j],
458                    ));
459
460                    instructions.push(instr::mem_block(
461                        MemAccessKind::Write,
462                        1,
463                        alpha_pow_input_a[j],
464                        alpha_pow_input[j],
465                    ));
466                    instructions.push(instr::mem_block(
467                        MemAccessKind::Write,
468                        1,
469                        ro_input_a[j],
470                        ro_input[j],
471                    ));
472                });
473
474                // Generate the FRI fold instruction.
475                instructions.push(instr::fri_fold(
476                    z_a,
477                    alpha_a,
478                    x_a,
479                    mat_opening_a.clone(),
480                    ps_at_z_a.clone(),
481                    alpha_pow_input_a.clone(),
482                    ro_input_a.clone(),
483                    alpha_pow_output_a.clone(),
484                    ro_output_a.clone(),
485                    vec![1; i as usize],
486                    vec![1; i as usize],
487                ));
488
489                // Read all the outputs.
490                (0..i).for_each(|j| {
491                    let j = j as usize;
492                    instructions.push(instr::mem_block(
493                        MemAccessKind::Read,
494                        1,
495                        alpha_pow_output_a[j],
496                        Block::from(alpha_pow_output[j].as_base_slice()),
497                    ));
498                    instructions.push(instr::mem_block(
499                        MemAccessKind::Read,
500                        1,
501                        ro_output_a[j],
502                        Block::from(ro_output[j].as_base_slice()),
503                    ));
504                });
505
506                instructions
507            })
508            .collect::<Vec<Instruction<F>>>();
509
510        let program = RecursionProgram { instructions, ..Default::default() };
511
512        run_recursion_test_machines(program);
513    }
514
515    #[test]
516    fn generate_fri_fold_circuit_trace() {
517        type F = BabyBear;
518
519        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
520        let mut rng2 = StdRng::seed_from_u64(0xDEADBEEF);
521        let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
522        let mut random_block = move || Block::from([random_felt(); 4]);
523
524        let shard = ExecutionRecord {
525            fri_fold_events: (0..17)
526                .map(|_| FriFoldEvent {
527                    base_single: FriFoldBaseIo {
528                        x: F::from_canonical_u32(rng2.gen_range(0..1 << 16)),
529                    },
530                    ext_single: FriFoldExtSingleIo { z: random_block(), alpha: random_block() },
531                    ext_vec: crate::FriFoldExtVecIo {
532                        mat_opening: random_block(),
533                        ps_at_z: random_block(),
534                        alpha_pow_input: random_block(),
535                        ro_input: random_block(),
536                        alpha_pow_output: random_block(),
537                        ro_output: random_block(),
538                    },
539                })
540                .collect(),
541            ..Default::default()
542        };
543        let chip = FriFoldChip::<3>::default();
544        let trace: RowMajorMatrix<F> = chip.generate_trace(&shard, &mut ExecutionRecord::default());
545        println!("{:?}", trace.values)
546    }
547}