1#![allow(clippy::needless_range_loop)]
2
3use core::borrow::Borrow;
4use itertools::Itertools;
5use sp1_core_machine::utils::pad_rows_fixed;
6use sp1_stark::air::{BinomialExtension, MachineAir};
7use std::borrow::BorrowMut;
8use tracing::instrument;
9
10use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
11use p3_field::PrimeField32;
12use p3_matrix::{dense::RowMajorMatrix, Matrix};
13use sp1_stark::air::{BaseAirBuilder, ExtensionAirBuilder};
14
15use sp1_derive::AlignedBorrow;
16use sp1_recursion_core::air::Block;
17
18use crate::{
19 builder::SP1RecursionAirBuilder,
20 runtime::{Instruction, RecursionProgram},
21 ExecutionRecord, FriFoldInstr,
22};
23
24use super::mem::MemoryAccessCols;
25
26pub const NUM_FRI_FOLD_COLS: usize = core::mem::size_of::<FriFoldCols<u8>>();
27pub const NUM_FRI_FOLD_PREPROCESSED_COLS: usize =
28 core::mem::size_of::<FriFoldPreprocessedCols<u8>>();
29
30pub struct FriFoldChip<const DEGREE: usize> {
31 pub fixed_log2_rows: Option<usize>,
32 pub pad: bool,
33}
34
35impl<const DEGREE: usize> Default for FriFoldChip<DEGREE> {
36 fn default() -> Self {
37 Self { fixed_log2_rows: None, pad: true }
38 }
39}
40
41#[derive(AlignedBorrow, Debug, Clone, Copy)]
43#[repr(C)]
44pub struct FriFoldPreprocessedCols<T: Copy> {
45 pub is_first: T,
46
47 pub z_mem: MemoryAccessCols<T>,
49 pub alpha_mem: MemoryAccessCols<T>,
50 pub x_mem: MemoryAccessCols<T>,
51
52 pub alpha_pow_input_mem: MemoryAccessCols<T>,
54 pub ro_input_mem: MemoryAccessCols<T>,
55 pub p_at_x_mem: MemoryAccessCols<T>,
56 pub p_at_z_mem: MemoryAccessCols<T>,
57
58 pub ro_output_mem: MemoryAccessCols<T>,
60 pub alpha_pow_output_mem: MemoryAccessCols<T>,
61
62 pub is_real: T,
63}
64
65#[derive(AlignedBorrow, Debug, Clone, Copy)]
66#[repr(C)]
67pub struct FriFoldCols<T: Copy> {
68 pub z: Block<T>,
69 pub alpha: Block<T>,
70 pub x: T,
71
72 pub p_at_x: Block<T>,
73 pub p_at_z: Block<T>,
74 pub alpha_pow_input: Block<T>,
75 pub ro_input: Block<T>,
76
77 pub alpha_pow_output: Block<T>,
78 pub ro_output: Block<T>,
79}
80
81impl<F, const DEGREE: usize> BaseAir<F> for FriFoldChip<DEGREE> {
82 fn width(&self) -> usize {
83 NUM_FRI_FOLD_COLS
84 }
85}
86
87impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for FriFoldChip<DEGREE> {
88 type Record = ExecutionRecord<F>;
89
90 type Program = RecursionProgram<F>;
91
92 fn name(&self) -> String {
93 "FriFold".to_string()
94 }
95
96 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
97 }
99
100 fn preprocessed_width(&self) -> usize {
101 NUM_FRI_FOLD_PREPROCESSED_COLS
102 }
103 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
104 let mut rows: Vec<[F; NUM_FRI_FOLD_PREPROCESSED_COLS]> = Vec::new();
105 program
106 .instructions
107 .iter()
108 .filter_map(|instruction| {
109 if let Instruction::FriFold(instr) = instruction {
110 Some(instr)
111 } else {
112 None
113 }
114 })
115 .for_each(|instruction| {
116 let FriFoldInstr {
117 base_single_addrs,
118 ext_single_addrs,
119 ext_vec_addrs,
120 alpha_pow_mults,
121 ro_mults,
122 } = instruction.as_ref();
123 let mut row_add =
124 vec![[F::zero(); NUM_FRI_FOLD_PREPROCESSED_COLS]; ext_vec_addrs.ps_at_z.len()];
125
126 row_add.iter_mut().enumerate().for_each(|(i, row)| {
127 let row: &mut FriFoldPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
128 row.is_first = F::from_bool(i == 0);
129
130 row.z_mem =
133 MemoryAccessCols { addr: ext_single_addrs.z, mult: -F::from_bool(i == 0) };
134 row.x_mem =
135 MemoryAccessCols { addr: base_single_addrs.x, mult: -F::from_bool(i == 0) };
136 row.alpha_mem = MemoryAccessCols {
137 addr: ext_single_addrs.alpha,
138 mult: -F::from_bool(i == 0),
139 };
140
141 row.alpha_pow_input_mem = MemoryAccessCols {
143 addr: ext_vec_addrs.alpha_pow_input[i],
144 mult: F::neg_one(),
145 };
146 row.ro_input_mem =
147 MemoryAccessCols { addr: ext_vec_addrs.ro_input[i], mult: F::neg_one() };
148 row.p_at_z_mem =
149 MemoryAccessCols { addr: ext_vec_addrs.ps_at_z[i], mult: F::neg_one() };
150 row.p_at_x_mem =
151 MemoryAccessCols { addr: ext_vec_addrs.mat_opening[i], mult: F::neg_one() };
152
153 row.alpha_pow_output_mem = MemoryAccessCols {
155 addr: ext_vec_addrs.alpha_pow_output[i],
156 mult: alpha_pow_mults[i],
157 };
158 row.ro_output_mem =
159 MemoryAccessCols { addr: ext_vec_addrs.ro_output[i], mult: ro_mults[i] };
160
161 row.is_real = F::one();
162 });
163 rows.extend(row_add);
164 });
165
166 if self.pad {
168 pad_rows_fixed(
169 &mut rows,
170 || [F::zero(); NUM_FRI_FOLD_PREPROCESSED_COLS],
171 self.fixed_log2_rows,
172 );
173 }
174
175 let trace = RowMajorMatrix::new(
176 rows.into_iter().flatten().collect(),
177 NUM_FRI_FOLD_PREPROCESSED_COLS,
178 );
179 Some(trace)
180 }
181
182 #[instrument(name = "generate fri fold trace", level = "debug", skip_all, fields(rows = input.fri_fold_events.len()))]
183 fn generate_trace(
184 &self,
185 input: &ExecutionRecord<F>,
186 _: &mut ExecutionRecord<F>,
187 ) -> RowMajorMatrix<F> {
188 let mut rows = input
189 .fri_fold_events
190 .iter()
191 .map(|event| {
192 let mut row = [F::zero(); NUM_FRI_FOLD_COLS];
193
194 let cols: &mut FriFoldCols<F> = row.as_mut_slice().borrow_mut();
195
196 cols.x = event.base_single.x;
197 cols.z = event.ext_single.z;
198 cols.alpha = event.ext_single.alpha;
199
200 cols.p_at_z = event.ext_vec.ps_at_z;
201 cols.p_at_x = event.ext_vec.mat_opening;
202 cols.alpha_pow_input = event.ext_vec.alpha_pow_input;
203 cols.ro_input = event.ext_vec.ro_input;
204
205 cols.alpha_pow_output = event.ext_vec.alpha_pow_output;
206 cols.ro_output = event.ext_vec.ro_output;
207
208 row
209 })
210 .collect_vec();
211
212 if self.pad {
214 pad_rows_fixed(&mut rows, || [F::zero(); NUM_FRI_FOLD_COLS], self.fixed_log2_rows);
215 }
216
217 let trace = RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_FRI_FOLD_COLS);
219
220 #[cfg(debug_assertions)]
221 println!("fri fold trace dims is width: {:?}, height: {:?}", trace.width(), trace.height());
222
223 trace
224 }
225
226 fn included(&self, _record: &Self::Record) -> bool {
227 true
228 }
229}
230
231impl<const DEGREE: usize> FriFoldChip<DEGREE> {
232 pub fn eval_fri_fold<AB: SP1RecursionAirBuilder>(
233 &self,
234 builder: &mut AB,
235 local: &FriFoldCols<AB::Var>,
236 next: &FriFoldCols<AB::Var>,
237 local_prepr: &FriFoldPreprocessedCols<AB::Var>,
238 next_prepr: &FriFoldPreprocessedCols<AB::Var>,
239 ) {
240 builder.send_single(local_prepr.x_mem.addr, local.x, local_prepr.x_mem.mult);
242
243 builder
245 .when_transition()
246 .when(next_prepr.is_real)
247 .when_not(next_prepr.is_first)
248 .assert_eq(local.x, next.x);
249
250 builder.send_block(local_prepr.z_mem.addr, local.z, local_prepr.z_mem.mult);
252
253 builder
255 .when_transition()
256 .when(next_prepr.is_real)
257 .when_not(next_prepr.is_first)
258 .assert_ext_eq(local.z.as_extension::<AB>(), next.z.as_extension::<AB>());
259
260 builder.send_block(local_prepr.alpha_mem.addr, local.alpha, local_prepr.alpha_mem.mult);
262
263 builder
265 .when_transition()
266 .when(next_prepr.is_real)
267 .when_not(next_prepr.is_first)
268 .assert_ext_eq(local.alpha.as_extension::<AB>(), next.alpha.as_extension::<AB>());
269
270 builder.send_block(
272 local_prepr.alpha_pow_input_mem.addr,
273 local.alpha_pow_input,
274 local_prepr.alpha_pow_input_mem.mult,
275 );
276
277 builder.send_block(
279 local_prepr.ro_input_mem.addr,
280 local.ro_input,
281 local_prepr.ro_input_mem.mult,
282 );
283
284 builder.send_block(local_prepr.p_at_z_mem.addr, local.p_at_z, local_prepr.p_at_z_mem.mult);
286
287 builder.send_block(local_prepr.p_at_x_mem.addr, local.p_at_x, local_prepr.p_at_x_mem.mult);
289
290 builder.send_block(
292 local_prepr.alpha_pow_output_mem.addr,
293 local.alpha_pow_output,
294 local_prepr.alpha_pow_output_mem.mult,
295 );
296
297 builder.send_block(
299 local_prepr.ro_output_mem.addr,
300 local.ro_output,
301 local_prepr.ro_output_mem.mult,
302 );
303
304 let alpha = local.alpha.as_extension::<AB>();
306 let old_alpha_pow = local.alpha_pow_input.as_extension::<AB>();
307 let new_alpha_pow = local.alpha_pow_output.as_extension::<AB>();
308 builder.assert_ext_eq(old_alpha_pow.clone() * alpha, new_alpha_pow.clone());
309
310 let p_at_z = local.p_at_z.as_extension::<AB>();
314 let p_at_x = local.p_at_x.as_extension::<AB>();
315 let z = local.z.as_extension::<AB>();
316 let x = local.x.into();
317 let old_ro = local.ro_input.as_extension::<AB>();
318 let new_ro = local.ro_output.as_extension::<AB>();
319 builder.assert_ext_eq(
320 (new_ro.clone() - old_ro) * (BinomialExtension::from_base(x) - z),
321 (p_at_x - p_at_z) * old_alpha_pow,
322 );
323 }
324
325 pub const fn do_memory_access<T: Copy>(local: &FriFoldPreprocessedCols<T>) -> T {
326 local.is_real
327 }
328}
329
330impl<AB, const DEGREE: usize> Air<AB> for FriFoldChip<DEGREE>
331where
332 AB: SP1RecursionAirBuilder + PairBuilder,
333{
334 fn eval(&self, builder: &mut AB) {
335 let main = builder.main();
336 let (local, next) = (main.row_slice(0), main.row_slice(1));
337 let local: &FriFoldCols<AB::Var> = (*local).borrow();
338 let next: &FriFoldCols<AB::Var> = (*next).borrow();
339 let prepr = builder.preprocessed();
340 let (prepr_local, prepr_next) = (prepr.row_slice(0), prepr.row_slice(1));
341 let prepr_local: &FriFoldPreprocessedCols<AB::Var> = (*prepr_local).borrow();
342 let prepr_next: &FriFoldPreprocessedCols<AB::Var> = (*prepr_next).borrow();
343
344 let lhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
346 let rhs = (0..DEGREE).map(|_| prepr_local.is_real.into()).product::<AB::Expr>();
347 builder.assert_eq(lhs, rhs);
348
349 self.eval_fri_fold::<AB>(builder, local, next, prepr_local, prepr_next);
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use p3_field::AbstractExtensionField;
356 use rand::{rngs::StdRng, Rng, SeedableRng};
357 use sp1_core_machine::utils::setup_logger;
358 use sp1_recursion_core::{air::Block, stark::config::BabyBearPoseidon2Outer};
359 use sp1_stark::{air::MachineAir, StarkGenericConfig};
360 use std::mem::size_of;
361
362 use p3_baby_bear::BabyBear;
363 use p3_field::AbstractField;
364 use p3_matrix::dense::RowMajorMatrix;
365
366 use crate::{
367 chips::fri_fold::FriFoldChip,
368 machine::tests::run_recursion_test_machines,
369 runtime::{instruction as instr, ExecutionRecord},
370 FriFoldBaseIo, FriFoldEvent, FriFoldExtSingleIo, FriFoldExtVecIo, Instruction,
371 MemAccessKind, RecursionProgram,
372 };
373
374 #[test]
375 fn prove_babybear_circuit_fri_fold() {
376 setup_logger();
377 type SC = BabyBearPoseidon2Outer;
378 type F = <SC as StarkGenericConfig>::Val;
379 type EF = <SC as StarkGenericConfig>::Challenge;
380
381 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
382 let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
383 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
384 let mut random_block =
385 move || Block::from([F::from_canonical_u32(rng.gen_range(0..1 << 16)); 4]);
386 let mut addr = 0;
387
388 let num_ext_vecs: u32 = size_of::<FriFoldExtVecIo<u8>>() as u32;
389 let num_singles: u32 =
390 size_of::<FriFoldBaseIo<u8>>() as u32 + size_of::<FriFoldExtSingleIo<u8>>() as u32;
391
392 let instructions = (2..17)
393 .flat_map(|i: u32| {
394 let alloc_size = i * (num_ext_vecs + 2) + num_singles;
395
396 let mat_opening_a = (0..i).map(|x| x + addr).collect::<Vec<_>>();
399 let ps_at_z_a = (0..i).map(|x| x + i + addr).collect::<Vec<_>>();
400
401 let alpha_pow_input_a = (0..i).map(|x: u32| x + addr + 2 * i).collect::<Vec<_>>();
402 let ro_input_a = (0..i).map(|x: u32| x + addr + 3 * i).collect::<Vec<_>>();
403
404 let alpha_pow_output_a = (0..i).map(|x: u32| x + addr + 4 * i).collect::<Vec<_>>();
405 let ro_output_a = (0..i).map(|x: u32| x + addr + 5 * i).collect::<Vec<_>>();
406
407 let x_a = addr + 6 * i;
408 let z_a = addr + 6 * i + 1;
409 let alpha_a = addr + 6 * i + 2;
410
411 addr += alloc_size;
412
413 let x = random_felt();
415 let z = random_block();
416 let alpha = random_block();
417
418 let alpha_pow_input = (0..i).map(|_| random_block()).collect::<Vec<_>>();
419 let ro_input = (0..i).map(|_| random_block()).collect::<Vec<_>>();
420
421 let ps_at_z = (0..i).map(|_| random_block()).collect::<Vec<_>>();
422 let mat_opening = (0..i).map(|_| random_block()).collect::<Vec<_>>();
423
424 let alpha_pow_output = (0..i)
426 .map(|i| alpha_pow_input[i as usize].ext::<EF>() * alpha.ext::<EF>())
427 .collect::<Vec<EF>>();
428 let ro_output = (0..i)
429 .map(|i| {
430 let i = i as usize;
431 ro_input[i].ext::<EF>()
432 + alpha_pow_input[i].ext::<EF>()
433 * (-ps_at_z[i].ext::<EF>() + mat_opening[i].ext::<EF>())
434 / (-z.ext::<EF>() + x)
435 })
436 .collect::<Vec<EF>>();
437
438 let mut instructions = vec![instr::mem_single(MemAccessKind::Write, 1, x_a, x)];
440
441 instructions.push(instr::mem_block(MemAccessKind::Write, 1, z_a, z));
442
443 instructions.push(instr::mem_block(MemAccessKind::Write, 1, alpha_a, alpha));
444
445 (0..i).for_each(|j_32| {
446 let j = j_32 as usize;
447 instructions.push(instr::mem_block(
448 MemAccessKind::Write,
449 1,
450 mat_opening_a[j],
451 mat_opening[j],
452 ));
453 instructions.push(instr::mem_block(
454 MemAccessKind::Write,
455 1,
456 ps_at_z_a[j],
457 ps_at_z[j],
458 ));
459
460 instructions.push(instr::mem_block(
461 MemAccessKind::Write,
462 1,
463 alpha_pow_input_a[j],
464 alpha_pow_input[j],
465 ));
466 instructions.push(instr::mem_block(
467 MemAccessKind::Write,
468 1,
469 ro_input_a[j],
470 ro_input[j],
471 ));
472 });
473
474 instructions.push(instr::fri_fold(
476 z_a,
477 alpha_a,
478 x_a,
479 mat_opening_a.clone(),
480 ps_at_z_a.clone(),
481 alpha_pow_input_a.clone(),
482 ro_input_a.clone(),
483 alpha_pow_output_a.clone(),
484 ro_output_a.clone(),
485 vec![1; i as usize],
486 vec![1; i as usize],
487 ));
488
489 (0..i).for_each(|j| {
491 let j = j as usize;
492 instructions.push(instr::mem_block(
493 MemAccessKind::Read,
494 1,
495 alpha_pow_output_a[j],
496 Block::from(alpha_pow_output[j].as_base_slice()),
497 ));
498 instructions.push(instr::mem_block(
499 MemAccessKind::Read,
500 1,
501 ro_output_a[j],
502 Block::from(ro_output[j].as_base_slice()),
503 ));
504 });
505
506 instructions
507 })
508 .collect::<Vec<Instruction<F>>>();
509
510 let program = RecursionProgram { instructions, ..Default::default() };
511
512 run_recursion_test_machines(program);
513 }
514
515 #[test]
516 fn generate_fri_fold_circuit_trace() {
517 type F = BabyBear;
518
519 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
520 let mut rng2 = StdRng::seed_from_u64(0xDEADBEEF);
521 let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
522 let mut random_block = move || Block::from([random_felt(); 4]);
523
524 let shard = ExecutionRecord {
525 fri_fold_events: (0..17)
526 .map(|_| FriFoldEvent {
527 base_single: FriFoldBaseIo {
528 x: F::from_canonical_u32(rng2.gen_range(0..1 << 16)),
529 },
530 ext_single: FriFoldExtSingleIo { z: random_block(), alpha: random_block() },
531 ext_vec: crate::FriFoldExtVecIo {
532 mat_opening: random_block(),
533 ps_at_z: random_block(),
534 alpha_pow_input: random_block(),
535 ro_input: random_block(),
536 alpha_pow_output: random_block(),
537 ro_output: random_block(),
538 },
539 })
540 .collect(),
541 ..Default::default()
542 };
543 let chip = FriFoldChip::<3>::default();
544 let trace: RowMajorMatrix<F> = chip.generate_trace(&shard, &mut ExecutionRecord::default());
545 println!("{:?}", trace.values)
546 }
547}