Skip to main content

sp1_core_machine/alu/add_sub/
sub.rs

1use core::{
2    borrow::{Borrow, BorrowMut},
3    mem::{size_of, MaybeUninit},
4};
5use std::marker::PhantomData;
6
7use hashbrown::HashMap;
8use itertools::Itertools;
9use slop_air::{Air, BaseAir};
10use slop_algebra::{AbstractField, PrimeField, PrimeField32};
11use slop_matrix::Matrix;
12use slop_maybe_rayon::prelude::{ParallelBridge, ParallelIterator};
13use sp1_core_executor::{
14    events::{AluEvent, ByteLookupEvent, ByteRecord},
15    ExecutionRecord, Opcode, Program, CLK_INC, PC_INC,
16};
17use sp1_derive::AlignedBorrow;
18use sp1_hypercube::air::MachineAir;
19use struct_reflection::{StructReflection, StructReflectionHelper};
20
21use crate::{
22    adapter::{
23        register::r_type::{RTypeReader, RTypeReaderInput},
24        state::{CPUState, CPUStateInput},
25    },
26    air::{SP1CoreAirBuilder, SP1Operation},
27    eval_untrusted_program,
28    operations::{SubOperation, SubOperationInput},
29    utils::next_multiple_of_32,
30    SupervisorMode, TrustMode, UserMode,
31};
32
33/// The number of main trace columns for `SubChip` in Supervisor mode.
34pub const NUM_SUB_COLS_SUPERVISOR: usize = size_of::<SubCols<u8, SupervisorMode>>();
35/// The number of main trace columns for `SubChip` in User mode.
36pub const NUM_SUB_COLS_USER: usize = size_of::<SubCols<u8, UserMode>>();
37
38/// A chip that implements subtraction for the opcode SUB.
39#[derive(Default)]
40pub struct SubChip<M: TrustMode> {
41    pub _phantom: PhantomData<M>,
42}
43
44/// The column layout for the chip.
45#[derive(AlignedBorrow, StructReflection, Default, Clone, Copy)]
46#[repr(C)]
47pub struct SubCols<T, M: TrustMode> {
48    /// The current shard, timestamp, program counter of the CPU.
49    pub state: CPUState<T>,
50
51    /// The adapter to read program and register information.
52    pub adapter: RTypeReader<T>,
53
54    /// Instance of `SubOperation` to handle subtraction logic in `SubChip`'s ALU operations.
55    pub sub_operation: SubOperation<T>,
56
57    /// Boolean to indicate whether the row is not a padding row.
58    pub is_real: T,
59
60    /// Adapter columns for trust mode specific data.
61    pub adapter_cols: M::AdapterCols<T>,
62}
63
64impl<F: PrimeField32, M: TrustMode> MachineAir<F> for SubChip<M> {
65    type Record = ExecutionRecord;
66
67    type Program = Program;
68
69    fn name(&self) -> &'static str {
70        if M::IS_TRUSTED {
71            "Sub"
72        } else {
73            "SubUser"
74        }
75    }
76
77    fn column_names(&self) -> Vec<String> {
78        SubCols::<F, M>::struct_reflection().unwrap()
79    }
80
81    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
82        if input.program.enable_untrusted_programs == M::IS_TRUSTED {
83            return Some(0);
84        }
85        let nb_rows =
86            next_multiple_of_32(input.sub_events.len(), input.fixed_log2_rows::<F, _>(self));
87        Some(nb_rows)
88    }
89
90    fn generate_trace_into(
91        &self,
92        input: &ExecutionRecord,
93        _output: &mut ExecutionRecord,
94        buffer: &mut [MaybeUninit<F>],
95    ) {
96        if input.program.enable_untrusted_programs == M::IS_TRUSTED {
97            return;
98        }
99
100        // Generate the rows for the trace.
101        let chunk_size = std::cmp::max(input.sub_events.len() / num_cpus::get(), 1);
102        let padded_nb_rows = <SubChip<M> as MachineAir<F>>::num_rows(self, input).unwrap();
103        let num_event_rows = input.sub_events.len();
104        let width = <SubChip<M> as BaseAir<F>>::width(self);
105
106        unsafe {
107            let padding_start = num_event_rows * width;
108            let padding_size = (padded_nb_rows - num_event_rows) * width;
109            if padding_size > 0 {
110                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
111            }
112        }
113
114        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
115        let values = unsafe { core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * width) };
116
117        values.chunks_mut(chunk_size * width).enumerate().par_bridge().for_each(|(i, rows)| {
118            rows.chunks_mut(width).enumerate().for_each(|(j, row)| {
119                let idx = i * chunk_size + j;
120                let cols: &mut SubCols<F, M> = row.borrow_mut();
121
122                if idx < input.sub_events.len() {
123                    let mut byte_lookup_events = Vec::new();
124                    let event = input.sub_events[idx];
125                    self.event_to_row(&event.0, cols, &mut byte_lookup_events);
126                    cols.state.populate(&mut byte_lookup_events, event.0.clk, event.0.pc);
127                    cols.adapter.populate(&mut byte_lookup_events, event.1);
128                    if !M::IS_TRUSTED {
129                        let cols: &mut SubCols<F, UserMode> = row.borrow_mut();
130                        cols.adapter_cols.is_trusted = F::from_bool(!event.1.is_untrusted);
131                    }
132                }
133            });
134        });
135    }
136
137    fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
138        if input.program.enable_untrusted_programs == M::IS_TRUSTED {
139            return;
140        }
141
142        let chunk_size = std::cmp::max(input.sub_events.len() / num_cpus::get(), 1);
143        let event_iter = input.sub_events.chunks(chunk_size);
144        let width = <SubChip<M> as BaseAir<F>>::width(self);
145
146        let blu_batches = event_iter
147            .par_bridge()
148            .map(|events| {
149                let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
150                events.iter().for_each(|event| {
151                    let mut row = vec![F::zero(); width];
152                    let cols: &mut SubCols<F, M> = row.as_mut_slice().borrow_mut();
153                    self.event_to_row(&event.0, cols, &mut blu);
154                    cols.state.populate(&mut blu, event.0.clk, event.0.pc);
155                    cols.adapter.populate(&mut blu, event.1);
156                });
157                blu
158            })
159            .collect::<Vec<_>>();
160
161        output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
162    }
163
164    fn included(&self, shard: &Self::Record) -> bool {
165        if let Some(shape) = shard.shape.as_ref() {
166            shape.included::<F, _>(self)
167        } else {
168            !shard.sub_events.is_empty()
169                && (M::IS_TRUSTED != shard.program.enable_untrusted_programs)
170        }
171    }
172}
173
174impl<M: TrustMode> SubChip<M> {
175    /// Create a row from an event.
176    fn event_to_row<F: PrimeField>(
177        &self,
178        event: &AluEvent,
179        cols: &mut SubCols<F, M>,
180        blu: &mut impl ByteRecord,
181    ) {
182        cols.is_real = F::one();
183        cols.sub_operation.populate(blu, event.b, event.c);
184    }
185}
186
187impl<F, M: TrustMode> BaseAir<F> for SubChip<M> {
188    fn width(&self) -> usize {
189        if M::IS_TRUSTED {
190            NUM_SUB_COLS_SUPERVISOR
191        } else {
192            NUM_SUB_COLS_USER
193        }
194    }
195}
196
197impl<AB, M> Air<AB> for SubChip<M>
198where
199    AB: SP1CoreAirBuilder,
200    M: TrustMode,
201{
202    fn eval(&self, builder: &mut AB) {
203        let main = builder.main();
204        let local = main.row_slice(0);
205        let local: &SubCols<AB::Var, M> = (*local).borrow();
206
207        builder.assert_bool(local.is_real);
208
209        let opcode = AB::Expr::from_f(Opcode::SUB.as_field());
210        let funct3 = AB::Expr::from_canonical_u8(Opcode::SUB.funct3().unwrap());
211        let funct7 = AB::Expr::from_canonical_u8(Opcode::SUB.funct7().unwrap());
212        let base_opcode = AB::Expr::from_canonical_u32(Opcode::SUB.base_opcode().0);
213        let instr_type = AB::Expr::from_canonical_u32(Opcode::SUB.instruction_type().0 as u32);
214
215        // This chip is for the case `rd != x0`.
216        builder.assert_zero(local.adapter.op_a_0);
217
218        // Constrain the sub operation over `op_b` and `op_c`.
219        let op_input = SubOperationInput::<AB>::new(
220            *local.adapter.b(),
221            *local.adapter.c(),
222            local.sub_operation,
223            local.is_real.into(),
224        );
225        <SubOperation<AB::F> as SP1Operation<AB>>::eval(builder, op_input);
226
227        // Constrain the state of the CPU.
228        // The program counter and timestamp increment by `4` and `8`.
229        <CPUState<AB::F> as SP1Operation<AB>>::eval(
230            builder,
231            CPUStateInput {
232                cols: local.state,
233                next_pc: [
234                    local.state.pc[0] + AB::F::from_canonical_u32(PC_INC),
235                    local.state.pc[1].into(),
236                    local.state.pc[2].into(),
237                ],
238                clk_increment: AB::Expr::from_canonical_u32(CLK_INC),
239                is_real: local.is_real.into(),
240            },
241        );
242
243        let mut is_trusted: AB::Expr = local.is_real.into();
244
245        #[cfg(feature = "mprotect")]
246        builder.assert_eq(
247            builder.extract_public_values().is_untrusted_programs_enabled,
248            AB::Expr::from_bool(!M::IS_TRUSTED),
249        );
250
251        if !M::IS_TRUSTED {
252            let local = main.row_slice(0);
253            let local: &SubCols<AB::Var, UserMode> = (*local).borrow();
254
255            let instruction = local.adapter.instruction::<AB>(opcode.clone());
256
257            #[cfg(not(feature = "mprotect"))]
258            builder.assert_zero(local.is_real);
259
260            eval_untrusted_program(
261                builder,
262                local.state.pc,
263                instruction,
264                [instr_type, base_opcode, funct3, funct7],
265                [local.state.clk_high::<AB>(), local.state.clk_low::<AB>()],
266                local.is_real.into(),
267                local.adapter_cols,
268            );
269
270            is_trusted = local.adapter_cols.is_trusted.into();
271        }
272
273        // Constrain the program and register reads.
274        <RTypeReader<AB::F> as SP1Operation<AB>>::eval(
275            builder,
276            RTypeReaderInput {
277                clk_high: local.state.clk_high::<AB>(),
278                clk_low: local.state.clk_low::<AB>(),
279                pc: local.state.pc,
280                opcode,
281                op_a_write_value: local.sub_operation.value.map(|x| x.into()),
282                cols: local.adapter,
283                is_real: local.is_real.into(),
284                is_trusted,
285            },
286        );
287    }
288}