Skip to main content

sp1_core_machine/control_flow/branch/
air.rs

1use std::borrow::Borrow;
2
3use slop_air::{Air, AirBuilder};
4use slop_algebra::{AbstractField, Field};
5use slop_matrix::Matrix;
6
7use crate::{
8    adapter::{
9        register::i_type::{ITypeReaderImmutable, ITypeReaderImmutableInput},
10        state::{CPUState, CPUStateInput},
11    },
12    air::{SP1CoreAirBuilder, SP1Operation},
13    eval_untrusted_program,
14    operations::{LtOperationSigned, LtOperationSignedInput},
15    TrustMode, UserMode,
16};
17use sp1_core_executor::{ByteOpcode, Opcode, CLK_INC, PC_INC};
18
19use super::{BranchChip, BranchColumns};
20
21/// Verifies all the branching related columns.
22///
23/// It does this in few parts:
24/// 1. It verifies that the next pc is correct based on the branching column.  That column is a
25///    boolean that indicates whether the branch condition is true.
26/// 2. It verifies the correct value of branching based on the opcode and the comparison operation.
27impl<AB, M> Air<AB> for BranchChip<M>
28where
29    AB: SP1CoreAirBuilder,
30    AB::Var: Sized,
31    M: TrustMode,
32{
33    #[inline(never)]
34    fn eval(&self, builder: &mut AB) {
35        let main = builder.main();
36        let local = main.row_slice(0);
37        let local: &BranchColumns<AB::Var, M> = (*local).borrow();
38
39        // SAFETY: All selectors `is_beq`, `is_bne`, `is_blt`, `is_bge`, `is_bltu`, `is_bgeu` are
40        // checked to be boolean. Each "real" row has exactly one selector turned on, as
41        // `is_real`, the sum of the six selectors, is boolean. Therefore, the `opcode`
42        // matches the corresponding opcode.
43        builder.assert_bool(local.is_beq);
44        builder.assert_bool(local.is_bne);
45        builder.assert_bool(local.is_blt);
46        builder.assert_bool(local.is_bge);
47        builder.assert_bool(local.is_bltu);
48        builder.assert_bool(local.is_bgeu);
49        let is_real = local.is_beq
50            + local.is_bne
51            + local.is_blt
52            + local.is_bge
53            + local.is_bltu
54            + local.is_bgeu;
55        builder.assert_bool(is_real.clone());
56
57        let opcode = local.is_beq * Opcode::BEQ.as_field::<AB::F>()
58            + local.is_bne * Opcode::BNE.as_field::<AB::F>()
59            + local.is_blt * Opcode::BLT.as_field::<AB::F>()
60            + local.is_bge * Opcode::BGE.as_field::<AB::F>()
61            + local.is_bltu * Opcode::BLTU.as_field::<AB::F>()
62            + local.is_bgeu * Opcode::BGEU.as_field::<AB::F>();
63
64        // Compute instruction field constants for each opcode
65        let funct3 = local.is_beq * AB::Expr::from_canonical_u8(Opcode::BEQ.funct3().unwrap())
66            + local.is_bne * AB::Expr::from_canonical_u8(Opcode::BNE.funct3().unwrap())
67            + local.is_blt * AB::Expr::from_canonical_u8(Opcode::BLT.funct3().unwrap())
68            + local.is_bge * AB::Expr::from_canonical_u8(Opcode::BGE.funct3().unwrap())
69            + local.is_bltu * AB::Expr::from_canonical_u8(Opcode::BLTU.funct3().unwrap())
70            + local.is_bgeu * AB::Expr::from_canonical_u8(Opcode::BGEU.funct3().unwrap());
71        let funct7 = local.is_beq * AB::Expr::from_canonical_u8(Opcode::BEQ.funct7().unwrap_or(0))
72            + local.is_bne * AB::Expr::from_canonical_u8(Opcode::BNE.funct7().unwrap_or(0))
73            + local.is_blt * AB::Expr::from_canonical_u8(Opcode::BLT.funct7().unwrap_or(0))
74            + local.is_bge * AB::Expr::from_canonical_u8(Opcode::BGE.funct7().unwrap_or(0))
75            + local.is_bltu * AB::Expr::from_canonical_u8(Opcode::BLTU.funct7().unwrap_or(0))
76            + local.is_bgeu * AB::Expr::from_canonical_u8(Opcode::BGEU.funct7().unwrap_or(0));
77        let base_opcode = local.is_beq * AB::Expr::from_canonical_u32(Opcode::BEQ.base_opcode().0)
78            + local.is_bne * AB::Expr::from_canonical_u32(Opcode::BNE.base_opcode().0)
79            + local.is_blt * AB::Expr::from_canonical_u32(Opcode::BLT.base_opcode().0)
80            + local.is_bge * AB::Expr::from_canonical_u32(Opcode::BGE.base_opcode().0)
81            + local.is_bltu * AB::Expr::from_canonical_u32(Opcode::BLTU.base_opcode().0)
82            + local.is_bgeu * AB::Expr::from_canonical_u32(Opcode::BGEU.base_opcode().0);
83        let instr_type = local.is_beq
84            * AB::Expr::from_canonical_u32(Opcode::BEQ.instruction_type().0 as u32)
85            + local.is_bne * AB::Expr::from_canonical_u32(Opcode::BNE.instruction_type().0 as u32)
86            + local.is_blt * AB::Expr::from_canonical_u32(Opcode::BLT.instruction_type().0 as u32)
87            + local.is_bge * AB::Expr::from_canonical_u32(Opcode::BGE.instruction_type().0 as u32)
88            + local.is_bltu
89                * AB::Expr::from_canonical_u32(Opcode::BLTU.instruction_type().0 as u32)
90            + local.is_bgeu
91                * AB::Expr::from_canonical_u32(Opcode::BGEU.instruction_type().0 as u32);
92
93        // Constrain the state of the CPU.
94        // The `next_pc` is constrained by the AIR.
95        // The clock is incremented by `8`.
96        <CPUState<AB::F> as SP1Operation<AB>>::eval(
97            builder,
98            CPUStateInput::new(
99                local.state,
100                local.next_pc.map(Into::into),
101                AB::Expr::from_canonical_u32(CLK_INC),
102                is_real.clone(),
103            ),
104        );
105
106        let mut is_trusted: AB::Expr = is_real.clone();
107
108        #[cfg(feature = "mprotect")]
109        builder.assert_eq(
110            builder.extract_public_values().is_untrusted_programs_enabled,
111            AB::Expr::from_bool(!M::IS_TRUSTED),
112        );
113
114        if !M::IS_TRUSTED {
115            let local = main.row_slice(0);
116            let local: &BranchColumns<AB::Var, UserMode> = (*local).borrow();
117
118            let instruction = local.adapter.instruction::<AB>(opcode.clone());
119
120            #[cfg(not(feature = "mprotect"))]
121            builder.assert_zero(is_real.clone());
122
123            eval_untrusted_program(
124                builder,
125                local.state.pc,
126                instruction,
127                [instr_type, base_opcode, funct3, funct7],
128                [local.state.clk_high::<AB>(), local.state.clk_low::<AB>()],
129                is_real.clone(),
130                local.adapter_cols,
131            );
132
133            is_trusted = local.adapter_cols.is_trusted.into();
134        }
135
136        // Constrain the program and register reads.
137        <ITypeReaderImmutable as SP1Operation<AB>>::eval(
138            builder,
139            ITypeReaderImmutableInput::new(
140                local.state.clk_high::<AB>(),
141                local.state.clk_low::<AB>(),
142                local.state.pc,
143                opcode,
144                local.adapter,
145                is_real.clone(),
146                is_trusted,
147            ),
148        );
149
150        // SAFETY: `use_signed_comparison` is boolean, since at most one selector is turned on.
151        let use_signed_comparison = local.is_blt + local.is_bge;
152        <LtOperationSigned<AB::F> as SP1Operation<AB>>::eval(
153            builder,
154            LtOperationSignedInput::<AB>::new(
155                local.adapter.prev_a().map(Into::into),
156                local.adapter.b().map(Into::into),
157                local.compare_operation,
158                use_signed_comparison.clone(),
159                is_real.clone(),
160            ),
161        );
162
163        // From the `LtOperationSigned`, derive whether `a == b`, `a < b`, or `a > b`.
164        let is_eq = AB::Expr::one()
165            - (local.compare_operation.result.u16_flags[0]
166                + local.compare_operation.result.u16_flags[1]
167                + local.compare_operation.result.u16_flags[2]
168                + local.compare_operation.result.u16_flags[3]);
169        let is_less_than = local.compare_operation.result.u16_compare_operation.bit;
170
171        // Constrain the branching column with the comparison results and opcode flags.
172        let mut branching: AB::Expr = AB::Expr::zero();
173        branching = branching.clone() + local.is_beq * is_eq.clone();
174        branching = branching.clone() + local.is_bne * (AB::Expr::one() - is_eq);
175        branching =
176            branching.clone() + (local.is_bge + local.is_bgeu) * (AB::Expr::one() - is_less_than);
177        branching = branching.clone() + (local.is_blt + local.is_bltu) * is_less_than;
178
179        builder.assert_bool(local.is_branching);
180        builder.when(is_real.clone()).assert_eq(local.is_branching, branching.clone());
181
182        // Constrain the next_pc using the branching column.
183        // Show that if `is_branching` is true, then next_pc == pc + op_c
184        // Show that if `is_branching` is false, then next_pc == pc + 4
185        let base_inverse = AB::F::from_canonical_u32(1 << 16).inverse();
186        let mut carry = AB::Expr::zero();
187        for i in 0..4 {
188            let pc = if i < 3 { local.state.pc[i].into() } else { AB::Expr::zero() };
189            let next_pc = if i < 3 { local.next_pc[i].into() } else { AB::Expr::zero() };
190            carry = (carry.clone() + pc + local.adapter.c()[i] - next_pc) * base_inverse;
191            builder.when(local.is_branching).assert_bool(carry.clone());
192        }
193
194        let mut carry = AB::Expr::zero();
195        for i in 0..4 {
196            let pc = if i < 3 { local.state.pc[i].into() } else { AB::Expr::zero() };
197            let next_pc = if i < 3 { local.next_pc[i].into() } else { AB::Expr::zero() };
198            let increment =
199                if i == 0 { AB::Expr::from_canonical_u32(PC_INC) } else { AB::Expr::zero() };
200            carry = (carry.clone() + pc + increment - next_pc) * base_inverse;
201            builder.when(is_real.clone() - local.is_branching).assert_bool(carry.clone());
202        }
203
204        // Check that the `next_pc` value is a multiple of 4.
205        builder.send_byte(
206            AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
207            local.next_pc[0].into() * AB::F::from_canonical_u32(4).inverse(),
208            AB::Expr::from_canonical_u32(14),
209            AB::Expr::zero(),
210            is_real.clone(),
211        );
212        builder.slice_range_check_u16(&local.next_pc[1..3], is_real);
213    }
214}