sp1_core_machine/control_flow/branch/
air.rs1use std::borrow::Borrow;
2
3use slop_air::{Air, AirBuilder};
4use slop_algebra::{AbstractField, Field};
5use slop_matrix::Matrix;
6
7use crate::{
8 adapter::{
9 register::i_type::{ITypeReaderImmutable, ITypeReaderImmutableInput},
10 state::{CPUState, CPUStateInput},
11 },
12 air::{SP1CoreAirBuilder, SP1Operation},
13 eval_untrusted_program,
14 operations::{LtOperationSigned, LtOperationSignedInput},
15 TrustMode, UserMode,
16};
17use sp1_core_executor::{ByteOpcode, Opcode, CLK_INC, PC_INC};
18
19use super::{BranchChip, BranchColumns};
20
21impl<AB, M> Air<AB> for BranchChip<M>
28where
29 AB: SP1CoreAirBuilder,
30 AB::Var: Sized,
31 M: TrustMode,
32{
33 #[inline(never)]
34 fn eval(&self, builder: &mut AB) {
35 let main = builder.main();
36 let local = main.row_slice(0);
37 let local: &BranchColumns<AB::Var, M> = (*local).borrow();
38
39 builder.assert_bool(local.is_beq);
44 builder.assert_bool(local.is_bne);
45 builder.assert_bool(local.is_blt);
46 builder.assert_bool(local.is_bge);
47 builder.assert_bool(local.is_bltu);
48 builder.assert_bool(local.is_bgeu);
49 let is_real = local.is_beq
50 + local.is_bne
51 + local.is_blt
52 + local.is_bge
53 + local.is_bltu
54 + local.is_bgeu;
55 builder.assert_bool(is_real.clone());
56
57 let opcode = local.is_beq * Opcode::BEQ.as_field::<AB::F>()
58 + local.is_bne * Opcode::BNE.as_field::<AB::F>()
59 + local.is_blt * Opcode::BLT.as_field::<AB::F>()
60 + local.is_bge * Opcode::BGE.as_field::<AB::F>()
61 + local.is_bltu * Opcode::BLTU.as_field::<AB::F>()
62 + local.is_bgeu * Opcode::BGEU.as_field::<AB::F>();
63
64 let funct3 = local.is_beq * AB::Expr::from_canonical_u8(Opcode::BEQ.funct3().unwrap())
66 + local.is_bne * AB::Expr::from_canonical_u8(Opcode::BNE.funct3().unwrap())
67 + local.is_blt * AB::Expr::from_canonical_u8(Opcode::BLT.funct3().unwrap())
68 + local.is_bge * AB::Expr::from_canonical_u8(Opcode::BGE.funct3().unwrap())
69 + local.is_bltu * AB::Expr::from_canonical_u8(Opcode::BLTU.funct3().unwrap())
70 + local.is_bgeu * AB::Expr::from_canonical_u8(Opcode::BGEU.funct3().unwrap());
71 let funct7 = local.is_beq * AB::Expr::from_canonical_u8(Opcode::BEQ.funct7().unwrap_or(0))
72 + local.is_bne * AB::Expr::from_canonical_u8(Opcode::BNE.funct7().unwrap_or(0))
73 + local.is_blt * AB::Expr::from_canonical_u8(Opcode::BLT.funct7().unwrap_or(0))
74 + local.is_bge * AB::Expr::from_canonical_u8(Opcode::BGE.funct7().unwrap_or(0))
75 + local.is_bltu * AB::Expr::from_canonical_u8(Opcode::BLTU.funct7().unwrap_or(0))
76 + local.is_bgeu * AB::Expr::from_canonical_u8(Opcode::BGEU.funct7().unwrap_or(0));
77 let base_opcode = local.is_beq * AB::Expr::from_canonical_u32(Opcode::BEQ.base_opcode().0)
78 + local.is_bne * AB::Expr::from_canonical_u32(Opcode::BNE.base_opcode().0)
79 + local.is_blt * AB::Expr::from_canonical_u32(Opcode::BLT.base_opcode().0)
80 + local.is_bge * AB::Expr::from_canonical_u32(Opcode::BGE.base_opcode().0)
81 + local.is_bltu * AB::Expr::from_canonical_u32(Opcode::BLTU.base_opcode().0)
82 + local.is_bgeu * AB::Expr::from_canonical_u32(Opcode::BGEU.base_opcode().0);
83 let instr_type = local.is_beq
84 * AB::Expr::from_canonical_u32(Opcode::BEQ.instruction_type().0 as u32)
85 + local.is_bne * AB::Expr::from_canonical_u32(Opcode::BNE.instruction_type().0 as u32)
86 + local.is_blt * AB::Expr::from_canonical_u32(Opcode::BLT.instruction_type().0 as u32)
87 + local.is_bge * AB::Expr::from_canonical_u32(Opcode::BGE.instruction_type().0 as u32)
88 + local.is_bltu
89 * AB::Expr::from_canonical_u32(Opcode::BLTU.instruction_type().0 as u32)
90 + local.is_bgeu
91 * AB::Expr::from_canonical_u32(Opcode::BGEU.instruction_type().0 as u32);
92
93 <CPUState<AB::F> as SP1Operation<AB>>::eval(
97 builder,
98 CPUStateInput::new(
99 local.state,
100 local.next_pc.map(Into::into),
101 AB::Expr::from_canonical_u32(CLK_INC),
102 is_real.clone(),
103 ),
104 );
105
106 let mut is_trusted: AB::Expr = is_real.clone();
107
108 #[cfg(feature = "mprotect")]
109 builder.assert_eq(
110 builder.extract_public_values().is_untrusted_programs_enabled,
111 AB::Expr::from_bool(!M::IS_TRUSTED),
112 );
113
114 if !M::IS_TRUSTED {
115 let local = main.row_slice(0);
116 let local: &BranchColumns<AB::Var, UserMode> = (*local).borrow();
117
118 let instruction = local.adapter.instruction::<AB>(opcode.clone());
119
120 #[cfg(not(feature = "mprotect"))]
121 builder.assert_zero(is_real.clone());
122
123 eval_untrusted_program(
124 builder,
125 local.state.pc,
126 instruction,
127 [instr_type, base_opcode, funct3, funct7],
128 [local.state.clk_high::<AB>(), local.state.clk_low::<AB>()],
129 is_real.clone(),
130 local.adapter_cols,
131 );
132
133 is_trusted = local.adapter_cols.is_trusted.into();
134 }
135
136 <ITypeReaderImmutable as SP1Operation<AB>>::eval(
138 builder,
139 ITypeReaderImmutableInput::new(
140 local.state.clk_high::<AB>(),
141 local.state.clk_low::<AB>(),
142 local.state.pc,
143 opcode,
144 local.adapter,
145 is_real.clone(),
146 is_trusted,
147 ),
148 );
149
150 let use_signed_comparison = local.is_blt + local.is_bge;
152 <LtOperationSigned<AB::F> as SP1Operation<AB>>::eval(
153 builder,
154 LtOperationSignedInput::<AB>::new(
155 local.adapter.prev_a().map(Into::into),
156 local.adapter.b().map(Into::into),
157 local.compare_operation,
158 use_signed_comparison.clone(),
159 is_real.clone(),
160 ),
161 );
162
163 let is_eq = AB::Expr::one()
165 - (local.compare_operation.result.u16_flags[0]
166 + local.compare_operation.result.u16_flags[1]
167 + local.compare_operation.result.u16_flags[2]
168 + local.compare_operation.result.u16_flags[3]);
169 let is_less_than = local.compare_operation.result.u16_compare_operation.bit;
170
171 let mut branching: AB::Expr = AB::Expr::zero();
173 branching = branching.clone() + local.is_beq * is_eq.clone();
174 branching = branching.clone() + local.is_bne * (AB::Expr::one() - is_eq);
175 branching =
176 branching.clone() + (local.is_bge + local.is_bgeu) * (AB::Expr::one() - is_less_than);
177 branching = branching.clone() + (local.is_blt + local.is_bltu) * is_less_than;
178
179 builder.assert_bool(local.is_branching);
180 builder.when(is_real.clone()).assert_eq(local.is_branching, branching.clone());
181
182 let base_inverse = AB::F::from_canonical_u32(1 << 16).inverse();
186 let mut carry = AB::Expr::zero();
187 for i in 0..4 {
188 let pc = if i < 3 { local.state.pc[i].into() } else { AB::Expr::zero() };
189 let next_pc = if i < 3 { local.next_pc[i].into() } else { AB::Expr::zero() };
190 carry = (carry.clone() + pc + local.adapter.c()[i] - next_pc) * base_inverse;
191 builder.when(local.is_branching).assert_bool(carry.clone());
192 }
193
194 let mut carry = AB::Expr::zero();
195 for i in 0..4 {
196 let pc = if i < 3 { local.state.pc[i].into() } else { AB::Expr::zero() };
197 let next_pc = if i < 3 { local.next_pc[i].into() } else { AB::Expr::zero() };
198 let increment =
199 if i == 0 { AB::Expr::from_canonical_u32(PC_INC) } else { AB::Expr::zero() };
200 carry = (carry.clone() + pc + increment - next_pc) * base_inverse;
201 builder.when(is_real.clone() - local.is_branching).assert_bool(carry.clone());
202 }
203
204 builder.send_byte(
206 AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
207 local.next_pc[0].into() * AB::F::from_canonical_u32(4).inverse(),
208 AB::Expr::from_canonical_u32(14),
209 AB::Expr::zero(),
210 is_real.clone(),
211 );
212 builder.slice_range_check_u16(&local.next_pc[1..3], is_real);
213 }
214}