1mod air;
2mod columns;
3mod trace;
4
5pub use columns::*;
6use p3_air::BaseAir;
7
8#[derive(Default)]
9pub struct BranchChip;
10
11impl<F> BaseAir<F> for BranchChip {
12 fn width(&self) -> usize {
13 NUM_BRANCH_COLS
14 }
15}
16
17#[cfg(test)]
18mod tests {
19 use std::borrow::BorrowMut;
20
21 use p3_baby_bear::BabyBear;
22 use p3_field::AbstractField;
23 use p3_matrix::dense::RowMajorMatrix;
24 use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program};
25 use sp1_stark::{
26 air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, chip_name, CpuProver,
27 MachineProver, Val,
28 };
29
30 use crate::{
31 control_flow::{BranchChip, BranchColumns},
32 io::SP1Stdin,
33 riscv::RiscvAir,
34 utils::run_malicious_test,
35 };
36
37 #[test]
38 fn test_malicious_branches() {
39 enum ErrorType {
40 LocalCumulativeSumFailing,
41 ConstraintsFailing,
42 }
43
44 struct BranchTestCase {
45 branch_opcode: Opcode,
46 branch_operand_b_value: u32,
47 branch_operand_c_value: u32,
48 incorrect_next_pc: u32,
49 error_type: ErrorType,
50 }
51
52 let branch_test_cases = vec![
54 BranchTestCase {
55 branch_opcode: Opcode::BEQ,
56 branch_operand_b_value: 5,
57 branch_operand_c_value: 5,
58 incorrect_next_pc: 12, error_type: ErrorType::LocalCumulativeSumFailing,
60 },
61 BranchTestCase {
62 branch_opcode: Opcode::BEQ,
63 branch_operand_b_value: 5,
64 branch_operand_c_value: 3,
65 incorrect_next_pc: 16, error_type: ErrorType::ConstraintsFailing,
67 },
68 BranchTestCase {
69 branch_opcode: Opcode::BNE,
70 branch_operand_b_value: 5,
71 branch_operand_c_value: 5,
72 incorrect_next_pc: 16, error_type: ErrorType::ConstraintsFailing,
74 },
75 BranchTestCase {
76 branch_opcode: Opcode::BNE,
77 branch_operand_b_value: 5,
78 branch_operand_c_value: 3,
79 incorrect_next_pc: 12, error_type: ErrorType::LocalCumulativeSumFailing,
81 },
82 BranchTestCase {
83 branch_opcode: Opcode::BLTU,
84 branch_operand_b_value: 5,
85 branch_operand_c_value: 3,
86 incorrect_next_pc: 16, error_type: ErrorType::ConstraintsFailing,
88 },
89 BranchTestCase {
90 branch_opcode: Opcode::BLTU,
91 branch_operand_b_value: 3,
92 branch_operand_c_value: 5,
93 incorrect_next_pc: 12, error_type: ErrorType::LocalCumulativeSumFailing,
95 },
96 BranchTestCase {
97 branch_opcode: Opcode::BLT,
98 branch_operand_b_value: 0xFFFF_FFFF, branch_operand_c_value: 3,
100 incorrect_next_pc: 12, error_type: ErrorType::LocalCumulativeSumFailing,
102 },
103 BranchTestCase {
104 branch_opcode: Opcode::BLT,
105 branch_operand_b_value: 3,
106 branch_operand_c_value: 0xFFFF_FFFF, incorrect_next_pc: 16, error_type: ErrorType::ConstraintsFailing,
109 },
110 BranchTestCase {
111 branch_opcode: Opcode::BGEU,
112 branch_operand_b_value: 3,
113 branch_operand_c_value: 5,
114 incorrect_next_pc: 16, error_type: ErrorType::ConstraintsFailing,
116 },
117 BranchTestCase {
118 branch_opcode: Opcode::BGEU,
119 branch_operand_b_value: 5,
120 branch_operand_c_value: 5,
121 incorrect_next_pc: 12, error_type: ErrorType::LocalCumulativeSumFailing,
123 },
124 BranchTestCase {
125 branch_opcode: Opcode::BGEU,
126 branch_operand_b_value: 5,
127 branch_operand_c_value: 3,
128 incorrect_next_pc: 12, error_type: ErrorType::LocalCumulativeSumFailing,
130 },
131 BranchTestCase {
132 branch_opcode: Opcode::BGE,
133 branch_operand_b_value: 0xFFFF_FFFF, branch_operand_c_value: 5,
135 incorrect_next_pc: 16, error_type: ErrorType::ConstraintsFailing,
137 },
138 BranchTestCase {
139 branch_opcode: Opcode::BGE,
140 branch_operand_b_value: 5,
141 branch_operand_c_value: 5,
142 incorrect_next_pc: 12, error_type: ErrorType::LocalCumulativeSumFailing,
144 },
145 BranchTestCase {
146 branch_opcode: Opcode::BGE,
147 branch_operand_b_value: 3,
148 branch_operand_c_value: 0xFFFF_FFFF, incorrect_next_pc: 12, error_type: ErrorType::LocalCumulativeSumFailing,
151 },
152 ];
153
154 for test_case in branch_test_cases {
155 let instructions = vec![
156 Instruction::new(Opcode::ADD, 29, 0, test_case.branch_operand_b_value, false, true),
157 Instruction::new(Opcode::ADD, 30, 0, test_case.branch_operand_c_value, false, true),
158 Instruction::new(test_case.branch_opcode, 29, 30, 8, false, true),
159 Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
160 Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
161 ];
162 let program = Program::new(instructions, 0, 0);
163 let stdin = SP1Stdin::new();
164
165 type P = CpuProver<BabyBearPoseidon2, RiscvAir<BabyBear>>;
166
167 let malicious_trace_pv_generator =
168 move |prover: &P,
169 record: &mut ExecutionRecord|
170 -> Vec<(String, RowMajorMatrix<Val<BabyBearPoseidon2>>)> {
171 let mut malicious_record = record.clone();
173 malicious_record.branch_events[0].next_pc = test_case.incorrect_next_pc;
174 prover.generate_traces(&malicious_record)
175 };
176
177 let result =
178 run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
179
180 match test_case.error_type {
181 ErrorType::LocalCumulativeSumFailing => {
182 assert!(
183 result.is_err() && result.unwrap_err().is_local_cumulative_sum_failing()
184 );
185 }
186 ErrorType::ConstraintsFailing => {
187 let branch_chip_name = chip_name!(BranchChip, BabyBear);
188 assert!(
189 result.is_err() &&
190 result.unwrap_err().is_constraints_failing(&branch_chip_name)
191 );
192 }
193 }
194 }
195 }
196
197 #[test]
198 fn test_malicious_multiple_opcode_flags() {
199 let instructions = vec![
200 Instruction::new(Opcode::ADD, 29, 0, 5, false, true),
201 Instruction::new(Opcode::ADD, 30, 0, 5, false, true),
202 Instruction::new(Opcode::BEQ, 29, 30, 8, false, true),
203 Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
204 Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
205 ];
206 let program = Program::new(instructions, 0, 0);
207 let stdin = SP1Stdin::new();
208
209 type P = CpuProver<BabyBearPoseidon2, RiscvAir<BabyBear>>;
210
211 let malicious_trace_pv_generator =
212 |prover: &P,
213 record: &mut ExecutionRecord|
214 -> Vec<(String, RowMajorMatrix<Val<BabyBearPoseidon2>>)> {
215 let mut traces = prover.generate_traces(record);
217 let branch_chip_name = chip_name!(BranchChip, BabyBear);
218 for (chip_name, trace) in traces.iter_mut() {
219 if *chip_name == branch_chip_name {
220 let first_row = trace.row_mut(0);
221 let first_row: &mut BranchColumns<BabyBear> = first_row.borrow_mut();
222 assert!(first_row.is_beq == BabyBear::one());
223 first_row.is_bne = BabyBear::one();
224 }
225 }
226 traces
227 };
228
229 let result =
230 run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
231 let branch_chip_name = chip_name!(BranchChip, BabyBear);
232 assert!(result.is_err() && result.unwrap_err().is_constraints_failing(&branch_chip_name));
233 }
234}