sp1_core_machine/control_flow/branch/air.rs
1use std::borrow::Borrow;
2
3use p3_air::{Air, AirBuilder};
4use p3_field::AbstractField;
5use p3_matrix::Matrix;
6use sp1_core_executor::{Opcode, DEFAULT_PC_INC, UNUSED_PC};
7use sp1_stark::{
8 air::{BaseAirBuilder, SP1AirBuilder},
9 Word,
10};
11
12use crate::{air::WordAirBuilder, operations::BabyBearWordRangeChecker};
13
14use super::{BranchChip, BranchColumns};
15
16/// Verifies all the branching related columns.
17///
18/// It does this in few parts:
19/// 1. It verifies that the next pc is correct based on the branching column. That column is a
20/// boolean that indicates whether the branch condition is true.
21/// 2. It verifies the correct value of branching based on the helper bool columns (a_eq_b, a_gt_b,
22/// a_lt_b).
23/// 3. It verifier the correct values of the helper bool columns based on op_a and op_b.
24impl<AB> Air<AB> for BranchChip
25where
26 AB: SP1AirBuilder,
27 AB::Var: Sized,
28{
29 #[inline(never)]
30 fn eval(&self, builder: &mut AB) {
31 let main = builder.main();
32 let local = main.row_slice(0);
33 let local: &BranchColumns<AB::Var> = (*local).borrow();
34
35 // SAFETY: All selectors `is_beq`, `is_bne`, `is_blt`, `is_bge`, `is_bltu`, `is_bgeu` are
36 // checked to be boolean. Each "real" row has exactly one selector turned on, as
37 // `is_real`, the sum of the six selectors, is boolean. Therefore, the `opcode`
38 // matches the corresponding opcode.
39 builder.assert_bool(local.is_beq);
40 builder.assert_bool(local.is_bne);
41 builder.assert_bool(local.is_blt);
42 builder.assert_bool(local.is_bge);
43 builder.assert_bool(local.is_bltu);
44 builder.assert_bool(local.is_bgeu);
45 let is_real = local.is_beq +
46 local.is_bne +
47 local.is_blt +
48 local.is_bge +
49 local.is_bltu +
50 local.is_bgeu;
51 builder.assert_bool(is_real.clone());
52
53 let opcode = local.is_beq * Opcode::BEQ.as_field::<AB::F>() +
54 local.is_bne * Opcode::BNE.as_field::<AB::F>() +
55 local.is_blt * Opcode::BLT.as_field::<AB::F>() +
56 local.is_bge * Opcode::BGE.as_field::<AB::F>() +
57 local.is_bltu * Opcode::BLTU.as_field::<AB::F>() +
58 local.is_bgeu * Opcode::BGEU.as_field::<AB::F>();
59
60 // SAFETY: This checks the following.
61 // - `num_extra_cycles = 0`
62 // - `op_a_val` will be constrained in the CpuChip as `op_a_immutable = 1`
63 // - `op_a_immutable = 1`, as this is a branch instruction
64 // - `is_memory = 0`
65 // - `is_syscall = 0`
66 // - `is_halt = 0`
67 // `next_pc` still has to be constrained, and this is done below.
68 builder.receive_instruction(
69 AB::Expr::zero(),
70 AB::Expr::zero(),
71 local.pc.reduce::<AB>(),
72 local.next_pc.reduce::<AB>(),
73 AB::Expr::zero(),
74 opcode,
75 local.op_a_value,
76 local.op_b_value,
77 local.op_c_value,
78 local.op_a_0,
79 AB::Expr::one(),
80 AB::Expr::zero(),
81 AB::Expr::zero(),
82 AB::Expr::zero(),
83 is_real.clone(),
84 );
85
86 // Evaluate program counter constraints.
87 {
88 // Range check branch_cols.pc and branch_cols.next_pc.
89 // SAFETY: `is_real` is already checked to be boolean.
90 // The `BabyBearWordRangeChecker` assumes that the value is checked to be a valid word.
91 // This is done when the word form is relevant, i.e. when `pc` and `next_pc` are sent to
92 // the ADD ALU table. The ADD ALU table checks the inputs are valid words,
93 // when it invokes `AddOperation`.
94 BabyBearWordRangeChecker::<AB::F>::range_check(
95 builder,
96 local.pc,
97 local.pc_range_checker,
98 is_real.clone(),
99 );
100 BabyBearWordRangeChecker::<AB::F>::range_check(
101 builder,
102 local.next_pc,
103 local.next_pc_range_checker,
104 is_real.clone(),
105 );
106
107 // When we are branching, assert that local.next_pc <==> local.pc + c.
108 builder.send_instruction(
109 AB::Expr::zero(),
110 AB::Expr::zero(),
111 AB::Expr::from_canonical_u32(UNUSED_PC),
112 AB::Expr::from_canonical_u32(UNUSED_PC + DEFAULT_PC_INC),
113 AB::Expr::zero(),
114 Opcode::ADD.as_field::<AB::F>(),
115 local.next_pc,
116 local.pc,
117 local.op_c_value,
118 AB::Expr::zero(),
119 AB::Expr::zero(),
120 AB::Expr::zero(),
121 AB::Expr::zero(),
122 AB::Expr::zero(),
123 local.is_branching,
124 );
125
126 // When we are not branching, assert that local.pc + 4 <==> next.pc.
127 builder.when(is_real.clone()).when(local.not_branching).assert_eq(
128 local.pc.reduce::<AB>() + AB::Expr::from_canonical_u32(DEFAULT_PC_INC),
129 local.next_pc.reduce::<AB>(),
130 );
131
132 // When local.not_branching is true, assert that local.is_real is true.
133 builder.when(local.not_branching).assert_one(is_real.clone());
134
135 // To prevent the ALU send above to be non-zero when the row is a padding row.
136 builder.when_not(is_real.clone()).assert_zero(local.is_branching);
137
138 // Assert that either we are branching or not branching when the instruction is a
139 // branch.
140 // The `next_pc` is constrained in both branching and not branching cases, so it is
141 // fully constrained.
142 builder.when(is_real.clone()).assert_one(local.is_branching + local.not_branching);
143 builder.when(is_real.clone()).assert_bool(local.is_branching);
144 builder.when(is_real.clone()).assert_bool(local.not_branching);
145 }
146
147 // Evaluate branching value constraints.
148 {
149 // When the opcode is BEQ and we are branching, assert that a_eq_b is true.
150 builder.when(local.is_beq * local.is_branching).assert_one(local.a_eq_b);
151
152 // When the opcode is BEQ and we are not branching, assert that either a_gt_b or a_lt_b
153 // is true.
154 builder
155 .when(local.is_beq)
156 .when_not(local.is_branching)
157 .assert_one(local.a_gt_b + local.a_lt_b);
158
159 // When the opcode is BNE and we are branching, assert that either a_gt_b or a_lt_b is
160 // true.
161 builder.when(local.is_bne * local.is_branching).assert_one(local.a_gt_b + local.a_lt_b);
162
163 // When the opcode is BNE and we are not branching, assert that a_eq_b is true.
164 builder.when(local.is_bne).when_not(local.is_branching).assert_one(local.a_eq_b);
165
166 // When the opcode is BLT or BLTU and we are branching, assert that a_lt_b is true.
167 builder
168 .when((local.is_blt + local.is_bltu) * local.is_branching)
169 .assert_one(local.a_lt_b);
170
171 // When the opcode is BLT or BLTU and we are not branching, assert that either a_eq_b
172 // or a_gt_b is true.
173 builder
174 .when(local.is_blt + local.is_bltu)
175 .when_not(local.is_branching)
176 .assert_one(local.a_eq_b + local.a_gt_b);
177
178 // When the opcode is BGE or BGEU and we are branching, assert that a_gt_b is true.
179 builder
180 .when((local.is_bge + local.is_bgeu) * local.is_branching)
181 .assert_one(local.a_gt_b + local.a_eq_b);
182
183 // When the opcode is BGE or BGEU and we are not branching, assert that either a_eq_b
184 // or a_lt_b is true.
185 builder
186 .when(local.is_bge + local.is_bgeu)
187 .when_not(local.is_branching)
188 .assert_one(local.a_lt_b);
189 }
190
191 // When it's a branch instruction and a_eq_b, assert that a == b.
192 builder
193 .when(is_real.clone() * local.a_eq_b)
194 .assert_word_eq(local.op_a_value, local.op_b_value);
195
196 // Calculate a_lt_b <==> a < b (using appropriate signedness).
197 // SAFETY: `use_signed_comparison` is boolean, since at most one selector is turned on.
198 let use_signed_comparison = local.is_blt + local.is_bge;
199 builder.send_instruction(
200 AB::Expr::zero(),
201 AB::Expr::zero(),
202 AB::Expr::from_canonical_u32(UNUSED_PC),
203 AB::Expr::from_canonical_u32(UNUSED_PC + DEFAULT_PC_INC),
204 AB::Expr::zero(),
205 use_signed_comparison.clone() * Opcode::SLT.as_field::<AB::F>() +
206 (AB::Expr::one() - use_signed_comparison.clone()) *
207 Opcode::SLTU.as_field::<AB::F>(),
208 Word::extend_var::<AB>(local.a_lt_b),
209 local.op_a_value,
210 local.op_b_value,
211 AB::Expr::zero(),
212 AB::Expr::zero(),
213 AB::Expr::zero(),
214 AB::Expr::zero(),
215 AB::Expr::zero(),
216 is_real.clone(),
217 );
218
219 // Calculate a_gt_b <==> a > b (using appropriate signedness).
220 builder.send_instruction(
221 AB::Expr::zero(),
222 AB::Expr::zero(),
223 AB::Expr::from_canonical_u32(UNUSED_PC),
224 AB::Expr::from_canonical_u32(UNUSED_PC + DEFAULT_PC_INC),
225 AB::Expr::zero(),
226 use_signed_comparison.clone() * Opcode::SLT.as_field::<AB::F>() +
227 (AB::Expr::one() - use_signed_comparison) * Opcode::SLTU.as_field::<AB::F>(),
228 Word::extend_var::<AB>(local.a_gt_b),
229 local.op_b_value,
230 local.op_a_value,
231 AB::Expr::zero(),
232 AB::Expr::zero(),
233 AB::Expr::zero(),
234 AB::Expr::zero(),
235 AB::Expr::zero(),
236 is_real.clone(),
237 );
238 }
239}