sp1_core_machine/alu/add_sub/
mod.rs

1use core::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use hashbrown::HashMap;
7use itertools::Itertools;
8use p3_air::{Air, AirBuilder, BaseAir};
9use p3_field::{AbstractField, PrimeField, PrimeField32};
10use p3_matrix::{dense::RowMajorMatrix, Matrix};
11use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator};
12use sp1_core_executor::{
13    events::{AluEvent, ByteLookupEvent, ByteRecord},
14    ExecutionRecord, Opcode, Program, DEFAULT_PC_INC,
15};
16use sp1_derive::AlignedBorrow;
17use sp1_stark::{
18    air::{MachineAir, SP1AirBuilder},
19    Word,
20};
21
22use crate::{
23    operations::AddOperation,
24    utils::{next_power_of_two, zeroed_f_vec},
25};
26
27/// The number of main trace columns for `AddSubChip`.
28pub const NUM_ADD_SUB_COLS: usize = size_of::<AddSubCols<u8>>();
29
30/// A chip that implements addition for the opcode ADD and SUB.
31///
32/// SUB is basically an ADD with a re-arrangement of the operands and result.
33/// E.g. given the standard ALU op variable name and positioning of `a` = `b` OP `c`,
34/// `a` = `b` + `c` should be verified for ADD, and `b` = `a` + `c` (e.g. `a` = `b` - `c`)
35/// should be verified for SUB.
36#[derive(Default)]
37pub struct AddSubChip;
38
39/// The column layout for the chip.
40#[derive(AlignedBorrow, Default, Clone, Copy)]
41#[repr(C)]
42pub struct AddSubCols<T> {
43    /// The program counter.
44    pub pc: T,
45
46    /// Instance of `AddOperation` to handle addition logic in `AddSubChip`'s ALU operations.
47    /// It's result will be `a` for the add operation and `b` for the sub operation.
48    pub add_operation: AddOperation<T>,
49
50    /// The first input operand.  This will be `b` for add operations and `a` for sub operations.
51    pub operand_1: Word<T>,
52
53    /// The second input operand.  This will be `c` for both operations.
54    pub operand_2: Word<T>,
55
56    /// Whether the first operand is not register 0.
57    pub op_a_not_0: T,
58
59    /// Boolean to indicate whether the row is for an add operation.
60    pub is_add: T,
61
62    /// Boolean to indicate whether the row is for a sub operation.
63    pub is_sub: T,
64}
65
66impl<F: PrimeField32> MachineAir<F> for AddSubChip {
67    type Record = ExecutionRecord;
68
69    type Program = Program;
70
71    fn name(&self) -> String {
72        "AddSub".to_string()
73    }
74
75    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
76        let nb_rows = next_power_of_two(
77            input.add_events.len() + input.sub_events.len(),
78            input.fixed_log2_rows::<F, _>(self),
79        );
80        Some(nb_rows)
81    }
82
83    fn generate_trace(
84        &self,
85        input: &ExecutionRecord,
86        _: &mut ExecutionRecord,
87    ) -> RowMajorMatrix<F> {
88        // Generate the rows for the trace.
89        let chunk_size =
90            std::cmp::max((input.add_events.len() + input.sub_events.len()) / num_cpus::get(), 1);
91        let merged_events =
92            input.add_events.iter().chain(input.sub_events.iter()).collect::<Vec<_>>();
93        let padded_nb_rows = <AddSubChip as MachineAir<F>>::num_rows(self, input).unwrap();
94        let mut values = zeroed_f_vec(padded_nb_rows * NUM_ADD_SUB_COLS);
95
96        values.chunks_mut(chunk_size * NUM_ADD_SUB_COLS).enumerate().par_bridge().for_each(
97            |(i, rows)| {
98                rows.chunks_mut(NUM_ADD_SUB_COLS).enumerate().for_each(|(j, row)| {
99                    let idx = i * chunk_size + j;
100                    let cols: &mut AddSubCols<F> = row.borrow_mut();
101
102                    if idx < merged_events.len() {
103                        let mut byte_lookup_events = Vec::new();
104                        let event = &merged_events[idx];
105                        self.event_to_row(event, cols, &mut byte_lookup_events);
106                    }
107                });
108            },
109        );
110
111        // Convert the trace to a row major matrix.
112        RowMajorMatrix::new(values, NUM_ADD_SUB_COLS)
113    }
114
115    fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
116        let chunk_size =
117            std::cmp::max((input.add_events.len() + input.sub_events.len()) / num_cpus::get(), 1);
118
119        let event_iter =
120            input.add_events.chunks(chunk_size).chain(input.sub_events.chunks(chunk_size));
121
122        let blu_batches = event_iter
123            .par_bridge()
124            .map(|events| {
125                let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
126                events.iter().for_each(|event| {
127                    let mut row = [F::zero(); NUM_ADD_SUB_COLS];
128                    let cols: &mut AddSubCols<F> = row.as_mut_slice().borrow_mut();
129                    self.event_to_row(event, cols, &mut blu);
130                });
131                blu
132            })
133            .collect::<Vec<_>>();
134
135        output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
136    }
137
138    fn included(&self, shard: &Self::Record) -> bool {
139        if let Some(shape) = shard.shape.as_ref() {
140            shape.included::<F, _>(self)
141        } else {
142            !shard.add_events.is_empty()
143        }
144    }
145
146    fn local_only(&self) -> bool {
147        true
148    }
149}
150
151impl AddSubChip {
152    /// Create a row from an event.
153    fn event_to_row<F: PrimeField>(
154        &self,
155        event: &AluEvent,
156        cols: &mut AddSubCols<F>,
157        blu: &mut impl ByteRecord,
158    ) {
159        cols.pc = F::from_canonical_u32(event.pc);
160
161        let is_add = event.opcode == Opcode::ADD;
162        cols.is_add = F::from_bool(is_add);
163        cols.is_sub = F::from_bool(!is_add);
164
165        let operand_1 = if is_add { event.b } else { event.a };
166        let operand_2 = event.c;
167
168        cols.add_operation.populate(blu, operand_1, operand_2);
169        cols.operand_1 = Word::from(operand_1);
170        cols.operand_2 = Word::from(operand_2);
171        cols.op_a_not_0 = F::from_bool(!event.op_a_0);
172    }
173}
174
175impl<F> BaseAir<F> for AddSubChip {
176    fn width(&self) -> usize {
177        NUM_ADD_SUB_COLS
178    }
179}
180
181impl<AB> Air<AB> for AddSubChip
182where
183    AB: SP1AirBuilder,
184{
185    fn eval(&self, builder: &mut AB) {
186        let main = builder.main();
187        let local = main.row_slice(0);
188        let local: &AddSubCols<AB::Var> = (*local).borrow();
189
190        // SAFETY: All selectors `is_add` and `is_sub` are checked to be boolean.
191        // Each "real" row has exactly one selector turned on, as `is_real = is_add + is_sub` is
192        // boolean. Therefore, the `opcode` matches the corresponding opcode of the
193        // instruction.
194        let is_real = local.is_add + local.is_sub;
195        builder.assert_bool(local.is_add);
196        builder.assert_bool(local.is_sub);
197        builder.assert_bool(is_real.clone());
198
199        let opcode = AB::Expr::from_f(Opcode::ADD.as_field()) * local.is_add +
200            AB::Expr::from_f(Opcode::SUB.as_field()) * local.is_sub;
201
202        // Evaluate the addition operation.
203        // This is enforced only when `op_a_not_0 == 1`.
204        // `op_a_val` doesn't need to be constrained when `op_a_not_0 == 0`.
205        AddOperation::<AB::F>::eval(
206            builder,
207            local.operand_1,
208            local.operand_2,
209            local.add_operation,
210            local.op_a_not_0.into(),
211        );
212
213        // SAFETY: We check that a padding row has `op_a_not_0 == 0`, to prevent a padding row
214        // sending byte lookups.
215        builder.when(local.op_a_not_0).assert_one(is_real.clone());
216
217        // Receive the arguments.  There are separate receives for ADD and SUB.
218        // For add, `add_operation.value` is `a`, `operand_1` is `b`, and `operand_2` is `c`.
219        // SAFETY: This checks the following. Note that in this case `opcode = Opcode::ADD`
220        // - `next_pc = pc + 4`
221        // - `num_extra_cycles = 0`
222        // - `op_a_val` is constrained by the `AddOperation` when `op_a_not_0 == 1`
223        // - `op_a_not_0` is correct, due to the sent `op_a_0` being equal to `1 - op_a_not_0`
224        // - `op_a_immutable = 0`
225        // - `is_memory = 0`
226        // - `is_syscall = 0`
227        // - `is_halt = 0`
228        builder.receive_instruction(
229            AB::Expr::zero(),
230            AB::Expr::zero(),
231            local.pc,
232            local.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_INC),
233            AB::Expr::zero(),
234            opcode.clone(),
235            local.add_operation.value,
236            local.operand_1,
237            local.operand_2,
238            AB::Expr::one() - local.op_a_not_0,
239            AB::Expr::zero(),
240            AB::Expr::zero(),
241            AB::Expr::zero(),
242            AB::Expr::zero(),
243            local.is_add,
244        );
245
246        // For sub, `operand_1` is `a`, `add_operation.value` is `b`, and `operand_2` is `c`.
247        // SAFETY: This checks the following. Note that in this case `opcode = Opcode::SUB`
248        // - `next_pc = pc + 4`
249        // - `num_extra_cycles = 0`
250        // - `op_a_val` is constrained by the `AddOperation` when `op_a_not_0 == 1`
251        // - `op_a_not_0` is correct, due to the sent `op_a_0` being equal to `1 - op_a_not_0`
252        // - `op_a_immutable = 0`
253        // - `is_memory = 0`
254        // - `is_syscall = 0`
255        // - `is_halt = 0`
256        builder.receive_instruction(
257            AB::Expr::zero(),
258            AB::Expr::zero(),
259            local.pc,
260            local.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_INC),
261            AB::Expr::zero(),
262            opcode,
263            local.operand_1,
264            local.add_operation.value,
265            local.operand_2,
266            AB::Expr::one() - local.op_a_not_0,
267            AB::Expr::zero(),
268            AB::Expr::zero(),
269            AB::Expr::zero(),
270            AB::Expr::zero(),
271            local.is_sub,
272        );
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    #![allow(clippy::print_stdout)]
279
280    use p3_baby_bear::BabyBear;
281    use p3_matrix::dense::RowMajorMatrix;
282    use rand::{thread_rng, Rng};
283    use sp1_core_executor::{
284        events::{AluEvent, MemoryRecordEnum},
285        ExecutionRecord, Instruction, Opcode, DEFAULT_PC_INC,
286    };
287    use sp1_stark::{
288        air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, chip_name, CpuProver,
289        MachineProver, StarkGenericConfig, Val,
290    };
291    use std::sync::LazyLock;
292
293    use super::*;
294    use crate::{
295        io::SP1Stdin,
296        riscv::RiscvAir,
297        utils::{run_malicious_test, uni_stark_prove as prove, uni_stark_verify as verify},
298    };
299
300    /// Lazily initialized record for use across multiple tests.
301    /// Consists of random `ADD` and `SUB` instructions.
302    static SHARD: LazyLock<ExecutionRecord> = LazyLock::new(|| {
303        let add_events = (0..1)
304            .flat_map(|i| {
305                [{
306                    let operand_1 = 1u32;
307                    let operand_2 = 2u32;
308                    let result = operand_1.wrapping_add(operand_2);
309                    AluEvent::new(i % 2, Opcode::ADD, result, operand_1, operand_2, false)
310                }]
311            })
312            .collect::<Vec<_>>();
313        let _sub_events = (0..255)
314            .flat_map(|i| {
315                [{
316                    let operand_1 = thread_rng().gen_range(0..u32::MAX);
317                    let operand_2 = thread_rng().gen_range(0..u32::MAX);
318                    let result = operand_1.wrapping_add(operand_2);
319                    AluEvent::new(i % 2, Opcode::SUB, result, operand_1, operand_2, false)
320                }]
321            })
322            .collect::<Vec<_>>();
323        ExecutionRecord { add_events, ..Default::default() }
324    });
325
326    #[test]
327    fn generate_trace() {
328        let mut shard = ExecutionRecord::default();
329        shard.add_events = vec![AluEvent::new(0, Opcode::ADD, 14, 8, 6, false)];
330        let chip = AddSubChip::default();
331        let trace: RowMajorMatrix<BabyBear> =
332            chip.generate_trace(&shard, &mut ExecutionRecord::default());
333        println!("{:?}", trace.values)
334    }
335
336    #[test]
337    fn prove_babybear() {
338        let config = BabyBearPoseidon2::new();
339        let mut challenger = config.challenger();
340
341        let mut shard = ExecutionRecord::default();
342        for i in 0..1 {
343            let operand_1 = thread_rng().gen_range(0..u32::MAX);
344            let operand_2 = thread_rng().gen_range(0..u32::MAX);
345            let result = operand_1.wrapping_add(operand_2);
346            shard.add_events.push(AluEvent::new(
347                i * DEFAULT_PC_INC,
348                Opcode::ADD,
349                result,
350                operand_1,
351                operand_2,
352                false,
353            ));
354        }
355        for i in 0..255 {
356            let operand_1 = thread_rng().gen_range(0..u32::MAX);
357            let operand_2 = thread_rng().gen_range(0..u32::MAX);
358            let result = operand_1.wrapping_sub(operand_2);
359            shard.add_events.push(AluEvent::new(
360                i * DEFAULT_PC_INC,
361                Opcode::SUB,
362                result,
363                operand_1,
364                operand_2,
365                false,
366            ));
367        }
368
369        let chip = AddSubChip::default();
370        let trace: RowMajorMatrix<BabyBear> =
371            chip.generate_trace(&shard, &mut ExecutionRecord::default());
372        let proof = prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
373
374        let mut challenger = config.challenger();
375        verify(&config, &chip, &mut challenger, &proof).unwrap();
376    }
377
378    #[cfg(feature = "sys")]
379    #[test]
380    fn test_generate_trace_ffi_eq_rust() {
381        let shard = LazyLock::force(&SHARD);
382
383        let chip = AddSubChip::default();
384        let trace: RowMajorMatrix<BabyBear> =
385            chip.generate_trace(shard, &mut ExecutionRecord::default());
386        let trace_ffi = generate_trace_ffi(shard);
387
388        assert_eq!(trace_ffi, trace);
389    }
390
391    #[cfg(feature = "sys")]
392    fn generate_trace_ffi(input: &ExecutionRecord) -> RowMajorMatrix<BabyBear> {
393        use rayon::slice::ParallelSlice;
394
395        use crate::utils::pad_rows_fixed;
396
397        type F = BabyBear;
398
399        let chunk_size =
400            std::cmp::max((input.add_events.len() + input.sub_events.len()) / num_cpus::get(), 1);
401
402        let events = input.add_events.iter().chain(input.sub_events.iter()).collect::<Vec<_>>();
403        let row_batches = events
404            .par_chunks(chunk_size)
405            .map(|events| {
406                let rows = events
407                    .iter()
408                    .map(|event| {
409                        let mut row = [F::zero(); NUM_ADD_SUB_COLS];
410                        let cols: &mut AddSubCols<F> = row.as_mut_slice().borrow_mut();
411                        unsafe {
412                            crate::sys::add_sub_event_to_row_babybear(event, cols);
413                        }
414                        row
415                    })
416                    .collect::<Vec<_>>();
417                rows
418            })
419            .collect::<Vec<_>>();
420
421        let mut rows: Vec<[F; NUM_ADD_SUB_COLS]> = vec![];
422        for row_batch in row_batches {
423            rows.extend(row_batch);
424        }
425
426        pad_rows_fixed(&mut rows, || [F::zero(); NUM_ADD_SUB_COLS], None);
427
428        // Convert the trace to a row major matrix.
429        RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_ADD_SUB_COLS)
430    }
431
432    #[test]
433    fn test_malicious_add_sub() {
434        const NUM_TESTS: usize = 5;
435
436        for opcode in [Opcode::ADD, Opcode::SUB] {
437            for _ in 0..NUM_TESTS {
438                let op_a = thread_rng().gen_range(0..u32::MAX);
439                let op_b = thread_rng().gen_range(0..u32::MAX);
440                let op_c = thread_rng().gen_range(0..u32::MAX);
441
442                let correct_op_a = if opcode == Opcode::ADD {
443                    op_b.wrapping_add(op_c)
444                } else {
445                    op_b.wrapping_sub(op_c)
446                };
447
448                assert!(op_a != correct_op_a);
449
450                let instructions = vec![
451                    Instruction::new(opcode, 5, op_b, op_c, true, true),
452                    Instruction::new(Opcode::ADD, 10, 0, 0, false, false),
453                ];
454                let program = Program::new(instructions, 0, 0);
455                let stdin = SP1Stdin::new();
456
457                type P = CpuProver<BabyBearPoseidon2, RiscvAir<BabyBear>>;
458
459                let malicious_trace_pv_generator = move |prover: &P,
460                                                         record: &mut ExecutionRecord|
461                      -> Vec<(
462                    String,
463                    RowMajorMatrix<Val<BabyBearPoseidon2>>,
464                )> {
465                    let mut malicious_record = record.clone();
466                    malicious_record.cpu_events[0].a = op_a;
467                    if let Some(MemoryRecordEnum::Write(mut write_record)) =
468                        malicious_record.cpu_events[0].a_record
469                    {
470                        write_record.value = op_a;
471                    }
472                    if opcode == Opcode::ADD {
473                        malicious_record.add_events[0].a = op_a;
474                    } else if opcode == Opcode::SUB {
475                        malicious_record.sub_events[0].a = op_a;
476                    } else {
477                        unreachable!()
478                    }
479
480                    let mut traces = prover.generate_traces(&malicious_record);
481
482                    let add_sub_chip_name = chip_name!(AddSubChip, BabyBear);
483                    for (chip_name, trace) in traces.iter_mut() {
484                        if *chip_name == add_sub_chip_name {
485                            // Add the add instructions are added first to the trace, before the sub
486                            // instructions.
487                            let index = if opcode == Opcode::ADD { 0 } else { 1 };
488
489                            let first_row = trace.row_mut(index);
490                            let first_row: &mut AddSubCols<BabyBear> = first_row.borrow_mut();
491                            if opcode == Opcode::ADD {
492                                first_row.add_operation.value = op_a.into();
493                            } else {
494                                first_row.add_operation.value = op_b.into();
495                            }
496                        }
497                    }
498
499                    traces
500                };
501
502                let result =
503                    run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
504                println!("Result for {opcode:?}: {result:?}");
505                let add_sub_chip_name = chip_name!(AddSubChip, BabyBear);
506                assert!(
507                    result.is_err() &&
508                        result.unwrap_err().is_constraints_failing(&add_sub_chip_name)
509                );
510            }
511        }
512    }
513}