sp1_core_machine/alu/bitwise/
mod.rs

1use core::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use hashbrown::HashMap;
7use itertools::Itertools;
8use p3_air::{Air, AirBuilder, BaseAir};
9use p3_field::{AbstractField, PrimeField, PrimeField32};
10use p3_matrix::{dense::RowMajorMatrix, Matrix};
11use p3_maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator, ParallelSlice};
12use sp1_core_executor::{
13    events::{AluEvent, ByteLookupEvent, ByteRecord},
14    ByteOpcode, ExecutionRecord, Opcode, Program, DEFAULT_PC_INC,
15};
16use sp1_derive::AlignedBorrow;
17use sp1_stark::{
18    air::{MachineAir, SP1AirBuilder},
19    Word,
20};
21
22use crate::utils::pad_rows_fixed;
23
24/// The number of main trace columns for `BitwiseChip`.
25pub const NUM_BITWISE_COLS: usize = size_of::<BitwiseCols<u8>>();
26
27/// A chip that implements bitwise operations for the opcodes XOR, OR, and AND.
28#[derive(Default)]
29pub struct BitwiseChip;
30
31/// The column layout for the chip.
32#[derive(AlignedBorrow, Default, Clone, Copy)]
33#[repr(C)]
34pub struct BitwiseCols<T> {
35    /// The program counter.
36    pub pc: T,
37
38    /// The output operand.
39    pub a: Word<T>,
40
41    /// The first input operand.
42    pub b: Word<T>,
43
44    /// The second input operand.
45    pub c: Word<T>,
46
47    /// Whether the first operand is not register 0.
48    pub op_a_not_0: T,
49
50    /// If the opcode is XOR.
51    pub is_xor: T,
52
53    // If the opcode is OR.
54    pub is_or: T,
55
56    /// If the opcode is AND.
57    pub is_and: T,
58}
59
60impl<F: PrimeField32> MachineAir<F> for BitwiseChip {
61    type Record = ExecutionRecord;
62
63    type Program = Program;
64
65    fn name(&self) -> String {
66        "Bitwise".to_string()
67    }
68
69    fn generate_trace(
70        &self,
71        input: &ExecutionRecord,
72        _: &mut ExecutionRecord,
73    ) -> RowMajorMatrix<F> {
74        let mut rows = input
75            .bitwise_events
76            .par_iter()
77            .map(|event| {
78                let mut row = [F::zero(); NUM_BITWISE_COLS];
79                let cols: &mut BitwiseCols<F> = row.as_mut_slice().borrow_mut();
80                let mut blu = Vec::new();
81                self.event_to_row(event, cols, &mut blu);
82                row
83            })
84            .collect::<Vec<_>>();
85
86        // Pad the trace to a power of two.
87        pad_rows_fixed(
88            &mut rows,
89            || [F::zero(); NUM_BITWISE_COLS],
90            input.fixed_log2_rows::<F, _>(self),
91        );
92
93        // Convert the trace to a row major matrix.
94        RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_BITWISE_COLS)
95    }
96
97    fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
98        let chunk_size = std::cmp::max(input.bitwise_events.len() / num_cpus::get(), 1);
99
100        let blu_batches = input
101            .bitwise_events
102            .par_chunks(chunk_size)
103            .map(|events| {
104                let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
105                events.iter().for_each(|event| {
106                    let mut row = [F::zero(); NUM_BITWISE_COLS];
107                    let cols: &mut BitwiseCols<F> = row.as_mut_slice().borrow_mut();
108                    self.event_to_row(event, cols, &mut blu);
109                });
110                blu
111            })
112            .collect::<Vec<_>>();
113
114        output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
115    }
116
117    fn included(&self, shard: &Self::Record) -> bool {
118        if let Some(shape) = shard.shape.as_ref() {
119            shape.included::<F, _>(self)
120        } else {
121            !shard.bitwise_events.is_empty()
122        }
123    }
124
125    fn local_only(&self) -> bool {
126        true
127    }
128}
129
130impl BitwiseChip {
131    /// Create a row from an event.
132    fn event_to_row<F: PrimeField>(
133        &self,
134        event: &AluEvent,
135        cols: &mut BitwiseCols<F>,
136        blu: &mut impl ByteRecord,
137    ) {
138        cols.pc = F::from_canonical_u32(event.pc);
139
140        let a = event.a.to_le_bytes();
141        let b = event.b.to_le_bytes();
142        let c = event.c.to_le_bytes();
143
144        cols.a = Word::from(event.a);
145        cols.b = Word::from(event.b);
146        cols.c = Word::from(event.c);
147        cols.op_a_not_0 = F::from_bool(!event.op_a_0);
148
149        cols.is_xor = F::from_bool(event.opcode == Opcode::XOR);
150        cols.is_or = F::from_bool(event.opcode == Opcode::OR);
151        cols.is_and = F::from_bool(event.opcode == Opcode::AND);
152
153        if !event.op_a_0 {
154            for ((b_a, b_b), b_c) in a.into_iter().zip(b).zip(c) {
155                let byte_event = ByteLookupEvent {
156                    opcode: ByteOpcode::from(event.opcode),
157                    a1: b_a as u16,
158                    a2: 0,
159                    b: b_b,
160                    c: b_c,
161                };
162                blu.add_byte_lookup_event(byte_event);
163            }
164        }
165    }
166}
167
168impl<F> BaseAir<F> for BitwiseChip {
169    fn width(&self) -> usize {
170        NUM_BITWISE_COLS
171    }
172}
173
174impl<AB> Air<AB> for BitwiseChip
175where
176    AB: SP1AirBuilder,
177{
178    fn eval(&self, builder: &mut AB) {
179        let main = builder.main();
180        let local = main.row_slice(0);
181        let local: &BitwiseCols<AB::Var> = (*local).borrow();
182
183        // Get the opcode for the operation.
184        let opcode = local.is_xor * ByteOpcode::XOR.as_field::<AB::F>() +
185            local.is_or * ByteOpcode::OR.as_field::<AB::F>() +
186            local.is_and * ByteOpcode::AND.as_field::<AB::F>();
187
188        // Get a multiplicity of `1` only for a true row.
189        let mult = local.is_xor + local.is_or + local.is_and;
190        for ((a, b), c) in local.a.into_iter().zip(local.b).zip(local.c) {
191            builder.send_byte(opcode.clone(), a, b, c, local.op_a_not_0);
192        }
193
194        // SAFETY: We check that a padding row has `op_a_not_0 == 0`, to prevent a padding row
195        // sending byte lookups.
196        builder.when(local.op_a_not_0).assert_one(mult.clone());
197
198        // Get the cpu opcode, which corresponds to the opcode being sent in the CPU table.
199        let cpu_opcode = local.is_xor * Opcode::XOR.as_field::<AB::F>() +
200            local.is_or * Opcode::OR.as_field::<AB::F>() +
201            local.is_and * Opcode::AND.as_field::<AB::F>();
202
203        // Receive the arguments.
204        // SAFETY: This checks the following.
205        // - `next_pc = pc + 4`
206        // - `num_extra_cycles = 0`
207        // - `op_a_val` is constrained by the byte lookups when `op_a_not_0 == 1`
208        // - `op_a_not_0` is correct, due to the sent `op_a_0` being equal to `1 - op_a_not_0`
209        // - `op_a_immutable = 0`
210        // - `is_memory = 0`
211        // - `is_syscall = 0`
212        // - `is_halt = 0`
213        // Note that `is_xor + is_or + is_and` is checked to be boolean below.
214        builder.receive_instruction(
215            AB::Expr::zero(),
216            AB::Expr::zero(),
217            local.pc,
218            local.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_INC),
219            AB::Expr::zero(),
220            cpu_opcode,
221            local.a,
222            local.b,
223            local.c,
224            AB::Expr::one() - local.op_a_not_0,
225            AB::Expr::zero(),
226            AB::Expr::zero(),
227            AB::Expr::zero(),
228            AB::Expr::zero(),
229            local.is_xor + local.is_or + local.is_and,
230        );
231
232        // SAFETY: All selectors `is_xor`, `is_or`, `is_and` are checked to be boolean.
233        // Each "real" row has exactly one selector turned on, as `is_real`, the sum of the three
234        // selectors, is boolean. Therefore, the `opcode` and `cpu_opcode` matches the
235        // corresponding opcode.
236        let is_real = local.is_xor + local.is_or + local.is_and;
237        builder.assert_bool(local.is_xor);
238        builder.assert_bool(local.is_or);
239        builder.assert_bool(local.is_and);
240        builder.assert_bool(is_real);
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    #![allow(clippy::print_stdout)]
247
248    use p3_baby_bear::BabyBear;
249    use p3_matrix::dense::RowMajorMatrix;
250    use rand::{thread_rng, Rng};
251    use sp1_core_executor::{
252        events::{AluEvent, MemoryRecordEnum},
253        ExecutionRecord, Instruction, Opcode, Program,
254    };
255    use sp1_stark::{
256        air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, CpuProver, MachineProver,
257        StarkGenericConfig, Val,
258    };
259
260    use crate::{
261        io::SP1Stdin,
262        riscv::RiscvAir,
263        utils::{run_malicious_test, uni_stark_prove, uni_stark_verify},
264    };
265
266    use super::BitwiseChip;
267
268    #[test]
269    fn generate_trace() {
270        let mut shard = ExecutionRecord::default();
271        shard.bitwise_events = vec![AluEvent::new(0, Opcode::XOR, 25, 10, 19, false)];
272        let chip = BitwiseChip::default();
273        let trace: RowMajorMatrix<BabyBear> =
274            chip.generate_trace(&shard, &mut ExecutionRecord::default());
275        println!("{:?}", trace.values)
276    }
277
278    #[test]
279    fn prove_babybear() {
280        let config = BabyBearPoseidon2::new();
281        let mut challenger = config.challenger();
282
283        let mut shard = ExecutionRecord::default();
284        shard.bitwise_events = [
285            AluEvent::new(0, Opcode::XOR, 25, 10, 19, false),
286            AluEvent::new(0, Opcode::OR, 27, 10, 19, false),
287            AluEvent::new(0, Opcode::AND, 2, 10, 19, false),
288        ]
289        .repeat(1000);
290        let chip = BitwiseChip::default();
291        let trace: RowMajorMatrix<BabyBear> =
292            chip.generate_trace(&shard, &mut ExecutionRecord::default());
293        let proof = uni_stark_prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
294
295        let mut challenger = config.challenger();
296        uni_stark_verify(&config, &chip, &mut challenger, &proof).unwrap();
297    }
298
299    #[test]
300    fn test_malicious_bitwise() {
301        const NUM_TESTS: usize = 5;
302
303        for opcode in [Opcode::XOR, Opcode::OR, Opcode::AND] {
304            for _ in 0..NUM_TESTS {
305                let op_a = thread_rng().gen_range(0..u32::MAX);
306                let op_b = thread_rng().gen_range(0..u32::MAX);
307                let op_c = thread_rng().gen_range(0..u32::MAX);
308
309                let correct_op_a = if opcode == Opcode::XOR {
310                    op_b ^ op_c
311                } else if opcode == Opcode::OR {
312                    op_b | op_c
313                } else {
314                    op_b & op_c
315                };
316
317                assert!(op_a != correct_op_a);
318
319                let instructions = vec![
320                    Instruction::new(opcode, 5, op_b, op_c, true, true),
321                    Instruction::new(Opcode::ADD, 10, 0, 0, false, false),
322                ];
323                let program = Program::new(instructions, 0, 0);
324                let stdin = SP1Stdin::new();
325
326                type P = CpuProver<BabyBearPoseidon2, RiscvAir<BabyBear>>;
327
328                let malicious_trace_pv_generator = move |prover: &P,
329                                                         record: &mut ExecutionRecord|
330                      -> Vec<(
331                    String,
332                    RowMajorMatrix<Val<BabyBearPoseidon2>>,
333                )> {
334                    let mut malicious_record = record.clone();
335                    malicious_record.cpu_events[0].a = op_a;
336                    if let Some(MemoryRecordEnum::Write(mut write_record)) =
337                        malicious_record.cpu_events[0].a_record
338                    {
339                        write_record.value = op_a;
340                    }
341                    malicious_record.bitwise_events[0].a = op_a;
342                    prover.generate_traces(&malicious_record)
343                };
344
345                let result =
346                    run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
347                assert!(result.is_err() && result.unwrap_err().is_local_cumulative_sum_failing());
348            }
349        }
350    }
351}