1use core::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use hashbrown::HashMap;
7use itertools::Itertools;
8use p3_air::{Air, AirBuilder, BaseAir};
9use p3_field::{AbstractField, PrimeField, PrimeField32};
10use p3_matrix::{dense::RowMajorMatrix, Matrix};
11use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator};
12use sp1_core_executor::{
13 events::{AluEvent, ByteLookupEvent, ByteRecord},
14 ExecutionRecord, Opcode, Program, DEFAULT_PC_INC,
15};
16use sp1_derive::AlignedBorrow;
17use sp1_stark::{
18 air::{MachineAir, SP1AirBuilder},
19 Word,
20};
21
22use crate::{
23 operations::AddOperation,
24 utils::{next_power_of_two, zeroed_f_vec},
25};
26
27pub const NUM_ADD_SUB_COLS: usize = size_of::<AddSubCols<u8>>();
29
30#[derive(Default)]
37pub struct AddSubChip;
38
39#[derive(AlignedBorrow, Default, Clone, Copy)]
41#[repr(C)]
42pub struct AddSubCols<T> {
43 pub pc: T,
45
46 pub add_operation: AddOperation<T>,
49
50 pub operand_1: Word<T>,
52
53 pub operand_2: Word<T>,
55
56 pub op_a_not_0: T,
58
59 pub is_add: T,
61
62 pub is_sub: T,
64}
65
66impl<F: PrimeField32> MachineAir<F> for AddSubChip {
67 type Record = ExecutionRecord;
68
69 type Program = Program;
70
71 fn name(&self) -> String {
72 "AddSub".to_string()
73 }
74
75 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
76 let nb_rows = next_power_of_two(
77 input.add_events.len() + input.sub_events.len(),
78 input.fixed_log2_rows::<F, _>(self),
79 );
80 Some(nb_rows)
81 }
82
83 fn generate_trace(
84 &self,
85 input: &ExecutionRecord,
86 _: &mut ExecutionRecord,
87 ) -> RowMajorMatrix<F> {
88 let chunk_size =
90 std::cmp::max((input.add_events.len() + input.sub_events.len()) / num_cpus::get(), 1);
91 let merged_events =
92 input.add_events.iter().chain(input.sub_events.iter()).collect::<Vec<_>>();
93 let padded_nb_rows = <AddSubChip as MachineAir<F>>::num_rows(self, input).unwrap();
94 let mut values = zeroed_f_vec(padded_nb_rows * NUM_ADD_SUB_COLS);
95
96 values.chunks_mut(chunk_size * NUM_ADD_SUB_COLS).enumerate().par_bridge().for_each(
97 |(i, rows)| {
98 rows.chunks_mut(NUM_ADD_SUB_COLS).enumerate().for_each(|(j, row)| {
99 let idx = i * chunk_size + j;
100 let cols: &mut AddSubCols<F> = row.borrow_mut();
101
102 if idx < merged_events.len() {
103 let mut byte_lookup_events = Vec::new();
104 let event = &merged_events[idx];
105 self.event_to_row(event, cols, &mut byte_lookup_events);
106 }
107 });
108 },
109 );
110
111 RowMajorMatrix::new(values, NUM_ADD_SUB_COLS)
113 }
114
115 fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
116 let chunk_size =
117 std::cmp::max((input.add_events.len() + input.sub_events.len()) / num_cpus::get(), 1);
118
119 let event_iter =
120 input.add_events.chunks(chunk_size).chain(input.sub_events.chunks(chunk_size));
121
122 let blu_batches = event_iter
123 .par_bridge()
124 .map(|events| {
125 let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
126 events.iter().for_each(|event| {
127 let mut row = [F::zero(); NUM_ADD_SUB_COLS];
128 let cols: &mut AddSubCols<F> = row.as_mut_slice().borrow_mut();
129 self.event_to_row(event, cols, &mut blu);
130 });
131 blu
132 })
133 .collect::<Vec<_>>();
134
135 output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
136 }
137
138 fn included(&self, shard: &Self::Record) -> bool {
139 if let Some(shape) = shard.shape.as_ref() {
140 shape.included::<F, _>(self)
141 } else {
142 !shard.add_events.is_empty()
143 }
144 }
145
146 fn local_only(&self) -> bool {
147 true
148 }
149}
150
151impl AddSubChip {
152 fn event_to_row<F: PrimeField>(
154 &self,
155 event: &AluEvent,
156 cols: &mut AddSubCols<F>,
157 blu: &mut impl ByteRecord,
158 ) {
159 cols.pc = F::from_canonical_u32(event.pc);
160
161 let is_add = event.opcode == Opcode::ADD;
162 cols.is_add = F::from_bool(is_add);
163 cols.is_sub = F::from_bool(!is_add);
164
165 let operand_1 = if is_add { event.b } else { event.a };
166 let operand_2 = event.c;
167
168 cols.add_operation.populate(blu, operand_1, operand_2);
169 cols.operand_1 = Word::from(operand_1);
170 cols.operand_2 = Word::from(operand_2);
171 cols.op_a_not_0 = F::from_bool(!event.op_a_0);
172 }
173}
174
175impl<F> BaseAir<F> for AddSubChip {
176 fn width(&self) -> usize {
177 NUM_ADD_SUB_COLS
178 }
179}
180
181impl<AB> Air<AB> for AddSubChip
182where
183 AB: SP1AirBuilder,
184{
185 fn eval(&self, builder: &mut AB) {
186 let main = builder.main();
187 let local = main.row_slice(0);
188 let local: &AddSubCols<AB::Var> = (*local).borrow();
189
190 let is_real = local.is_add + local.is_sub;
195 builder.assert_bool(local.is_add);
196 builder.assert_bool(local.is_sub);
197 builder.assert_bool(is_real.clone());
198
199 let opcode = AB::Expr::from_f(Opcode::ADD.as_field()) * local.is_add +
200 AB::Expr::from_f(Opcode::SUB.as_field()) * local.is_sub;
201
202 AddOperation::<AB::F>::eval(
206 builder,
207 local.operand_1,
208 local.operand_2,
209 local.add_operation,
210 local.op_a_not_0.into(),
211 );
212
213 builder.when(local.op_a_not_0).assert_one(is_real.clone());
216
217 builder.receive_instruction(
229 AB::Expr::zero(),
230 AB::Expr::zero(),
231 local.pc,
232 local.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_INC),
233 AB::Expr::zero(),
234 opcode.clone(),
235 local.add_operation.value,
236 local.operand_1,
237 local.operand_2,
238 AB::Expr::one() - local.op_a_not_0,
239 AB::Expr::zero(),
240 AB::Expr::zero(),
241 AB::Expr::zero(),
242 AB::Expr::zero(),
243 local.is_add,
244 );
245
246 builder.receive_instruction(
257 AB::Expr::zero(),
258 AB::Expr::zero(),
259 local.pc,
260 local.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_INC),
261 AB::Expr::zero(),
262 opcode,
263 local.operand_1,
264 local.add_operation.value,
265 local.operand_2,
266 AB::Expr::one() - local.op_a_not_0,
267 AB::Expr::zero(),
268 AB::Expr::zero(),
269 AB::Expr::zero(),
270 AB::Expr::zero(),
271 local.is_sub,
272 );
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 #![allow(clippy::print_stdout)]
279
280 use p3_baby_bear::BabyBear;
281 use p3_matrix::dense::RowMajorMatrix;
282 use rand::{thread_rng, Rng};
283 use sp1_core_executor::{
284 events::{AluEvent, MemoryRecordEnum},
285 ExecutionRecord, Instruction, Opcode, DEFAULT_PC_INC,
286 };
287 use sp1_stark::{
288 air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, chip_name, CpuProver,
289 MachineProver, StarkGenericConfig, Val,
290 };
291 use std::sync::LazyLock;
292
293 use super::*;
294 use crate::{
295 io::SP1Stdin,
296 riscv::RiscvAir,
297 utils::{run_malicious_test, uni_stark_prove as prove, uni_stark_verify as verify},
298 };
299
300 static SHARD: LazyLock<ExecutionRecord> = LazyLock::new(|| {
303 let add_events = (0..1)
304 .flat_map(|i| {
305 [{
306 let operand_1 = 1u32;
307 let operand_2 = 2u32;
308 let result = operand_1.wrapping_add(operand_2);
309 AluEvent::new(i % 2, Opcode::ADD, result, operand_1, operand_2, false)
310 }]
311 })
312 .collect::<Vec<_>>();
313 let _sub_events = (0..255)
314 .flat_map(|i| {
315 [{
316 let operand_1 = thread_rng().gen_range(0..u32::MAX);
317 let operand_2 = thread_rng().gen_range(0..u32::MAX);
318 let result = operand_1.wrapping_add(operand_2);
319 AluEvent::new(i % 2, Opcode::SUB, result, operand_1, operand_2, false)
320 }]
321 })
322 .collect::<Vec<_>>();
323 ExecutionRecord { add_events, ..Default::default() }
324 });
325
326 #[test]
327 fn generate_trace() {
328 let mut shard = ExecutionRecord::default();
329 shard.add_events = vec![AluEvent::new(0, Opcode::ADD, 14, 8, 6, false)];
330 let chip = AddSubChip::default();
331 let trace: RowMajorMatrix<BabyBear> =
332 chip.generate_trace(&shard, &mut ExecutionRecord::default());
333 println!("{:?}", trace.values)
334 }
335
336 #[test]
337 fn prove_babybear() {
338 let config = BabyBearPoseidon2::new();
339 let mut challenger = config.challenger();
340
341 let mut shard = ExecutionRecord::default();
342 for i in 0..1 {
343 let operand_1 = thread_rng().gen_range(0..u32::MAX);
344 let operand_2 = thread_rng().gen_range(0..u32::MAX);
345 let result = operand_1.wrapping_add(operand_2);
346 shard.add_events.push(AluEvent::new(
347 i * DEFAULT_PC_INC,
348 Opcode::ADD,
349 result,
350 operand_1,
351 operand_2,
352 false,
353 ));
354 }
355 for i in 0..255 {
356 let operand_1 = thread_rng().gen_range(0..u32::MAX);
357 let operand_2 = thread_rng().gen_range(0..u32::MAX);
358 let result = operand_1.wrapping_sub(operand_2);
359 shard.add_events.push(AluEvent::new(
360 i * DEFAULT_PC_INC,
361 Opcode::SUB,
362 result,
363 operand_1,
364 operand_2,
365 false,
366 ));
367 }
368
369 let chip = AddSubChip::default();
370 let trace: RowMajorMatrix<BabyBear> =
371 chip.generate_trace(&shard, &mut ExecutionRecord::default());
372 let proof = prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
373
374 let mut challenger = config.challenger();
375 verify(&config, &chip, &mut challenger, &proof).unwrap();
376 }
377
378 #[cfg(feature = "sys")]
379 #[test]
380 fn test_generate_trace_ffi_eq_rust() {
381 let shard = LazyLock::force(&SHARD);
382
383 let chip = AddSubChip::default();
384 let trace: RowMajorMatrix<BabyBear> =
385 chip.generate_trace(shard, &mut ExecutionRecord::default());
386 let trace_ffi = generate_trace_ffi(shard);
387
388 assert_eq!(trace_ffi, trace);
389 }
390
391 #[cfg(feature = "sys")]
392 fn generate_trace_ffi(input: &ExecutionRecord) -> RowMajorMatrix<BabyBear> {
393 use rayon::slice::ParallelSlice;
394
395 use crate::utils::pad_rows_fixed;
396
397 type F = BabyBear;
398
399 let chunk_size =
400 std::cmp::max((input.add_events.len() + input.sub_events.len()) / num_cpus::get(), 1);
401
402 let events = input.add_events.iter().chain(input.sub_events.iter()).collect::<Vec<_>>();
403 let row_batches = events
404 .par_chunks(chunk_size)
405 .map(|events| {
406 let rows = events
407 .iter()
408 .map(|event| {
409 let mut row = [F::zero(); NUM_ADD_SUB_COLS];
410 let cols: &mut AddSubCols<F> = row.as_mut_slice().borrow_mut();
411 unsafe {
412 crate::sys::add_sub_event_to_row_babybear(event, cols);
413 }
414 row
415 })
416 .collect::<Vec<_>>();
417 rows
418 })
419 .collect::<Vec<_>>();
420
421 let mut rows: Vec<[F; NUM_ADD_SUB_COLS]> = vec![];
422 for row_batch in row_batches {
423 rows.extend(row_batch);
424 }
425
426 pad_rows_fixed(&mut rows, || [F::zero(); NUM_ADD_SUB_COLS], None);
427
428 RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_ADD_SUB_COLS)
430 }
431
432 #[test]
433 fn test_malicious_add_sub() {
434 const NUM_TESTS: usize = 5;
435
436 for opcode in [Opcode::ADD, Opcode::SUB] {
437 for _ in 0..NUM_TESTS {
438 let op_a = thread_rng().gen_range(0..u32::MAX);
439 let op_b = thread_rng().gen_range(0..u32::MAX);
440 let op_c = thread_rng().gen_range(0..u32::MAX);
441
442 let correct_op_a = if opcode == Opcode::ADD {
443 op_b.wrapping_add(op_c)
444 } else {
445 op_b.wrapping_sub(op_c)
446 };
447
448 assert!(op_a != correct_op_a);
449
450 let instructions = vec![
451 Instruction::new(opcode, 5, op_b, op_c, true, true),
452 Instruction::new(Opcode::ADD, 10, 0, 0, false, false),
453 ];
454 let program = Program::new(instructions, 0, 0);
455 let stdin = SP1Stdin::new();
456
457 type P = CpuProver<BabyBearPoseidon2, RiscvAir<BabyBear>>;
458
459 let malicious_trace_pv_generator = move |prover: &P,
460 record: &mut ExecutionRecord|
461 -> Vec<(
462 String,
463 RowMajorMatrix<Val<BabyBearPoseidon2>>,
464 )> {
465 let mut malicious_record = record.clone();
466 malicious_record.cpu_events[0].a = op_a;
467 if let Some(MemoryRecordEnum::Write(mut write_record)) =
468 malicious_record.cpu_events[0].a_record
469 {
470 write_record.value = op_a;
471 }
472 if opcode == Opcode::ADD {
473 malicious_record.add_events[0].a = op_a;
474 } else if opcode == Opcode::SUB {
475 malicious_record.sub_events[0].a = op_a;
476 } else {
477 unreachable!()
478 }
479
480 let mut traces = prover.generate_traces(&malicious_record);
481
482 let add_sub_chip_name = chip_name!(AddSubChip, BabyBear);
483 for (chip_name, trace) in traces.iter_mut() {
484 if *chip_name == add_sub_chip_name {
485 let index = if opcode == Opcode::ADD { 0 } else { 1 };
488
489 let first_row = trace.row_mut(index);
490 let first_row: &mut AddSubCols<BabyBear> = first_row.borrow_mut();
491 if opcode == Opcode::ADD {
492 first_row.add_operation.value = op_a.into();
493 } else {
494 first_row.add_operation.value = op_b.into();
495 }
496 }
497 }
498
499 traces
500 };
501
502 let result =
503 run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
504 println!("Result for {opcode:?}: {result:?}");
505 let add_sub_chip_name = chip_name!(AddSubChip, BabyBear);
506 assert!(
507 result.is_err() &&
508 result.unwrap_err().is_constraints_failing(&add_sub_chip_name)
509 );
510 }
511 }
512 }
513}