sp1_core_machine/alu/bitwise/
mod.rs1use core::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use hashbrown::HashMap;
7use itertools::Itertools;
8use p3_air::{Air, AirBuilder, BaseAir};
9use p3_field::{AbstractField, PrimeField, PrimeField32};
10use p3_matrix::{dense::RowMajorMatrix, Matrix};
11use p3_maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator, ParallelSlice};
12use sp1_core_executor::{
13 events::{AluEvent, ByteLookupEvent, ByteRecord},
14 ByteOpcode, ExecutionRecord, Opcode, Program, DEFAULT_PC_INC,
15};
16use sp1_derive::AlignedBorrow;
17use sp1_stark::{
18 air::{MachineAir, SP1AirBuilder},
19 Word,
20};
21
22use crate::utils::pad_rows_fixed;
23
24pub const NUM_BITWISE_COLS: usize = size_of::<BitwiseCols<u8>>();
26
27#[derive(Default)]
29pub struct BitwiseChip;
30
31#[derive(AlignedBorrow, Default, Clone, Copy)]
33#[repr(C)]
34pub struct BitwiseCols<T> {
35 pub pc: T,
37
38 pub a: Word<T>,
40
41 pub b: Word<T>,
43
44 pub c: Word<T>,
46
47 pub op_a_not_0: T,
49
50 pub is_xor: T,
52
53 pub is_or: T,
55
56 pub is_and: T,
58}
59
60impl<F: PrimeField32> MachineAir<F> for BitwiseChip {
61 type Record = ExecutionRecord;
62
63 type Program = Program;
64
65 fn name(&self) -> String {
66 "Bitwise".to_string()
67 }
68
69 fn generate_trace(
70 &self,
71 input: &ExecutionRecord,
72 _: &mut ExecutionRecord,
73 ) -> RowMajorMatrix<F> {
74 let mut rows = input
75 .bitwise_events
76 .par_iter()
77 .map(|event| {
78 let mut row = [F::zero(); NUM_BITWISE_COLS];
79 let cols: &mut BitwiseCols<F> = row.as_mut_slice().borrow_mut();
80 let mut blu = Vec::new();
81 self.event_to_row(event, cols, &mut blu);
82 row
83 })
84 .collect::<Vec<_>>();
85
86 pad_rows_fixed(
88 &mut rows,
89 || [F::zero(); NUM_BITWISE_COLS],
90 input.fixed_log2_rows::<F, _>(self),
91 );
92
93 RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_BITWISE_COLS)
95 }
96
97 fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
98 let chunk_size = std::cmp::max(input.bitwise_events.len() / num_cpus::get(), 1);
99
100 let blu_batches = input
101 .bitwise_events
102 .par_chunks(chunk_size)
103 .map(|events| {
104 let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
105 events.iter().for_each(|event| {
106 let mut row = [F::zero(); NUM_BITWISE_COLS];
107 let cols: &mut BitwiseCols<F> = row.as_mut_slice().borrow_mut();
108 self.event_to_row(event, cols, &mut blu);
109 });
110 blu
111 })
112 .collect::<Vec<_>>();
113
114 output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
115 }
116
117 fn included(&self, shard: &Self::Record) -> bool {
118 if let Some(shape) = shard.shape.as_ref() {
119 shape.included::<F, _>(self)
120 } else {
121 !shard.bitwise_events.is_empty()
122 }
123 }
124
125 fn local_only(&self) -> bool {
126 true
127 }
128}
129
130impl BitwiseChip {
131 fn event_to_row<F: PrimeField>(
133 &self,
134 event: &AluEvent,
135 cols: &mut BitwiseCols<F>,
136 blu: &mut impl ByteRecord,
137 ) {
138 cols.pc = F::from_canonical_u32(event.pc);
139
140 let a = event.a.to_le_bytes();
141 let b = event.b.to_le_bytes();
142 let c = event.c.to_le_bytes();
143
144 cols.a = Word::from(event.a);
145 cols.b = Word::from(event.b);
146 cols.c = Word::from(event.c);
147 cols.op_a_not_0 = F::from_bool(!event.op_a_0);
148
149 cols.is_xor = F::from_bool(event.opcode == Opcode::XOR);
150 cols.is_or = F::from_bool(event.opcode == Opcode::OR);
151 cols.is_and = F::from_bool(event.opcode == Opcode::AND);
152
153 if !event.op_a_0 {
154 for ((b_a, b_b), b_c) in a.into_iter().zip(b).zip(c) {
155 let byte_event = ByteLookupEvent {
156 opcode: ByteOpcode::from(event.opcode),
157 a1: b_a as u16,
158 a2: 0,
159 b: b_b,
160 c: b_c,
161 };
162 blu.add_byte_lookup_event(byte_event);
163 }
164 }
165 }
166}
167
168impl<F> BaseAir<F> for BitwiseChip {
169 fn width(&self) -> usize {
170 NUM_BITWISE_COLS
171 }
172}
173
174impl<AB> Air<AB> for BitwiseChip
175where
176 AB: SP1AirBuilder,
177{
178 fn eval(&self, builder: &mut AB) {
179 let main = builder.main();
180 let local = main.row_slice(0);
181 let local: &BitwiseCols<AB::Var> = (*local).borrow();
182
183 let opcode = local.is_xor * ByteOpcode::XOR.as_field::<AB::F>() +
185 local.is_or * ByteOpcode::OR.as_field::<AB::F>() +
186 local.is_and * ByteOpcode::AND.as_field::<AB::F>();
187
188 let mult = local.is_xor + local.is_or + local.is_and;
190 for ((a, b), c) in local.a.into_iter().zip(local.b).zip(local.c) {
191 builder.send_byte(opcode.clone(), a, b, c, local.op_a_not_0);
192 }
193
194 builder.when(local.op_a_not_0).assert_one(mult.clone());
197
198 let cpu_opcode = local.is_xor * Opcode::XOR.as_field::<AB::F>() +
200 local.is_or * Opcode::OR.as_field::<AB::F>() +
201 local.is_and * Opcode::AND.as_field::<AB::F>();
202
203 builder.receive_instruction(
215 AB::Expr::zero(),
216 AB::Expr::zero(),
217 local.pc,
218 local.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_INC),
219 AB::Expr::zero(),
220 cpu_opcode,
221 local.a,
222 local.b,
223 local.c,
224 AB::Expr::one() - local.op_a_not_0,
225 AB::Expr::zero(),
226 AB::Expr::zero(),
227 AB::Expr::zero(),
228 AB::Expr::zero(),
229 local.is_xor + local.is_or + local.is_and,
230 );
231
232 let is_real = local.is_xor + local.is_or + local.is_and;
237 builder.assert_bool(local.is_xor);
238 builder.assert_bool(local.is_or);
239 builder.assert_bool(local.is_and);
240 builder.assert_bool(is_real);
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 #![allow(clippy::print_stdout)]
247
248 use p3_baby_bear::BabyBear;
249 use p3_matrix::dense::RowMajorMatrix;
250 use rand::{thread_rng, Rng};
251 use sp1_core_executor::{
252 events::{AluEvent, MemoryRecordEnum},
253 ExecutionRecord, Instruction, Opcode, Program,
254 };
255 use sp1_stark::{
256 air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, CpuProver, MachineProver,
257 StarkGenericConfig, Val,
258 };
259
260 use crate::{
261 io::SP1Stdin,
262 riscv::RiscvAir,
263 utils::{run_malicious_test, uni_stark_prove, uni_stark_verify},
264 };
265
266 use super::BitwiseChip;
267
268 #[test]
269 fn generate_trace() {
270 let mut shard = ExecutionRecord::default();
271 shard.bitwise_events = vec![AluEvent::new(0, Opcode::XOR, 25, 10, 19, false)];
272 let chip = BitwiseChip::default();
273 let trace: RowMajorMatrix<BabyBear> =
274 chip.generate_trace(&shard, &mut ExecutionRecord::default());
275 println!("{:?}", trace.values)
276 }
277
278 #[test]
279 fn prove_babybear() {
280 let config = BabyBearPoseidon2::new();
281 let mut challenger = config.challenger();
282
283 let mut shard = ExecutionRecord::default();
284 shard.bitwise_events = [
285 AluEvent::new(0, Opcode::XOR, 25, 10, 19, false),
286 AluEvent::new(0, Opcode::OR, 27, 10, 19, false),
287 AluEvent::new(0, Opcode::AND, 2, 10, 19, false),
288 ]
289 .repeat(1000);
290 let chip = BitwiseChip::default();
291 let trace: RowMajorMatrix<BabyBear> =
292 chip.generate_trace(&shard, &mut ExecutionRecord::default());
293 let proof = uni_stark_prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
294
295 let mut challenger = config.challenger();
296 uni_stark_verify(&config, &chip, &mut challenger, &proof).unwrap();
297 }
298
299 #[test]
300 fn test_malicious_bitwise() {
301 const NUM_TESTS: usize = 5;
302
303 for opcode in [Opcode::XOR, Opcode::OR, Opcode::AND] {
304 for _ in 0..NUM_TESTS {
305 let op_a = thread_rng().gen_range(0..u32::MAX);
306 let op_b = thread_rng().gen_range(0..u32::MAX);
307 let op_c = thread_rng().gen_range(0..u32::MAX);
308
309 let correct_op_a = if opcode == Opcode::XOR {
310 op_b ^ op_c
311 } else if opcode == Opcode::OR {
312 op_b | op_c
313 } else {
314 op_b & op_c
315 };
316
317 assert!(op_a != correct_op_a);
318
319 let instructions = vec![
320 Instruction::new(opcode, 5, op_b, op_c, true, true),
321 Instruction::new(Opcode::ADD, 10, 0, 0, false, false),
322 ];
323 let program = Program::new(instructions, 0, 0);
324 let stdin = SP1Stdin::new();
325
326 type P = CpuProver<BabyBearPoseidon2, RiscvAir<BabyBear>>;
327
328 let malicious_trace_pv_generator = move |prover: &P,
329 record: &mut ExecutionRecord|
330 -> Vec<(
331 String,
332 RowMajorMatrix<Val<BabyBearPoseidon2>>,
333 )> {
334 let mut malicious_record = record.clone();
335 malicious_record.cpu_events[0].a = op_a;
336 if let Some(MemoryRecordEnum::Write(mut write_record)) =
337 malicious_record.cpu_events[0].a_record
338 {
339 write_record.value = op_a;
340 }
341 malicious_record.bitwise_events[0].a = op_a;
342 prover.generate_traces(&malicious_record)
343 };
344
345 let result =
346 run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
347 assert!(result.is_err() && result.unwrap_err().is_local_cumulative_sum_failing());
348 }
349 }
350 }
351}