Skip to main content

sp1_core_machine/program/
trusted.rs

1use core::{
2    borrow::{Borrow, BorrowMut},
3    mem::{size_of, MaybeUninit},
4};
5use std::collections::HashMap;
6
7use crate::{air::ProgramAirBuilder, program::InstructionCols, utils::next_multiple_of_32};
8use slop_air::{Air, BaseAir, PairBuilder};
9use slop_algebra::PrimeField32;
10use slop_matrix::Matrix;
11use slop_maybe_rayon::prelude::{ParallelBridge, ParallelIterator};
12use sp1_core_executor::{ExecutionRecord, Program};
13use sp1_derive::AlignedBorrow;
14use sp1_hypercube::air::{MachineAir, SP1AirBuilder};
15
16/// The number of preprocessed program columns.
17pub const NUM_PROGRAM_PREPROCESSED_COLS: usize = size_of::<ProgramPreprocessedCols<u8>>();
18
19/// The number of columns for the program multiplicities.
20pub const NUM_PROGRAM_MULT_COLS: usize = size_of::<ProgramMultiplicityCols<u8>>();
21
22/// The column layout for the chip.
23#[derive(AlignedBorrow, Clone, Copy, Default)]
24#[repr(C)]
25pub struct ProgramPreprocessedCols<T> {
26    pub pc: [T; 3],
27    pub instruction: InstructionCols<T>,
28}
29
30/// The column layout for the chip.
31#[derive(AlignedBorrow, Clone, Copy, Default)]
32#[repr(C)]
33pub struct ProgramMultiplicityCols<T> {
34    pub multiplicity: T,
35}
36
37/// A chip that implements addition for the opcodes ADD and ADDI.
38#[derive(Default)]
39pub struct ProgramChip;
40
41impl ProgramChip {
42    pub const fn new() -> Self {
43        Self {}
44    }
45}
46
47impl<F: PrimeField32> MachineAir<F> for ProgramChip {
48    type Record = ExecutionRecord;
49
50    type Program = Program;
51
52    fn name(&self) -> &'static str {
53        "Program"
54    }
55
56    fn preprocessed_width(&self) -> usize {
57        NUM_PROGRAM_PREPROCESSED_COLS
58    }
59
60    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
61        let nb_rows = input.program.instructions.len();
62        let size_log2 = input.fixed_log2_rows::<F, _>(self);
63        let padded_nb_rows = next_multiple_of_32(nb_rows, size_log2);
64        Some(padded_nb_rows)
65    }
66
67    fn preprocessed_num_rows(&self, program: &Self::Program) -> Option<usize> {
68        let instrs_len = program.instructions.len();
69        Some(next_multiple_of_32(instrs_len, None))
70    }
71
72    fn preprocessed_num_rows_with_instrs_len(
73        &self,
74        _program: &Self::Program,
75        instrs_len: usize,
76    ) -> Option<usize> {
77        Some(next_multiple_of_32(instrs_len, None))
78    }
79
80    fn generate_preprocessed_trace_into(
81        &self,
82        program: &Self::Program,
83        buffer: &mut [MaybeUninit<F>],
84    ) {
85        debug_assert!(
86            !program.instructions.is_empty() || program.preprocessed_shape.is_some(),
87            "empty program"
88        );
89        // Generate the trace rows for each event.
90        let nb_rows = program.instructions.len();
91        let size_log2 = program.fixed_log2_rows::<F, _>(self);
92        let padded_nb_rows = next_multiple_of_32(nb_rows, size_log2);
93        assert!(matches!(
94            padded_nb_rows.checked_mul(4),
95            Some(last_idx) if last_idx < F::ORDER_U64 as usize,
96        ));
97        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
98        let values = unsafe {
99            core::slice::from_raw_parts_mut(
100                buffer_ptr,
101                padded_nb_rows * NUM_PROGRAM_PREPROCESSED_COLS,
102            )
103        };
104        let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1);
105
106        values
107            .chunks_mut(chunk_size * NUM_PROGRAM_PREPROCESSED_COLS)
108            .enumerate()
109            .par_bridge()
110            .for_each(|(i, rows)| {
111                rows.chunks_mut(NUM_PROGRAM_PREPROCESSED_COLS).enumerate().for_each(|(j, row)| {
112                    let mut idx = i * chunk_size + j;
113                    if idx >= nb_rows {
114                        idx = 0;
115                    }
116                    let cols: &mut ProgramPreprocessedCols<F> = row.borrow_mut();
117                    let pc = program.pc_base + idx as u64 * 4;
118                    assert!(pc < (1 << 48));
119                    cols.pc = [
120                        F::from_canonical_u16((pc & 0xFFFF) as u16),
121                        F::from_canonical_u16(((pc >> 16) & 0xFFFF) as u16),
122                        F::from_canonical_u16(((pc >> 32) & 0xFFFF) as u16),
123                    ];
124                    let instruction = program.instructions[idx];
125                    cols.instruction.populate(&instruction);
126                });
127            });
128    }
129
130    fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) {
131        // Do nothing since this chip has no dependencies.
132    }
133
134    fn generate_trace_into(
135        &self,
136        input: &ExecutionRecord,
137        _output: &mut ExecutionRecord,
138        buffer: &mut [MaybeUninit<F>],
139    ) {
140        // Generate the trace rows for each event.
141
142        // Collect the number of times each instruction is called from the cpu events.
143        // Store it as a map of PC -> count.
144        let mut instruction_counts = HashMap::new();
145        input.add_events.iter().for_each(|event| {
146            let pc = event.0.pc;
147            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
148        });
149        input.addw_events.iter().for_each(|event| {
150            let pc = event.0.pc;
151            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
152        });
153        input.addi_events.iter().for_each(|event| {
154            let pc = event.0.pc;
155            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
156        });
157        input.sub_events.iter().for_each(|event| {
158            let pc = event.0.pc;
159            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
160        });
161        input.subw_events.iter().for_each(|event| {
162            let pc = event.0.pc;
163            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
164        });
165        input.bitwise_events.iter().for_each(|event| {
166            let pc = event.0.pc;
167            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
168        });
169        input.mul_events.iter().for_each(|event| {
170            let pc = event.0.pc;
171            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
172        });
173        input.divrem_events.iter().for_each(|event| {
174            let pc = event.0.pc;
175            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
176        });
177        input.lt_events.iter().for_each(|event| {
178            let pc = event.0.pc;
179            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
180        });
181        input.shift_left_events.iter().for_each(|event| {
182            let pc = event.0.pc;
183            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
184        });
185        input.shift_right_events.iter().for_each(|event| {
186            let pc = event.0.pc;
187            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
188        });
189        input.branch_events.iter().for_each(|event| {
190            let pc = event.0.pc;
191            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
192        });
193        input.memory_load_byte_events.iter().for_each(|event| {
194            let pc = event.0.pc;
195            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
196        });
197        input.memory_load_half_events.iter().for_each(|event| {
198            let pc = event.0.pc;
199            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
200        });
201        input.memory_load_word_events.iter().for_each(|event| {
202            let pc = event.0.pc;
203            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
204        });
205        input.memory_load_x0_events.iter().for_each(|event| {
206            let pc = event.0.pc;
207            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
208        });
209        input.memory_load_double_events.iter().for_each(|event| {
210            let pc = event.0.pc;
211            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
212        });
213        input.memory_store_byte_events.iter().for_each(|event| {
214            let pc = event.0.pc;
215            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
216        });
217        input.memory_store_half_events.iter().for_each(|event| {
218            let pc = event.0.pc;
219            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
220        });
221        input.memory_store_word_events.iter().for_each(|event| {
222            let pc = event.0.pc;
223            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
224        });
225        input.memory_store_double_events.iter().for_each(|event| {
226            let pc = event.0.pc;
227            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
228        });
229        input.jal_events.iter().for_each(|event| {
230            let pc = event.0.pc;
231            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
232        });
233        input.jalr_events.iter().for_each(|event| {
234            let pc = event.0.pc;
235            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
236        });
237        input.utype_events.iter().for_each(|event| {
238            let pc = event.0.pc;
239            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
240        });
241        input.syscall_events.iter().for_each(|event| {
242            let pc = event.0.pc;
243            instruction_counts.entry(pc).and_modify(|count| *count += 1).or_insert(1);
244        });
245
246        // Note: The program table should only count trusted (i.e. not untrusted instructions.)
247        // However, untrusted instructions are also included in the events vectors.
248        // Intuitively this would cause a mismatch where the program table tries to receive
249        // additional interactions due to thes untrusted instruction events. In reality, there is no
250        // issue because rows are created over the program instructions which do not include
251        // untrusted instructions, and the address space for program instructions are
252        // protected and will never intersect with the address space for untrusted
253        // instructions.
254
255        let padded_nb_rows = <ProgramChip as MachineAir<F>>::num_rows(self, input).unwrap();
256        let nb_instructions = input.program.instructions.len();
257
258        unsafe {
259            let padding_start = nb_instructions * NUM_PROGRAM_MULT_COLS;
260            let padding_size = (padded_nb_rows - nb_instructions) * NUM_PROGRAM_MULT_COLS;
261            if padding_size > 0 {
262                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
263            }
264        }
265
266        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
267        let values = unsafe {
268            core::slice::from_raw_parts_mut(buffer_ptr, nb_instructions * NUM_PROGRAM_MULT_COLS)
269        };
270
271        let chunk_size = std::cmp::max(nb_instructions / num_cpus::get(), 1);
272
273        values.chunks_mut(chunk_size * NUM_PROGRAM_MULT_COLS).enumerate().par_bridge().for_each(
274            |(i, rows)| {
275                rows.chunks_mut(NUM_PROGRAM_MULT_COLS).enumerate().for_each(|(j, row)| {
276                    let idx = i * chunk_size + j;
277                    if idx < nb_instructions {
278                        let pc = input.program.pc_base + idx as u64 * 4;
279                        let cols: &mut ProgramMultiplicityCols<F> = row.borrow_mut();
280                        cols.multiplicity =
281                            F::from_canonical_usize(*instruction_counts.get(&pc).unwrap_or(&0));
282                    }
283                });
284            },
285        );
286    }
287
288    fn included(&self, _: &Self::Record) -> bool {
289        true
290    }
291}
292
293impl<F> BaseAir<F> for ProgramChip {
294    fn width(&self) -> usize {
295        NUM_PROGRAM_MULT_COLS
296    }
297}
298
299impl<AB> Air<AB> for ProgramChip
300where
301    AB: SP1AirBuilder + PairBuilder,
302{
303    fn eval(&self, builder: &mut AB) {
304        let main = builder.main();
305        let preprocessed = builder.preprocessed();
306
307        let prep_local = preprocessed.row_slice(0);
308        let prep_local: &ProgramPreprocessedCols<AB::Var> = (*prep_local).borrow();
309        let mult_local = main.row_slice(0);
310        let mult_local: &ProgramMultiplicityCols<AB::Var> = (*mult_local).borrow();
311
312        // Constrain the interaction with CPU table
313        builder.receive_program(prep_local.pc, prep_local.instruction, mult_local.multiplicity);
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    #![allow(clippy::print_stdout)]
320
321    use std::sync::Arc;
322
323    use sp1_primitives::SP1Field;
324
325    use slop_matrix::dense::RowMajorMatrix;
326    use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program};
327    use sp1_hypercube::air::MachineAir;
328
329    use crate::program::ProgramChip;
330
331    #[test]
332    fn generate_trace() {
333        // main:
334        //     addi x29, x0, 5
335        //     addi x30, x0, 37
336        //     add x31, x30, x29
337        let instructions = vec![
338            Instruction::new(Opcode::ADDI, 29, 0, 5, false, true),
339            Instruction::new(Opcode::ADDI, 30, 0, 37, false, true),
340            Instruction::new(Opcode::ADD, 31, 30, 29, false, false),
341        ];
342        let shard = ExecutionRecord {
343            program: Arc::new(Program::new(instructions, 0, 0)),
344            ..Default::default()
345        };
346        let chip = ProgramChip::new();
347        let trace: RowMajorMatrix<SP1Field> =
348            chip.generate_trace(&shard, &mut ExecutionRecord::default());
349        println!("{:?}", trace.values)
350    }
351}