sp1_core_machine/control_flow/branch/
mod.rs

1mod air;
2mod columns;
3mod trace;
4
5pub use columns::*;
6use p3_air::BaseAir;
7
8#[derive(Default)]
9pub struct BranchChip;
10
11impl<F> BaseAir<F> for BranchChip {
12    fn width(&self) -> usize {
13        NUM_BRANCH_COLS
14    }
15}
16
17#[cfg(test)]
18mod tests {
19    use std::borrow::BorrowMut;
20
21    use p3_baby_bear::BabyBear;
22    use p3_field::AbstractField;
23    use p3_matrix::dense::RowMajorMatrix;
24    use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program};
25    use sp1_stark::{
26        air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, chip_name, CpuProver,
27        MachineProver, Val,
28    };
29
30    use crate::{
31        control_flow::{BranchChip, BranchColumns},
32        io::SP1Stdin,
33        riscv::RiscvAir,
34        utils::run_malicious_test,
35    };
36
37    #[test]
38    fn test_malicious_branches() {
39        enum ErrorType {
40            LocalCumulativeSumFailing,
41            ConstraintsFailing,
42        }
43
44        struct BranchTestCase {
45            branch_opcode: Opcode,
46            branch_operand_b_value: u32,
47            branch_operand_c_value: u32,
48            incorrect_next_pc: u32,
49            error_type: ErrorType,
50        }
51
52        // The PC of the branch instruction is 8, and it will branch to 16 if the condition is true.
53        let branch_test_cases = vec![
54            BranchTestCase {
55                branch_opcode: Opcode::BEQ,
56                branch_operand_b_value: 5,
57                branch_operand_c_value: 5,
58                incorrect_next_pc: 12, // Correct next PC is 16.
59                error_type: ErrorType::LocalCumulativeSumFailing,
60            },
61            BranchTestCase {
62                branch_opcode: Opcode::BEQ,
63                branch_operand_b_value: 5,
64                branch_operand_c_value: 3,
65                incorrect_next_pc: 16, // Correct next PC is 12.
66                error_type: ErrorType::ConstraintsFailing,
67            },
68            BranchTestCase {
69                branch_opcode: Opcode::BNE,
70                branch_operand_b_value: 5,
71                branch_operand_c_value: 5,
72                incorrect_next_pc: 16, // Correct next PC is 12.
73                error_type: ErrorType::ConstraintsFailing,
74            },
75            BranchTestCase {
76                branch_opcode: Opcode::BNE,
77                branch_operand_b_value: 5,
78                branch_operand_c_value: 3,
79                incorrect_next_pc: 12, // Correct next PC is 16.
80                error_type: ErrorType::LocalCumulativeSumFailing,
81            },
82            BranchTestCase {
83                branch_opcode: Opcode::BLTU,
84                branch_operand_b_value: 5,
85                branch_operand_c_value: 3,
86                incorrect_next_pc: 16, // Correct next PC is 12.
87                error_type: ErrorType::ConstraintsFailing,
88            },
89            BranchTestCase {
90                branch_opcode: Opcode::BLTU,
91                branch_operand_b_value: 3,
92                branch_operand_c_value: 5,
93                incorrect_next_pc: 12, // Correct next PC is 16.
94                error_type: ErrorType::LocalCumulativeSumFailing,
95            },
96            BranchTestCase {
97                branch_opcode: Opcode::BLT,
98                branch_operand_b_value: 0xFFFF_FFFF, // This is -1.
99                branch_operand_c_value: 3,
100                incorrect_next_pc: 12, // Correct next PC is 16.
101                error_type: ErrorType::LocalCumulativeSumFailing,
102            },
103            BranchTestCase {
104                branch_opcode: Opcode::BLT,
105                branch_operand_b_value: 3,
106                branch_operand_c_value: 0xFFFF_FFFF, // This is -1.
107                incorrect_next_pc: 16,               // Correct next PC is 12.
108                error_type: ErrorType::ConstraintsFailing,
109            },
110            BranchTestCase {
111                branch_opcode: Opcode::BGEU,
112                branch_operand_b_value: 3,
113                branch_operand_c_value: 5,
114                incorrect_next_pc: 16, // Correct next PC is 12.
115                error_type: ErrorType::ConstraintsFailing,
116            },
117            BranchTestCase {
118                branch_opcode: Opcode::BGEU,
119                branch_operand_b_value: 5,
120                branch_operand_c_value: 5,
121                incorrect_next_pc: 12, // Correct next PC is 16.
122                error_type: ErrorType::LocalCumulativeSumFailing,
123            },
124            BranchTestCase {
125                branch_opcode: Opcode::BGEU,
126                branch_operand_b_value: 5,
127                branch_operand_c_value: 3,
128                incorrect_next_pc: 12, // Correct next PC is 16.
129                error_type: ErrorType::LocalCumulativeSumFailing,
130            },
131            BranchTestCase {
132                branch_opcode: Opcode::BGE,
133                branch_operand_b_value: 0xFFFF_FFFF, // This is -1.
134                branch_operand_c_value: 5,
135                incorrect_next_pc: 16, // Correct next PC is 12.
136                error_type: ErrorType::ConstraintsFailing,
137            },
138            BranchTestCase {
139                branch_opcode: Opcode::BGE,
140                branch_operand_b_value: 5,
141                branch_operand_c_value: 5,
142                incorrect_next_pc: 12, // Correct next PC is 16.
143                error_type: ErrorType::LocalCumulativeSumFailing,
144            },
145            BranchTestCase {
146                branch_opcode: Opcode::BGE,
147                branch_operand_b_value: 3,
148                branch_operand_c_value: 0xFFFF_FFFF, // This is -1.
149                incorrect_next_pc: 12,               // Correct next PC is 16.
150                error_type: ErrorType::LocalCumulativeSumFailing,
151            },
152        ];
153
154        for test_case in branch_test_cases {
155            let instructions = vec![
156                Instruction::new(Opcode::ADD, 29, 0, test_case.branch_operand_b_value, false, true),
157                Instruction::new(Opcode::ADD, 30, 0, test_case.branch_operand_c_value, false, true),
158                Instruction::new(test_case.branch_opcode, 29, 30, 8, false, true),
159                Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
160                Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
161            ];
162            let program = Program::new(instructions, 0, 0);
163            let stdin = SP1Stdin::new();
164
165            type P = CpuProver<BabyBearPoseidon2, RiscvAir<BabyBear>>;
166
167            let malicious_trace_pv_generator =
168                move |prover: &P,
169                      record: &mut ExecutionRecord|
170                      -> Vec<(String, RowMajorMatrix<Val<BabyBearPoseidon2>>)> {
171                    // Create a malicious record where the BEQ instruction branches incorrectly.
172                    let mut malicious_record = record.clone();
173                    malicious_record.branch_events[0].next_pc = test_case.incorrect_next_pc;
174                    prover.generate_traces(&malicious_record)
175                };
176
177            let result =
178                run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
179
180            match test_case.error_type {
181                ErrorType::LocalCumulativeSumFailing => {
182                    assert!(
183                        result.is_err() && result.unwrap_err().is_local_cumulative_sum_failing()
184                    );
185                }
186                ErrorType::ConstraintsFailing => {
187                    let branch_chip_name = chip_name!(BranchChip, BabyBear);
188                    assert!(
189                        result.is_err() &&
190                            result.unwrap_err().is_constraints_failing(&branch_chip_name)
191                    );
192                }
193            }
194        }
195    }
196
197    #[test]
198    fn test_malicious_multiple_opcode_flags() {
199        let instructions = vec![
200            Instruction::new(Opcode::ADD, 29, 0, 5, false, true),
201            Instruction::new(Opcode::ADD, 30, 0, 5, false, true),
202            Instruction::new(Opcode::BEQ, 29, 30, 8, false, true),
203            Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
204            Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
205        ];
206        let program = Program::new(instructions, 0, 0);
207        let stdin = SP1Stdin::new();
208
209        type P = CpuProver<BabyBearPoseidon2, RiscvAir<BabyBear>>;
210
211        let malicious_trace_pv_generator =
212            |prover: &P,
213             record: &mut ExecutionRecord|
214             -> Vec<(String, RowMajorMatrix<Val<BabyBearPoseidon2>>)> {
215                // Modify the branch chip to have a row that has multiple opcode flags set.
216                let mut traces = prover.generate_traces(record);
217                let branch_chip_name = chip_name!(BranchChip, BabyBear);
218                for (chip_name, trace) in traces.iter_mut() {
219                    if *chip_name == branch_chip_name {
220                        let first_row = trace.row_mut(0);
221                        let first_row: &mut BranchColumns<BabyBear> = first_row.borrow_mut();
222                        assert!(first_row.is_beq == BabyBear::one());
223                        first_row.is_bne = BabyBear::one();
224                    }
225                }
226                traces
227            };
228
229        let result =
230            run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
231        let branch_chip_name = chip_name!(BranchChip, BabyBear);
232        assert!(result.is_err() && result.unwrap_err().is_constraints_failing(&branch_chip_name));
233    }
234}