sp1_core_machine/control_flow/branch/
air.rs

1use std::borrow::Borrow;
2
3use p3_air::{Air, AirBuilder};
4use p3_field::AbstractField;
5use p3_matrix::Matrix;
6use sp1_core_executor::{Opcode, DEFAULT_PC_INC, UNUSED_PC};
7use sp1_stark::{
8    air::{BaseAirBuilder, SP1AirBuilder},
9    Word,
10};
11
12use crate::{air::WordAirBuilder, operations::BabyBearWordRangeChecker};
13
14use super::{BranchChip, BranchColumns};
15
16/// Verifies all the branching related columns.
17///
18/// It does this in few parts:
19/// 1. It verifies that the next pc is correct based on the branching column.  That column is a
20///    boolean that indicates whether the branch condition is true.
21/// 2. It verifies the correct value of branching based on the helper bool columns (a_eq_b, a_gt_b,
22///    a_lt_b).
23/// 3. It verifier the correct values of the helper bool columns based on op_a and op_b.
24impl<AB> Air<AB> for BranchChip
25where
26    AB: SP1AirBuilder,
27    AB::Var: Sized,
28{
29    #[inline(never)]
30    fn eval(&self, builder: &mut AB) {
31        let main = builder.main();
32        let local = main.row_slice(0);
33        let local: &BranchColumns<AB::Var> = (*local).borrow();
34
35        // SAFETY: All selectors `is_beq`, `is_bne`, `is_blt`, `is_bge`, `is_bltu`, `is_bgeu` are
36        // checked to be boolean. Each "real" row has exactly one selector turned on, as
37        // `is_real`, the sum of the six selectors, is boolean. Therefore, the `opcode`
38        // matches the corresponding opcode.
39        builder.assert_bool(local.is_beq);
40        builder.assert_bool(local.is_bne);
41        builder.assert_bool(local.is_blt);
42        builder.assert_bool(local.is_bge);
43        builder.assert_bool(local.is_bltu);
44        builder.assert_bool(local.is_bgeu);
45        let is_real = local.is_beq +
46            local.is_bne +
47            local.is_blt +
48            local.is_bge +
49            local.is_bltu +
50            local.is_bgeu;
51        builder.assert_bool(is_real.clone());
52
53        let opcode = local.is_beq * Opcode::BEQ.as_field::<AB::F>() +
54            local.is_bne * Opcode::BNE.as_field::<AB::F>() +
55            local.is_blt * Opcode::BLT.as_field::<AB::F>() +
56            local.is_bge * Opcode::BGE.as_field::<AB::F>() +
57            local.is_bltu * Opcode::BLTU.as_field::<AB::F>() +
58            local.is_bgeu * Opcode::BGEU.as_field::<AB::F>();
59
60        // SAFETY: This checks the following.
61        // - `num_extra_cycles = 0`
62        // - `op_a_val` will be constrained in the CpuChip as `op_a_immutable = 1`
63        // - `op_a_immutable = 1`, as this is a branch instruction
64        // - `is_memory = 0`
65        // - `is_syscall = 0`
66        // - `is_halt = 0`
67        // `next_pc` still has to be constrained, and this is done below.
68        builder.receive_instruction(
69            AB::Expr::zero(),
70            AB::Expr::zero(),
71            local.pc.reduce::<AB>(),
72            local.next_pc.reduce::<AB>(),
73            AB::Expr::zero(),
74            opcode,
75            local.op_a_value,
76            local.op_b_value,
77            local.op_c_value,
78            local.op_a_0,
79            AB::Expr::one(),
80            AB::Expr::zero(),
81            AB::Expr::zero(),
82            AB::Expr::zero(),
83            is_real.clone(),
84        );
85
86        // Evaluate program counter constraints.
87        {
88            // Range check branch_cols.pc and branch_cols.next_pc.
89            // SAFETY: `is_real` is already checked to be boolean.
90            // The `BabyBearWordRangeChecker` assumes that the value is checked to be a valid word.
91            // This is done when the word form is relevant, i.e. when `pc` and `next_pc` are sent to
92            // the ADD ALU table. The ADD ALU table checks the inputs are valid words,
93            // when it invokes `AddOperation`.
94            BabyBearWordRangeChecker::<AB::F>::range_check(
95                builder,
96                local.pc,
97                local.pc_range_checker,
98                is_real.clone(),
99            );
100            BabyBearWordRangeChecker::<AB::F>::range_check(
101                builder,
102                local.next_pc,
103                local.next_pc_range_checker,
104                is_real.clone(),
105            );
106
107            // When we are branching, assert that local.next_pc <==> local.pc + c.
108            builder.send_instruction(
109                AB::Expr::zero(),
110                AB::Expr::zero(),
111                AB::Expr::from_canonical_u32(UNUSED_PC),
112                AB::Expr::from_canonical_u32(UNUSED_PC + DEFAULT_PC_INC),
113                AB::Expr::zero(),
114                Opcode::ADD.as_field::<AB::F>(),
115                local.next_pc,
116                local.pc,
117                local.op_c_value,
118                AB::Expr::zero(),
119                AB::Expr::zero(),
120                AB::Expr::zero(),
121                AB::Expr::zero(),
122                AB::Expr::zero(),
123                local.is_branching,
124            );
125
126            // When we are not branching, assert that local.pc + 4 <==> next.pc.
127            builder.when(is_real.clone()).when(local.not_branching).assert_eq(
128                local.pc.reduce::<AB>() + AB::Expr::from_canonical_u32(DEFAULT_PC_INC),
129                local.next_pc.reduce::<AB>(),
130            );
131
132            // When local.not_branching is true, assert that local.is_real is true.
133            builder.when(local.not_branching).assert_one(is_real.clone());
134
135            // To prevent the ALU send above to be non-zero when the row is a padding row.
136            builder.when_not(is_real.clone()).assert_zero(local.is_branching);
137
138            // Assert that either we are branching or not branching when the instruction is a
139            // branch.
140            // The `next_pc` is constrained in both branching and not branching cases, so it is
141            // fully constrained.
142            builder.when(is_real.clone()).assert_one(local.is_branching + local.not_branching);
143            builder.when(is_real.clone()).assert_bool(local.is_branching);
144            builder.when(is_real.clone()).assert_bool(local.not_branching);
145        }
146
147        // Evaluate branching value constraints.
148        {
149            // When the opcode is BEQ and we are branching, assert that a_eq_b is true.
150            builder.when(local.is_beq * local.is_branching).assert_one(local.a_eq_b);
151
152            // When the opcode is BEQ and we are not branching, assert that either a_gt_b or a_lt_b
153            // is true.
154            builder
155                .when(local.is_beq)
156                .when_not(local.is_branching)
157                .assert_one(local.a_gt_b + local.a_lt_b);
158
159            // When the opcode is BNE and we are branching, assert that either a_gt_b or a_lt_b is
160            // true.
161            builder.when(local.is_bne * local.is_branching).assert_one(local.a_gt_b + local.a_lt_b);
162
163            // When the opcode is BNE and we are not branching, assert that a_eq_b is true.
164            builder.when(local.is_bne).when_not(local.is_branching).assert_one(local.a_eq_b);
165
166            // When the opcode is BLT or BLTU and we are branching, assert that a_lt_b is true.
167            builder
168                .when((local.is_blt + local.is_bltu) * local.is_branching)
169                .assert_one(local.a_lt_b);
170
171            // When the opcode is BLT or BLTU and we are not branching, assert that either a_eq_b
172            // or a_gt_b is true.
173            builder
174                .when(local.is_blt + local.is_bltu)
175                .when_not(local.is_branching)
176                .assert_one(local.a_eq_b + local.a_gt_b);
177
178            // When the opcode is BGE or BGEU and we are branching, assert that a_gt_b is true.
179            builder
180                .when((local.is_bge + local.is_bgeu) * local.is_branching)
181                .assert_one(local.a_gt_b + local.a_eq_b);
182
183            // When the opcode is BGE or BGEU and we are not branching, assert that either a_eq_b
184            // or a_lt_b is true.
185            builder
186                .when(local.is_bge + local.is_bgeu)
187                .when_not(local.is_branching)
188                .assert_one(local.a_lt_b);
189        }
190
191        // When it's a branch instruction and a_eq_b, assert that a == b.
192        builder
193            .when(is_real.clone() * local.a_eq_b)
194            .assert_word_eq(local.op_a_value, local.op_b_value);
195
196        // Calculate a_lt_b <==> a < b (using appropriate signedness).
197        // SAFETY: `use_signed_comparison` is boolean, since at most one selector is turned on.
198        let use_signed_comparison = local.is_blt + local.is_bge;
199        builder.send_instruction(
200            AB::Expr::zero(),
201            AB::Expr::zero(),
202            AB::Expr::from_canonical_u32(UNUSED_PC),
203            AB::Expr::from_canonical_u32(UNUSED_PC + DEFAULT_PC_INC),
204            AB::Expr::zero(),
205            use_signed_comparison.clone() * Opcode::SLT.as_field::<AB::F>() +
206                (AB::Expr::one() - use_signed_comparison.clone()) *
207                    Opcode::SLTU.as_field::<AB::F>(),
208            Word::extend_var::<AB>(local.a_lt_b),
209            local.op_a_value,
210            local.op_b_value,
211            AB::Expr::zero(),
212            AB::Expr::zero(),
213            AB::Expr::zero(),
214            AB::Expr::zero(),
215            AB::Expr::zero(),
216            is_real.clone(),
217        );
218
219        // Calculate a_gt_b <==> a > b (using appropriate signedness).
220        builder.send_instruction(
221            AB::Expr::zero(),
222            AB::Expr::zero(),
223            AB::Expr::from_canonical_u32(UNUSED_PC),
224            AB::Expr::from_canonical_u32(UNUSED_PC + DEFAULT_PC_INC),
225            AB::Expr::zero(),
226            use_signed_comparison.clone() * Opcode::SLT.as_field::<AB::F>() +
227                (AB::Expr::one() - use_signed_comparison) * Opcode::SLTU.as_field::<AB::F>(),
228            Word::extend_var::<AB>(local.a_gt_b),
229            local.op_b_value,
230            local.op_a_value,
231            AB::Expr::zero(),
232            AB::Expr::zero(),
233            AB::Expr::zero(),
234            AB::Expr::zero(),
235            AB::Expr::zero(),
236            is_real.clone(),
237        );
238    }
239}