Skip to main content

sp1_core_machine/control_flow/branch/
mod.rs

1mod air;
2mod columns;
3mod trace;
4
5pub use columns::*;
6use slop_air::BaseAir;
7use std::marker::PhantomData;
8
9use crate::TrustMode;
10
11#[derive(Default)]
12pub struct BranchChip<M: TrustMode> {
13    pub _phantom: PhantomData<M>,
14}
15
16impl<F, M: TrustMode> BaseAir<F> for BranchChip<M> {
17    fn width(&self) -> usize {
18        if M::IS_TRUSTED {
19            NUM_BRANCH_COLS_SUPERVISOR
20        } else {
21            NUM_BRANCH_COLS_USER
22        }
23    }
24}
25/*
26#[cfg(test)]
27mod tests {
28    use std::borrow::BorrowMut;
29
30    use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program};
31    use sp1_hypercube::{
32        air::MachineAir, koala_bear_poseidon2::SP1InnerPcs, chip_name, CpuProver,
33        MachineProver, Val,
34    };
35
36//     use sp1_primitives::SP1Field;
37//     use slop_algebra::AbstractField;
38//     use slop_matrix::dense::RowMajorMatrix;
39//     use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program};
40//     use sp1_hypercube::{
41//         air::MachineAir, koala_bear_poseidon2::SP1InnerPcs, chip_name, CpuProver,
42//         MachineProver, Val,
43//     };
44
45    #[test]
46    fn test_malicious_branches() {
47        enum ErrorType {
48            // TODO: Re-enable when we LOGUP-GKR working.
49            // LocalCumulativeSumFailing,
50            ConstraintsFailing,
51        }
52
53        struct BranchTestCase {
54            branch_opcode: Opcode,
55            branch_operand_b_value: u32,
56            branch_operand_c_value: u32,
57            incorrect_next_pc: u64,
58            error_type: ErrorType,
59        }
60
61        // The PC of the branch instruction is 8, and it will branch to 16 if the condition is true.
62        let branch_test_cases = vec![
63            // TODO: Re-enable when we LOGUP-GKR working.
64            // BranchTestCase {
65            //     branch_opcode: Opcode::BEQ,
66            //     branch_operand_b_value: 5,
67            //     branch_operand_c_value: 5,
68            //     incorrect_next_pc: 12, // Correct next PC is 16.
69            //     error_type: ErrorType::LocalCumulativeSumFailing,
70            // },
71            BranchTestCase {
72                branch_opcode: Opcode::BEQ,
73                branch_operand_b_value: 5,
74                branch_operand_c_value: 3,
75                incorrect_next_pc: 16, // Correct next PC is 12.
76                error_type: ErrorType::ConstraintsFailing,
77            },
78            BranchTestCase {
79                branch_opcode: Opcode::BNE,
80                branch_operand_b_value: 5,
81                branch_operand_c_value: 5,
82                incorrect_next_pc: 16, // Correct next PC is 12.
83                error_type: ErrorType::ConstraintsFailing,
84            },
85            // TODO: Re-enable when we LOGUP-GKR working.
86            // BranchTestCase {
87            //     branch_opcode: Opcode::BNE,
88            //     branch_operand_b_value: 5,
89            //     branch_operand_c_value: 3,
90            //     incorrect_next_pc: 12, // Correct next PC is 16.
91            //     error_type: ErrorType::LocalCumulativeSumFailing,
92            // },
93            BranchTestCase {
94                branch_opcode: Opcode::BLTU,
95                branch_operand_b_value: 5,
96                branch_operand_c_value: 3,
97                incorrect_next_pc: 16, // Correct next PC is 12.
98                error_type: ErrorType::ConstraintsFailing,
99            },
100            // TODO: Re-enable when we LOGUP-GKR working.
101            // BranchTestCase {
102            //     branch_opcode: Opcode::BLTU,
103            //     branch_operand_b_value: 3,
104            //     branch_operand_c_value: 5,
105            //     incorrect_next_pc: 12, // Correct next PC is 16.
106            //     error_type: ErrorType::LocalCumulativeSumFailing,
107            // },
108            // TODO: Re-enable when we LOGUP-GKR working.
109            // BranchTestCase {
110            //     branch_opcode: Opcode::BLT,
111            //     branch_operand_b_value: 0xFFFF_FFFF, // This is -1.
112            //     branch_operand_c_value: 3,
113            //     incorrect_next_pc: 12, // Correct next PC is 16.
114            //     error_type: ErrorType::LocalCumulativeSumFailing,
115            // },
116            BranchTestCase {
117                branch_opcode: Opcode::BLT,
118                branch_operand_b_value: 3,
119                branch_operand_c_value: 0xFFFF_FFFF, // This is -1.
120                incorrect_next_pc: 16,               // Correct next PC is 12.
121                error_type: ErrorType::ConstraintsFailing,
122            },
123            BranchTestCase {
124                branch_opcode: Opcode::BGEU,
125                branch_operand_b_value: 3,
126                branch_operand_c_value: 5,
127                incorrect_next_pc: 16, // Correct next PC is 12.
128                error_type: ErrorType::ConstraintsFailing,
129            },
130            // TODO: Re-enable when we LOGUP-GKR working.
131            // BranchTestCase {
132            //     branch_opcode: Opcode::BGEU,
133            //     branch_operand_b_value: 5,
134            //     branch_operand_c_value: 5,
135            //     incorrect_next_pc: 12, // Correct next PC is 16.
136            //     error_type: ErrorType::LocalCumulativeSumFailing,
137            // },
138            // TODO: Re-enable when we LOGUP-GKR working.
139            // BranchTestCase {
140            //     branch_opcode: Opcode::BGEU,
141            //     branch_operand_b_value: 5,
142            //     branch_operand_c_value: 3,
143            //     incorrect_next_pc: 12, // Correct next PC is 16.
144            //     error_type: ErrorType::LocalCumulativeSumFailing,
145            // },
146            BranchTestCase {
147                branch_opcode: Opcode::BGE,
148                branch_operand_b_value: 0xFFFF_FFFF, // This is -1.
149                branch_operand_c_value: 5,
150                incorrect_next_pc: 16, // Correct next PC is 12.
151                error_type: ErrorType::ConstraintsFailing,
152            },
153            // TODO: Re-enable when we LOGUP-GKR working.
154            // BranchTestCase {
155            //     branch_opcode: Opcode::BGE,
156            //     branch_operand_b_value: 5,
157            //     branch_operand_c_value: 5,
158            //     incorrect_next_pc: 12, // Correct next PC is 16.
159            //     error_type: ErrorType::LocalCumulativeSumFailing,
160            // },
161            // TODO: Re-enable when we LOGUP-GKR working.
162            // BranchTestCase {
163            //     branch_opcode: Opcode::BGE,
164            //     branch_operand_b_value: 3,
165            //     branch_operand_c_value: 0xFFFF_FFFF, // This is -1.
166            //     incorrect_next_pc: 12,               // Correct next PC is 16.
167            //     error_type: ErrorType::LocalCumulativeSumFailing,
168            // },
169        ];
170
171        for test_case in branch_test_cases {
172            let instructions = vec![
173                Instruction::new(Opcode::ADD, 29, 0, test_case.branch_operand_b_value, false, true),
174                Instruction::new(Opcode::ADD, 30, 0, test_case.branch_operand_c_value, false, true),
175                Instruction::new(test_case.branch_opcode, 29, 30, 8, false, true),
176                Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
177                Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
178            ];
179            let program = Program::new(instructions, 0, 0);
180            let stdin = SP1Stdin::new();
181
182            type P = CpuProver<SP1InnerPcs, RiscvAir<SP1Field>>;
183
184            let malicious_trace_pv_generator =
185                move |prover: &P,
186                      record: &mut ExecutionRecord|
187                      -> Vec<(String, RowMajorMatrix<Val<SP1InnerPcs>>)> {
188                    // Create a malicious record where the BEQ instruction branches incorrectly.
189                    let mut malicious_record = record.clone();
190                    malicious_record.branch_events[0].next_pc = test_case.incorrect_next_pc;
191                    prover.generate_traces(&malicious_record)
192                };
193
194            let result =
195                run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
196
197            match test_case.error_type {
198                // TODO: Re-enable when we LOGUP-GKR working.
199                // ErrorType::LocalCumulativeSumFailing => {
200                //     assert!(
201                //         result.is_err() && result.unwrap_err().is_local_cumulative_sum_failing()
202                //     );
203                // }
204                ErrorType::ConstraintsFailing => {
205                    assert!(result.is_err() && result.unwrap_err().is_constraints_failing());
206                }
207            }
208        }
209    }
210
211    #[test]
212    fn test_malicious_multiple_opcode_flags() {
213        let instructions = vec![
214            Instruction::new(Opcode::ADD, 29, 0, 5, false, true),
215            Instruction::new(Opcode::ADD, 30, 0, 5, false, true),
216            Instruction::new(Opcode::BEQ, 29, 30, 8, false, true),
217            Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
218            Instruction::new(Opcode::ADD, 28, 0, 5, false, true),
219        ];
220        let program = Program::new(instructions, 0, 0);
221        let stdin = SP1Stdin::new();
222
223        type P = CpuProver<SP1InnerPcs, RiscvAir<SP1Field>>;
224
225        let malicious_trace_pv_generator =
226            |prover: &P,
227             record: &mut ExecutionRecord|
228             -> Vec<(String, RowMajorMatrix<Val<SP1InnerPcs>>)> {
229                // Modify the branch chip to have a row that has multiple opcode flags set.
230                let mut traces = prover.generate_traces(record);
231                let branch_chip_name = chip_name!(BranchChip, SP1Field);
232                for (chip_name, trace) in traces.iter_mut() {
233                    if *chip_name == branch_chip_name {
234                        let first_row = trace.row_mut(0);
235                        let first_row: &mut BranchColumns<SP1Field> = first_row.borrow_mut();
236                        assert!(first_row.is_beq == SP1Field::one());
237                        first_row.is_bne = SP1Field::one();
238                    }
239                }
240                traces
241            };
242
243        let result =
244            run_malicious_test::<P>(program, stdin, Box::new(malicious_trace_pv_generator));
245        assert!(result.is_err() && result.unwrap_err().is_constraints_failing());
246    }
247}
248 */