sp1_core_machine/control_flow/branch/mod.rs
1mod air;
2mod columns;
3mod trace;
4
5pub use columns::*;
6use slop_air::BaseAir;
7use std::marker::PhantomData;
8
9use crate::TrustMode;
10
11#[derive(Default)]
12pub struct BranchChip<M: TrustMode> {
13 pub _phantom: PhantomData<M>,
14}
15
16impl<F, M: TrustMode> BaseAir<F> for BranchChip<M> {
17 fn width(&self) -> usize {
18 if M::IS_TRUSTED {
19 NUM_BRANCH_COLS_SUPERVISOR
20 } else {
21 NUM_BRANCH_COLS_USER
22 }
23 }
24}
25/*
26#[cfg(test)]
27mod tests {
28 use std::borrow::BorrowMut;
29
30 use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program};
31 use sp1_hypercube::{
32 air::MachineAir, koala_bear_poseidon2::SP1InnerPcs, chip_name, CpuProver,
33 MachineProver, Val,
34 };
35
36// use sp1_primitives::SP1Field;
37// use slop_algebra::AbstractField;
38// use slop_matrix::dense::RowMajorMatrix;
39// use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program};
40// use sp1_hypercube::{
41// air::MachineAir, koala_bear_poseidon2::SP1InnerPcs, chip_name, CpuProver,
42// MachineProver, Val,
43// };
44
45 #[test]
46 fn test_malicious_branches() {
47 enum ErrorType {
48 // TODO: Re-enable when we LOGUP-GKR working.
49 // LocalCumulativeSumFailing,
50 ConstraintsFailing,
51 }
52
53 struct BranchTestCase {
54 branch_opcode: Opcode,
55 branch_operand_b_value: u32,
56 branch_operand_c_value: u32,
57 incorrect_next_pc: u64,
58 error_type: ErrorType,
59 }
60
61 // The PC of the branch instruction is 8, and it will branch to 16 if the condition is true.
62 let branch_test_cases = vec![
63 // TODO: Re-enable when we LOGUP-GKR working.
64 // BranchTestCase {
65 // branch_opcode: Opcode::BEQ,
66 // branch_operand_b_value: 5,
67 // branch_operand_c_value: 5,
68 // incorrect_next_pc: 12, // Correct next PC is 16.
69 // error_type: ErrorType::LocalCumulativeSumFailing,
70 // },
71 BranchTestCase {
72 branch_opcode: Opcode::BEQ,
73 branch_operand_b_value: 5,
74 branch_operand_c_value: 3,
75 incorrect_next_pc: 16, // Correct next PC is 12.
76 error_type: ErrorType::ConstraintsFailing,
77 },
78 BranchTestCase {
79 branch_opcode: Opcode::BNE,
80 branch_operand_b_value: 5,
81 branch_operand_c_value: 5,
82 incorrect_next_pc: 16, // Correct next PC is 12.
83 error_type: ErrorType::ConstraintsFailing,
84 },
85 // TODO: Re-enable when we LOGUP-GKR working.
86 // BranchTestCase {
87 // branch_opcode: Opcode::BNE,
88 // branch_operand_b_value: 5,
89 // branch_operand_c_value: 3,
90 // incorrect_next_pc: 12, // Correct next PC is 16.
91 // error_type: ErrorType::LocalCumulativeSumFailing,
92 // },
93 BranchTestCase {
94 branch_opcode: Opcode::BLTU,
95 branch_operand_b_value: 5,
96 branch_operand_c_value: 3,
97 incorrect_next_pc: 16, // Correct next PC is 12.
98 error_type: ErrorType::ConstraintsFailing,
99 },
100 // TODO: Re-enable when we LOGUP-GKR working.
101 // BranchTestCase {
102 // branch_opcode: Opcode::BLTU,
103 // branch_operand_b_value: 3,
104 // branch_operand_c_value: 5,
105 // incorrect_next_pc: 12, // Correct next PC is 16.
106 // error_type: ErrorType::LocalCumulativeSumFailing,
107 // },
108 // TODO: Re-enable when we LOGUP-GKR working.
109 // BranchTestCase {
110 // branch_opcode: Opcode::BLT,
111 // branch_operand_b_value: 0xFFFF_FFFF, // This is -1.
112 // branch_operand_c_value: 3,
113 // incorrect_next_pc: 12, // Correct next PC is 16.
114 // error_type: ErrorType::LocalCumulativeSumFailing,
115 // },
116 BranchTestCase {
117 branch_opcode: Opcode::BLT,
118 branch_operand_b_value: 3,
119 branch_operand_c_value: 0xFFFF_FFFF, // This is -1.
120 incorrect_next_pc: 16, // Correct next PC is 12.
121 error_type: ErrorType::ConstraintsFailing,
122 },
123 BranchTestCase {
124 branch_opcode: Opcode::BGEU,
125 branch_operand_b_value: 3,
126 branch_operand_c_value: 5,
127 incorrect_next_pc: 16, // Correct next PC is 12.
128 error_type: ErrorType::ConstraintsFailing,
129 },
130 // TODO: Re-enable when we LOGUP-GKR working.
131 // BranchTestCase {
132 // branch_opcode: Opcode::BGEU,
133 // branch_operand_b_value: 5,
134 // branch_operand_c_value: 5,
135 // incorrect_next_pc: 12, // Correct next PC is 16.
136 // error_type: ErrorType::LocalCumulativeSumFailing,
137 // },
138 // TODO: Re-enable when we LOGUP-GKR working.
139 // BranchTestCase {
140 // branch_opcode: Opcode::BGEU,
141 // branch_operand_b_value: 5,
142 // branch_operand_c_value: 3,
143 // incorrect_next_pc: 12, // Correct next PC is 16.
144 // error_type: ErrorType::LocalCumulativeSumFailing,
145 // },
146 BranchTestCase {
147 branch_opcode: Opcode::BGE,
148 branch_operand_b_value: 0xFFFF_FFFF, // This is -1.
149 branch_operand_c_value: 5,
150 incorrect_next_pc: 16, // Correct next PC is 12.
151 error_type: ErrorType::ConstraintsFailing,
152 },
153 // TODO: Re-enable when we LOGUP-GKR working.
154 // BranchTestCase {
155 // branch_opcode: Opcode::BGE,
156 // branch_operand_b_value: 5,
157 // branch_operand_c_value: 5,
158 // incorrect_next_pc: 12, // Correct next PC is 16.
159 // error_type: ErrorType::LocalCumulativeSumFailing,
160 // },
161 // TODO: Re-enable when we LOGUP-GKR working.
162 // BranchTestCase {
163 // branch_opcode: Opcode::BGE,
164 // branch_operand_b_value: 3,
165 // branch_operand_c_value: 0xFFFF_FFFF, // This is -1.
166 // incorrect_next_pc: 12, // Correct next PC is 16.
167 // error_type: ErrorType::LocalCumulativeSumFailing,
168 // },
169 ];
170
171 for test_case in branch_test_cases {
172 let instructions = vec![
173 Instruction::new(Opcode::ADD, 29, 0, test_case.branch_operand_b_value, false, true),
174 Instruction::new(Opcode::ADD, 30, 0, test_case.branch_operand_c_value, false, true),
175 Instruction::new(test_case.branch_opcode, 29, 30, 8, false, true),
176 Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
177 Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
178 ];
179 let program = Program::new(instructions, 0, 0);
180 let stdin = SP1Stdin::new();
181
182 type P = CpuProver<SP1InnerPcs, RiscvAir<SP1Field>>;
183
184 let malicious_trace_pv_generator =
185 move |prover: &P,
186 record: &mut ExecutionRecord|
187 -> Vec<(String, RowMajorMatrix<Val<SP1InnerPcs>>)> {
188 // Create a malicious record where the BEQ instruction branches incorrectly.
189 let mut malicious_record = record.clone();
190 malicious_record.branch_events[0].next_pc = test_case.incorrect_next_pc;
191 prover.generate_traces(&malicious_record)
192 };
193
194 let result =
195 run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
196
197 match test_case.error_type {
198 // TODO: Re-enable when we LOGUP-GKR working.
199 // ErrorType::LocalCumulativeSumFailing => {
200 // assert!(
201 // result.is_err() && result.unwrap_err().is_local_cumulative_sum_failing()
202 // );
203 // }
204 ErrorType::ConstraintsFailing => {
205 assert!(result.is_err() && result.unwrap_err().is_constraints_failing());
206 }
207 }
208 }
209 }
210
211 #[test]
212 fn test_malicious_multiple_opcode_flags() {
213 let instructions = vec![
214 Instruction::new(Opcode::ADD, 29, 0, 5, false, true),
215 Instruction::new(Opcode::ADD, 30, 0, 5, false, true),
216 Instruction::new(Opcode::BEQ, 29, 30, 8, false, true),
217 Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
218 Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
219 ];
220 let program = Program::new(instructions, 0, 0);
221 let stdin = SP1Stdin::new();
222
223 type P = CpuProver<SP1InnerPcs, RiscvAir<SP1Field>>;
224
225 let malicious_trace_pv_generator =
226 |prover: &P,
227 record: &mut ExecutionRecord|
228 -> Vec<(String, RowMajorMatrix<Val<SP1InnerPcs>>)> {
229 // Modify the branch chip to have a row that has multiple opcode flags set.
230 let mut traces = prover.generate_traces(record);
231 let branch_chip_name = chip_name!(BranchChip, SP1Field);
232 for (chip_name, trace) in traces.iter_mut() {
233 if *chip_name == branch_chip_name {
234 let first_row = trace.row_mut(0);
235 let first_row: &mut BranchColumns<SP1Field> = first_row.borrow_mut();
236 assert!(first_row.is_beq == SP1Field::one());
237 first_row.is_bne = SP1Field::one();
238 }
239 }
240 traces
241 };
242
243 let result =
244 run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
245 assert!(result.is_err() && result.unwrap_err().is_constraints_failing());
246 }
247}
248 */