Skip to main content

sp1_recursion_machine/chips/
prefix_sum_checks.rs

1use crate::builder::SP1RecursionAirBuilder;
2use core::borrow::Borrow;
3use slop_air::{Air, BaseAir, PairBuilder};
4use slop_algebra::{AbstractField, PrimeField32};
5use slop_matrix::Matrix;
6use slop_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut};
7use sp1_derive::AlignedBorrow;
8use sp1_hypercube::{
9    air::{BinomialExtension, MachineAir},
10    next_multiple_of_32,
11};
12
13use sp1_primitives::SP1Field;
14use sp1_recursion_executor::{
15    Address, Block, ExecutionRecord, Instruction, PrefixSumChecksEvent, PrefixSumChecksInstr,
16    RecursionProgram,
17};
18
19use std::{borrow::BorrowMut, mem::MaybeUninit};
20
21pub const NUM_PREFIX_SUM_CHECKS_COLS: usize = core::mem::size_of::<PrefixSumChecksCols<u8>>();
22pub const NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS: usize =
23    core::mem::size_of::<PrefixSumChecksPreprocessedCols<u8>>();
24
25#[derive(Clone, Debug, Copy, Default)]
26pub struct PrefixSumChecksChip;
27
28/// The main columns for a prefix-sum-checks invocation.
29#[derive(AlignedBorrow, Debug, Clone, Copy)]
30#[repr(C)]
31pub struct PrefixSumChecksCols<T: Copy> {
32    pub x1: T,
33    pub x2: Block<T>,
34    pub acc: Block<T>,
35    pub new_acc: Block<T>,
36    pub felt_acc: T,
37    pub felt_new_acc: T,
38}
39
40#[derive(AlignedBorrow, Clone, Copy, Debug)]
41#[repr(C)]
42pub struct PrefixSumChecksPreprocessedCols<T: Copy> {
43    pub x1_mem: Address<T>,
44    pub x2_mem: Address<T>,
45    pub acc_addr: Address<T>,
46    pub next_acc_addr: Address<T>,
47    pub next_acc_mult: T,
48    pub felt_acc_addr: Address<T>,
49    pub felt_next_acc_addr: Address<T>,
50    pub felt_next_acc_mult: T,
51    pub is_real: T,
52}
53
54impl<F> BaseAir<F> for PrefixSumChecksChip {
55    fn width(&self) -> usize {
56        NUM_PREFIX_SUM_CHECKS_COLS
57    }
58}
59
60impl<F: PrimeField32> MachineAir<F> for PrefixSumChecksChip {
61    type Record = ExecutionRecord<F>;
62
63    type Program = RecursionProgram<F>;
64
65    fn name(&self) -> &'static str {
66        "PrefixSumChecks"
67    }
68
69    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
70        // This is a no-op.
71    }
72
73    fn preprocessed_width(&self) -> usize {
74        NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS
75    }
76
77    fn preprocessed_num_rows(&self, program: &Self::Program) -> Option<usize> {
78        let instrs_len = program
79            .inner
80            .iter()
81            .filter_map(|instruction| match instruction.inner() {
82                Instruction::PrefixSumChecks(instr) => Some(instr.addrs.x1.len()),
83                _ => None,
84            })
85            .sum();
86        self.preprocessed_num_rows_with_instrs_len(program, instrs_len)
87    }
88
89    fn preprocessed_num_rows_with_instrs_len(
90        &self,
91        program: &Self::Program,
92        instrs_len: usize,
93    ) -> Option<usize> {
94        let height = program.shape.as_ref().and_then(|shape| shape.height(self));
95        Some(next_multiple_of_32(instrs_len, height))
96    }
97
98    fn generate_preprocessed_trace_into(
99        &self,
100        program: &Self::Program,
101        buffer: &mut [MaybeUninit<F>],
102    ) {
103        assert_eq!(
104            std::any::TypeId::of::<F>(),
105            std::any::TypeId::of::<SP1Field>(),
106            "generate_preprocessed_trace only supports SP1Field field"
107        );
108
109        let instrs = program
110            .inner
111            .iter()
112            .filter_map(|instruction| match instruction.inner() {
113                Instruction::PrefixSumChecks(x) => Some(x),
114                _ => None,
115            })
116            .collect::<Vec<_>>();
117
118        let padded_nb_rows = self.preprocessed_num_rows(program).unwrap();
119
120        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
121        let values = unsafe {
122            core::slice::from_raw_parts_mut(
123                buffer_ptr,
124                padded_nb_rows * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS,
125            )
126        };
127
128        let mut row_cnt = 0;
129        instrs.iter().for_each(|instruction| {
130            let PrefixSumChecksInstr { addrs, acc_mults, field_acc_mults } = instruction.as_ref();
131            let len = addrs.x1.len();
132            (0..len).for_each(|i| {
133                let start = row_cnt * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS;
134                let end = (row_cnt + 1) * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS;
135                let cols: &mut PrefixSumChecksPreprocessedCols<F> = values[start..end].borrow_mut();
136                if i == 0 {
137                    cols.acc_addr = addrs.one;
138                    cols.felt_acc_addr = addrs.zero;
139                } else {
140                    cols.acc_addr = addrs.accs[i - 1];
141                    cols.felt_acc_addr = addrs.field_accs[i - 1];
142                }
143                cols.x1_mem = addrs.x1[i];
144                cols.x2_mem = addrs.x2[i];
145                cols.next_acc_addr = addrs.accs[i];
146                cols.next_acc_mult = acc_mults[i];
147                cols.felt_next_acc_addr = addrs.field_accs[i];
148                cols.felt_next_acc_mult = field_acc_mults[i];
149                cols.is_real = F::one();
150                row_cnt += 1;
151            });
152        });
153
154        unsafe {
155            let padding_start = row_cnt * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS;
156            let padding_size = (padded_nb_rows - row_cnt) * NUM_PREFIX_SUM_CHECKS_PREPROCESSED_COLS;
157            if padding_size > 0 {
158                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
159            }
160        }
161    }
162
163    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
164        let height = input.program.shape.as_ref().and_then(|shape| shape.height(self));
165        let events = &input.prefix_sum_checks_events;
166        Some(next_multiple_of_32(events.len(), height))
167    }
168
169    fn generate_trace_into(
170        &self,
171        input: &ExecutionRecord<F>,
172        _: &mut ExecutionRecord<F>,
173        buffer: &mut [MaybeUninit<F>],
174    ) {
175        assert!(
176            std::any::TypeId::of::<F>() == std::any::TypeId::of::<SP1Field>(),
177            "generate_trace_into only supports SP1Field"
178        );
179        let padded_nb_rows = <PrefixSumChecksChip as MachineAir<F>>::num_rows(self, input).unwrap();
180        let events = unsafe {
181            std::mem::transmute::<&Vec<PrefixSumChecksEvent<F>>, &Vec<PrefixSumChecksEvent<SP1Field>>>(
182                &input.prefix_sum_checks_events,
183            )
184        };
185        let num_event_rows = events.len();
186
187        unsafe {
188            let padding_start = num_event_rows * NUM_PREFIX_SUM_CHECKS_COLS;
189            let padding_size = (padded_nb_rows - num_event_rows) * NUM_PREFIX_SUM_CHECKS_COLS;
190            if padding_size > 0 {
191                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
192            }
193        }
194
195        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
196        let values = unsafe {
197            core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * NUM_PREFIX_SUM_CHECKS_COLS)
198        };
199
200        // Generate the trace rows & corresponding records for each chunk of events in parallel.
201        let populate_len = events.len() * NUM_PREFIX_SUM_CHECKS_COLS;
202        values[..populate_len]
203            .par_chunks_mut(NUM_PREFIX_SUM_CHECKS_COLS)
204            .zip_eq(events)
205            .for_each(|(row, vals)| {
206                let bb_event = unsafe {
207                                    std::mem::transmute::<
208                                        &PrefixSumChecksEvent<SP1Field>,
209                                        &PrefixSumChecksEvent<F>,
210                                    >(vals)
211                                };
212                let cols: &mut PrefixSumChecksCols<_> = row.borrow_mut();
213                cols.x1 = bb_event.x1;
214                cols.x2 = bb_event.x2;
215                cols.acc = bb_event.acc;
216                cols.new_acc = bb_event.new_acc;
217                cols.felt_acc = bb_event.field_acc;
218                cols.felt_new_acc = bb_event.new_field_acc;
219            });
220    }
221
222    fn included(&self, _: &Self::Record) -> bool {
223        true
224    }
225}
226
227impl<AB> Air<AB> for PrefixSumChecksChip
228where
229    AB: SP1RecursionAirBuilder + PairBuilder,
230{
231    fn eval(&self, builder: &mut AB) {
232        let main = builder.main();
233        let local = main.row_slice(0);
234        let local: &PrefixSumChecksCols<AB::Var> = (*local).borrow();
235        let prep = builder.preprocessed();
236        let prep_local = prep.row_slice(0);
237        let prep_local: &PrefixSumChecksPreprocessedCols<_> = (*prep_local).borrow();
238
239        let x2 = local.x2.as_extension::<AB>();
240        let prod = BinomialExtension::from_base(local.x1.into()) * x2.clone();
241        let one: BinomialExtension<AB::Expr> = BinomialExtension::from_base(AB::Expr::one());
242        let two = AB::Expr::from_canonical_u32(2);
243
244        let sum_x_y = BinomialExtension::from_base(local.x1.into()) + x2;
245
246        // Check that `is_real` is boolean.
247        builder.assert_bool(prep_local.is_real);
248
249        // Booleanity check for x1.
250        builder.assert_bool(local.x1);
251
252        // Constrain the memory access for inputs.
253        builder.receive_single(prep_local.x1_mem, local.x1, prep_local.is_real);
254        builder.receive_block(prep_local.x2_mem, local.x2, prep_local.is_real);
255
256        // Constrain the memory read for the current accumulator.
257        builder.receive_block(prep_local.acc_addr, local.acc, prep_local.is_real);
258        builder.receive_single(prep_local.felt_acc_addr, local.felt_acc, prep_local.is_real);
259
260        // Constrain the memory write for the next accumulator for lagrange eval and bit2felt.
261        builder.assert_ext_eq(
262            local.new_acc.as_extension::<AB>(),
263            local.acc.as_extension::<AB>() * (one - sum_x_y + prod.clone() + prod),
264        );
265        builder.assert_eq(local.felt_new_acc, local.x1 + two * local.felt_acc);
266
267        // Constrain the memory write for the output accumulator.
268        builder.send_block(prep_local.next_acc_addr, local.new_acc, prep_local.next_acc_mult);
269        builder.send_single(
270            prep_local.felt_next_acc_addr,
271            local.felt_new_acc,
272            prep_local.felt_next_acc_mult,
273        );
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use crate::test::test_recursion_linear_program;
280    use rand::{rngs::StdRng, Rng, SeedableRng};
281    use slop_algebra::{extension::BinomialExtensionField, AbstractExtensionField, AbstractField};
282
283    use sp1_recursion_executor::{instruction as instr, Instruction, MemAccessKind};
284
285    use slop_matrix::Matrix;
286    use sp1_hypercube::air::MachineAir;
287    use sp1_recursion_executor::ExecutionRecord;
288
289    use super::PrefixSumChecksChip;
290
291    use crate::chips::test_fixtures;
292
293    #[tokio::test]
294    async fn generate_trace() {
295        let shard = test_fixtures::shard().await;
296        let trace = PrefixSumChecksChip.generate_trace(shard, &mut ExecutionRecord::default());
297        assert!(trace.height() > test_fixtures::MIN_ROWS);
298    }
299
300    #[tokio::test]
301    async fn generate_preprocessed_trace() {
302        let program = &test_fixtures::program_with_input().await.0;
303        let trace = PrefixSumChecksChip.generate_preprocessed_trace(program).unwrap();
304        assert!(trace.height() > test_fixtures::MIN_ROWS);
305    }
306
307    #[tokio::test]
308    async fn test_prefix_sum_checks() {
309        use sp1_primitives::SP1Field;
310        type F = SP1Field;
311        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
312        let mut random_extfelt = move || {
313            let inner: [F; 4] = core::array::from_fn(|_| rng.sample(rand::distributions::Standard));
314            BinomialExtensionField::<F, 4>::from_base_slice(&inner)
315        };
316        let mut felt_rng = StdRng::seed_from_u64(0xDEADBEEF);
317        let mut random_felt = move || -> SP1Field {
318            if felt_rng.gen_bool(0.5) {
319                SP1Field::one()
320            } else {
321                SP1Field::zero()
322            }
323        };
324        let mut addr = 0;
325
326        let instructions = (0..10)
327            .flat_map(|_| {
328                let x1 = [random_felt(), random_felt()];
329                let one = BinomialExtensionField::<F, 4>::from_base(SP1Field::one());
330                let x2 = [random_extfelt(), random_extfelt()];
331
332                let mut result = one;
333                for i in 0..2 {
334                    let prod = BinomialExtensionField::<F, 4>::from_base(x1[i]) * x2[i];
335                    result *= one - (BinomialExtensionField::<F, 4>::from_base(x1[i]) + x2[i])
336                        + prod
337                        + prod;
338                }
339
340                let mut felt = SP1Field::zero();
341                let two = SP1Field::from_canonical_u32(2);
342                for &x1 in &x1 {
343                    felt = x1 + two * felt;
344                }
345
346                let alloc_size = 10;
347                let a = (0..alloc_size).map(|x| x + addr).collect::<Vec<_>>();
348                addr += alloc_size;
349                [
350                    instr::mem_single(MemAccessKind::Write, 1, a[0], x1[0]),
351                    instr::mem_single(MemAccessKind::Write, 1, a[1], x1[1]),
352                    instr::mem_ext(MemAccessKind::Write, 1, a[2], x2[0]),
353                    instr::mem_ext(MemAccessKind::Write, 1, a[3], x2[1]),
354                    instr::mem_ext(MemAccessKind::Write, 1, a[4], one),
355                    instr::mem_single(MemAccessKind::Write, 1, a[5], SP1Field::zero()),
356                    instr::prefix_sum_checks(
357                        vec![1, 1],
358                        vec![1, 1],
359                        vec![F::from_canonical_u32(a[0]), F::from_canonical_u32(a[1])],
360                        vec![F::from_canonical_u32(a[2]), F::from_canonical_u32(a[3])],
361                        F::from_canonical_u32(a[5]),
362                        F::from_canonical_u32(a[4]),
363                        vec![F::from_canonical_u32(a[6]), F::from_canonical_u32(a[7])],
364                        vec![F::from_canonical_u32(a[8]), F::from_canonical_u32(a[9])],
365                    ),
366                    instr::mem_ext(MemAccessKind::Read, 1, a[7], result),
367                    instr::mem_single(MemAccessKind::Read, 1, a[9], felt),
368                ]
369            })
370            .collect::<Vec<Instruction<SP1Field>>>();
371
372        test_recursion_linear_program(instructions).await;
373    }
374}