sp1_recursion_machine/chips/
prefix_sum_checks.rs1use crate::builder::SP1RecursionAirBuilder;
2use core::borrow::Borrow;
3use slop_air::{Air, BaseAir, PairBuilder};
4use slop_algebra::{AbstractField, PrimeField32};
5use slop_matrix::Matrix;
6use slop_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut};
7use sp1_derive::AlignedBorrow;
8use sp1_hypercube::{
9 air::{BinomialExtension, MachineAir},
10 next_multiple_of_32,
11};
12
13use sp1_primitives::SP1Field;
14use sp1_recursion_executor::{
15 Address, Block, ExecutionRecord, Instruction, PrefixSumChecksEvent, PrefixSumChecksInstr,
16 RecursionProgram,
17};
18
19use std::{borrow::BorrowMut, mem::MaybeUninit};
20
21pub const NUM_PREFIX_SUM_CHECKS_COLS: usize = core::mem::size_of::<PrefixSumChecksCols<u8>>();
22pub const NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS: usize =
23 core::mem::size_of::<PrefixSumChecksPreprocessedCols<u8>>();
24
25#[derive(Clone, Debug, Copy, Default)]
26pub struct PrefixSumChecksChip;
27
28#[derive(AlignedBorrow, Debug, Clone, Copy)]
30#[repr(C)]
31pub struct PrefixSumChecksCols<T: Copy> {
32 pub x1: T,
33 pub x2: Block<T>,
34 pub acc: Block<T>,
35 pub new_acc: Block<T>,
36 pub felt_acc: T,
37 pub felt_new_acc: T,
38}
39
40#[derive(AlignedBorrow, Clone, Copy, Debug)]
41#[repr(C)]
42pub struct PrefixSumChecksPreprocessedCols<T: Copy> {
43 pub x1_mem: Address<T>,
44 pub x2_mem: Address<T>,
45 pub acc_addr: Address<T>,
46 pub next_acc_addr: Address<T>,
47 pub next_acc_mult: T,
48 pub felt_acc_addr: Address<T>,
49 pub felt_next_acc_addr: Address<T>,
50 pub felt_next_acc_mult: T,
51 pub is_real: T,
52}
53
54impl<F> BaseAir<F> for PrefixSumChecksChip {
55 fn width(&self) -> usize {
56 NUM_PREFIX_SUM_CHECKS_COLS
57 }
58}
59
60impl<F: PrimeField32> MachineAir<F> for PrefixSumChecksChip {
61 type Record = ExecutionRecord<F>;
62
63 type Program = RecursionProgram<F>;
64
65 fn name(&self) -> &'static str {
66 "PrefixSumChecks"
67 }
68
69 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
70 }
72
73 fn preprocessed_width(&self) -> usize {
74 NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS
75 }
76
77 fn preprocessed_num_rows(&self, program: &Self::Program) -> Option<usize> {
78 let instrs_len = program
79 .inner
80 .iter()
81 .filter_map(|instruction| match instruction.inner() {
82 Instruction::PrefixSumChecks(instr) => Some(instr.addrs.x1.len()),
83 _ => None,
84 })
85 .sum();
86 self.preprocessed_num_rows_with_instrs_len(program, instrs_len)
87 }
88
89 fn preprocessed_num_rows_with_instrs_len(
90 &self,
91 program: &Self::Program,
92 instrs_len: usize,
93 ) -> Option<usize> {
94 let height = program.shape.as_ref().and_then(|shape| shape.height(self));
95 Some(next_multiple_of_32(instrs_len, height))
96 }
97
98 fn generate_preprocessed_trace_into(
99 &self,
100 program: &Self::Program,
101 buffer: &mut [MaybeUninit<F>],
102 ) {
103 assert_eq!(
104 std::any::TypeId::of::<F>(),
105 std::any::TypeId::of::<SP1Field>(),
106 "generate_preprocessed_trace only supports SP1Field field"
107 );
108
109 let instrs = program
110 .inner
111 .iter()
112 .filter_map(|instruction| match instruction.inner() {
113 Instruction::PrefixSumChecks(x) => Some(x),
114 _ => None,
115 })
116 .collect::<Vec<_>>();
117
118 let padded_nb_rows = self.preprocessed_num_rows(program).unwrap();
119
120 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
121 let values = unsafe {
122 core::slice::from_raw_parts_mut(
123 buffer_ptr,
124 padded_nb_rows * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS,
125 )
126 };
127
128 let mut row_cnt = 0;
129 instrs.iter().for_each(|instruction| {
130 let PrefixSumChecksInstr { addrs, acc_mults, field_acc_mults } = instruction.as_ref();
131 let len = addrs.x1.len();
132 (0..len).for_each(|i| {
133 let start = row_cnt * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS;
134 let end = (row_cnt + 1) * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS;
135 let cols: &mut PrefixSumChecksPreprocessedCols<F> = values[start..end].borrow_mut();
136 if i == 0 {
137 cols.acc_addr = addrs.one;
138 cols.felt_acc_addr = addrs.zero;
139 } else {
140 cols.acc_addr = addrs.accs[i - 1];
141 cols.felt_acc_addr = addrs.field_accs[i - 1];
142 }
143 cols.x1_mem = addrs.x1[i];
144 cols.x2_mem = addrs.x2[i];
145 cols.next_acc_addr = addrs.accs[i];
146 cols.next_acc_mult = acc_mults[i];
147 cols.felt_next_acc_addr = addrs.field_accs[i];
148 cols.felt_next_acc_mult = field_acc_mults[i];
149 cols.is_real = F::one();
150 row_cnt += 1;
151 });
152 });
153
154 unsafe {
155 let padding_start = row_cnt * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS;
156 let padding_size = (padded_nb_rows - row_cnt) * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS;
157 if padding_size > 0 {
158 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
159 }
160 }
161 }
162
163 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
164 let height = input.program.shape.as_ref().and_then(|shape| shape.height(self));
165 let events = &input.prefix_sum_checks_events;
166 Some(next_multiple_of_32(events.len(), height))
167 }
168
169 fn generate_trace_into(
170 &self,
171 input: &ExecutionRecord<F>,
172 _: &mut ExecutionRecord<F>,
173 buffer: &mut [MaybeUninit<F>],
174 ) {
175 assert!(
176 std::any::TypeId::of::<F>() == std::any::TypeId::of::<SP1Field>(),
177 "generate_trace_into only supports SP1Field"
178 );
179 let padded_nb_rows = <PrefixSumChecksChip as MachineAir<F>>::num_rows(self, input).unwrap();
180 let events = unsafe {
181 std::mem::transmute::<&Vec<PrefixSumChecksEvent<F>>, &Vec<PrefixSumChecksEvent<SP1Field>>>(
182 &input.prefix_sum_checks_events,
183 )
184 };
185 let num_event_rows = events.len();
186
187 unsafe {
188 let padding_start = num_event_rows * NUM_PREFIX_SUM_CHECKS_COLS;
189 let padding_size = (padded_nb_rows - num_event_rows) * NUM_PREFIX_SUM_CHECKS_COLS;
190 if padding_size > 0 {
191 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
192 }
193 }
194
195 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
196 let values = unsafe {
197 core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * NUM_PREFIX_SUM_CHECKS_COLS)
198 };
199
200 let populate_len = events.len() * NUM_PREFIX_SUM_CHECKS_COLS;
202 values[..populate_len]
203 .par_chunks_mut(NUM_PREFIX_SUM_CHECKS_COLS)
204 .zip_eq(events)
205 .for_each(|(row, vals)| {
206 let bb_event = unsafe {
207 std::mem::transmute::<
208 &PrefixSumChecksEvent<SP1Field>,
209 &PrefixSumChecksEvent<F>,
210 >(vals)
211 };
212 let cols: &mut PrefixSumChecksCols<_> = row.borrow_mut();
213 cols.x1 = bb_event.x1;
214 cols.x2 = bb_event.x2;
215 cols.acc = bb_event.acc;
216 cols.new_acc = bb_event.new_acc;
217 cols.felt_acc = bb_event.field_acc;
218 cols.felt_new_acc = bb_event.new_field_acc;
219 });
220 }
221
222 fn included(&self, _: &Self::Record) -> bool {
223 true
224 }
225}
226
227impl<AB> Air<AB> for PrefixSumChecksChip
228where
229 AB: SP1RecursionAirBuilder + PairBuilder,
230{
231 fn eval(&self, builder: &mut AB) {
232 let main = builder.main();
233 let local = main.row_slice(0);
234 let local: &PrefixSumChecksCols<AB::Var> = (*local).borrow();
235 let prep = builder.preprocessed();
236 let prep_local = prep.row_slice(0);
237 let prep_local: &PrefixSumChecksPreprocessedCols<_> = (*prep_local).borrow();
238
239 let x2 = local.x2.as_extension::<AB>();
240 let prod = BinomialExtension::from_base(local.x1.into()) * x2.clone();
241 let one: BinomialExtension<AB::Expr> = BinomialExtension::from_base(AB::Expr::one());
242 let two = AB::Expr::from_canonical_u32(2);
243
244 let sum_x_y = BinomialExtension::from_base(local.x1.into()) + x2;
245
246 builder.assert_bool(prep_local.is_real);
248
249 builder.assert_bool(local.x1);
251
252 builder.receive_single(prep_local.x1_mem, local.x1, prep_local.is_real);
254 builder.receive_block(prep_local.x2_mem, local.x2, prep_local.is_real);
255
256 builder.receive_block(prep_local.acc_addr, local.acc, prep_local.is_real);
258 builder.receive_single(prep_local.felt_acc_addr, local.felt_acc, prep_local.is_real);
259
260 builder.assert_ext_eq(
262 local.new_acc.as_extension::<AB>(),
263 local.acc.as_extension::<AB>() * (one - sum_x_y + prod.clone() + prod),
264 );
265 builder.assert_eq(local.felt_new_acc, local.x1 + two * local.felt_acc);
266
267 builder.send_block(prep_local.next_acc_addr, local.new_acc, prep_local.next_acc_mult);
269 builder.send_single(
270 prep_local.felt_next_acc_addr,
271 local.felt_new_acc,
272 prep_local.felt_next_acc_mult,
273 );
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use crate::test::test_recursion_linear_program;
280 use rand::{rngs::StdRng, Rng, SeedableRng};
281 use slop_algebra::{extension::BinomialExtensionField, AbstractExtensionField, AbstractField};
282
283 use sp1_recursion_executor::{instruction as instr, Instruction, MemAccessKind};
284
285 use slop_matrix::Matrix;
286 use sp1_hypercube::air::MachineAir;
287 use sp1_recursion_executor::ExecutionRecord;
288
289 use super::PrefixSumChecksChip;
290
291 use crate::chips::test_fixtures;
292
293 #[tokio::test]
294 async fn generate_trace() {
295 let shard = test_fixtures::shard().await;
296 let trace = PrefixSumChecksChip.generate_trace(shard, &mut ExecutionRecord::default());
297 assert!(trace.height() > test_fixtures::MIN_ROWS);
298 }
299
300 #[tokio::test]
301 async fn generate_preprocessed_trace() {
302 let program = &test_fixtures::program_with_input().await.0;
303 let trace = PrefixSumChecksChip.generate_preprocessed_trace(program).unwrap();
304 assert!(trace.height() > test_fixtures::MIN_ROWS);
305 }
306
307 #[tokio::test]
308 async fn test_prefix_sum_checks() {
309 use sp1_primitives::SP1Field;
310 type F = SP1Field;
311 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
312 let mut random_extfelt = move || {
313 let inner: [F; 4] = core::array::from_fn(|_| rng.sample(rand::distributions::Standard));
314 BinomialExtensionField::<F, 4>::from_base_slice(&inner)
315 };
316 let mut felt_rng = StdRng::seed_from_u64(0xDEADBEEF);
317 let mut random_felt = move || -> SP1Field {
318 if felt_rng.gen_bool(0.5) {
319 SP1Field::one()
320 } else {
321 SP1Field::zero()
322 }
323 };
324 let mut addr = 0;
325
326 let instructions = (0..10)
327 .flat_map(|_| {
328 let x1 = [random_felt(), random_felt()];
329 let one = BinomialExtensionField::<F, 4>::from_base(SP1Field::one());
330 let x2 = [random_extfelt(), random_extfelt()];
331
332 let mut result = one;
333 for i in 0..2 {
334 let prod = BinomialExtensionField::<F, 4>::from_base(x1[i]) * x2[i];
335 result *= one - (BinomialExtensionField::<F, 4>::from_base(x1[i]) + x2[i])
336 + prod
337 + prod;
338 }
339
340 let mut felt = SP1Field::zero();
341 let two = SP1Field::from_canonical_u32(2);
342 for &x1 in &x1 {
343 felt = x1 + two * felt;
344 }
345
346 let alloc_size = 10;
347 let a = (0..alloc_size).map(|x| x + addr).collect::<Vec<_>>();
348 addr += alloc_size;
349 [
350 instr::mem_single(MemAccessKind::Write, 1, a[0], x1[0]),
351 instr::mem_single(MemAccessKind::Write, 1, a[1], x1[1]),
352 instr::mem_ext(MemAccessKind::Write, 1, a[2], x2[0]),
353 instr::mem_ext(MemAccessKind::Write, 1, a[3], x2[1]),
354 instr::mem_ext(MemAccessKind::Write, 1, a[4], one),
355 instr::mem_single(MemAccessKind::Write, 1, a[5], SP1Field::zero()),
356 instr::prefix_sum_checks(
357 vec![1, 1],
358 vec![1, 1],
359 vec![F::from_canonical_u32(a[0]), F::from_canonical_u32(a[1])],
360 vec![F::from_canonical_u32(a[2]), F::from_canonical_u32(a[3])],
361 F::from_canonical_u32(a[5]),
362 F::from_canonical_u32(a[4]),
363 vec![F::from_canonical_u32(a[6]), F::from_canonical_u32(a[7])],
364 vec![F::from_canonical_u32(a[8]), F::from_canonical_u32(a[9])],
365 ),
366 instr::mem_ext(MemAccessKind::Read, 1, a[7], result),
367 instr::mem_single(MemAccessKind::Read, 1, a[9], felt),
368 ]
369 })
370 .collect::<Vec<Instruction<SP1Field>>>();
371
372 test_recursion_linear_program(instructions).await;
373 }
374}