Skip to main content

sp1_core_machine/program/
instruction_fetch.rs

1use core::{
2    borrow::{Borrow, BorrowMut},
3    mem::{size_of, MaybeUninit},
4};
5
6use crate::{
7    air::SP1CoreAirBuilder, memory::MemoryAccessCols, program::InstructionCols,
8    utils::next_multiple_of_32,
9};
10use hashbrown::HashMap;
11use itertools::Itertools;
12use slop_air::{Air, AirBuilder, BaseAir};
13use slop_algebra::{AbstractField, Field, PrimeField, PrimeField32};
14use slop_matrix::Matrix;
15use sp1_core_executor::{
16    events::{ByteLookupEvent, ByteRecord, InstructionFetchEvent, MemoryAccessPosition},
17    ByteOpcode, ExecutionRecord, MemoryAccessRecord, Opcode, Program,
18};
19use sp1_derive::AlignedBorrow;
20use sp1_hypercube::air::MachineAir;
21use sp1_primitives::consts::{u64_to_u16_limbs, PROT_EXEC};
22
23/// The number of program columns.
24pub const NUM_INSTRUCTION_FETCH_COLS: usize = size_of::<InstructionFetchCols<u8>>();
25
26/// The column layout for the chip.
27#[derive(AlignedBorrow, Clone, Copy, Default)]
28#[repr(C)]
29pub struct InstructionFetchCols<T> {
30    pub clk_high: T,
31    pub clk_low: T,
32    pub pc: [T; 3],
33    /// This is used to check if the top two limbs of the `pc` is not both zero.
34    pub top_two_limb_inv: T,
35
36    pub instruction: InstructionCols<T>,
37    pub instr_type: T,
38    pub base_opcode: T,
39    pub funct3: T,
40    pub funct7: T,
41
42    pub memory_access: MemoryAccessCols<T>,
43    /// The selected 32 bits of read memory, in this case the 32 bit encoded instruction.
44    pub selected_word: [T; 2],
45    pub offset: T,
46    pub is_real: T,
47}
48
49/// A chip that implements instruction fetching from memory.
50#[derive(Default)]
51pub struct InstructionFetchChip;
52
53impl InstructionFetchChip {
54    pub const fn new() -> Self {
55        Self {}
56    }
57
58    fn event_to_row<F: PrimeField>(
59        &self,
60        event: &InstructionFetchEvent,
61        memory_access: &MemoryAccessRecord,
62        cols: &mut InstructionFetchCols<F>,
63    ) {
64        let instruction = event.instruction;
65        let (mem_access, encoded) = memory_access.untrusted_instruction.unwrap();
66        assert_eq!(encoded, event.encoded_instruction);
67
68        let pc = event.pc; // input.program.pc_base + event.pc as u64 * 4;
69        cols.pc = [
70            F::from_canonical_u16((pc & 0xFFFF) as u16),
71            F::from_canonical_u16(((pc >> 16) & 0xFFFF) as u16),
72            F::from_canonical_u16(((pc >> 32) & 0xFFFF) as u16),
73        ];
74
75        let sum_top_two_limb = cols.pc[1] + cols.pc[2];
76        cols.top_two_limb_inv = sum_top_two_limb.inverse();
77
78        let clk_high = (event.clk >> 24) as u32;
79        let clk_low = (event.clk & 0xFFFFFF) as u32;
80        cols.clk_high = F::from_canonical_u32(clk_high);
81        cols.clk_low = F::from_canonical_u32(clk_low);
82
83        if instruction.opcode != Opcode::UNIMP {
84            let (instr_type, instr_type_imm) = instruction.opcode.instruction_type();
85            cols.instr_type = if instr_type_imm.is_some() && instruction.imm_c {
86                F::from_canonical_u32(instr_type_imm.unwrap() as u32)
87            } else {
88                F::from_canonical_u32(instr_type as u32)
89            };
90            assert!(cols.instr_type != F::zero());
91
92            let (base_opcode, base_imm_opcode) = instruction.opcode.base_opcode();
93            cols.base_opcode = if base_imm_opcode.is_some() && instruction.imm_c {
94                F::from_canonical_u32(base_imm_opcode.unwrap())
95            } else {
96                F::from_canonical_u32(base_opcode)
97            };
98            let funct3 = instruction.opcode.funct3().unwrap_or(0);
99            let funct7 = instruction.opcode.funct7().unwrap_or(0);
100            cols.funct3 = F::from_canonical_u8(funct3);
101            cols.funct7 = F::from_canonical_u8(funct7);
102        }
103
104        // Offset indicates whether we want lower or upper 32 bits of the instruction
105        let offset = (pc / 4) % 2;
106        cols.offset = F::from_canonical_u8(offset as u8);
107
108        // Turn into 4 16 bit limbs
109        let limbs = u64_to_u16_limbs(mem_access.value());
110
111        // Select based on the offset either the first two or last two limbs
112        // Either 0 or 2
113        let limb_selector = 2 * offset;
114
115        // Note selected word is equivalent to the 32 bit encoded instruction
116        cols.selected_word = [
117            F::from_canonical_u16(limbs[limb_selector as usize]),
118            F::from_canonical_u16(limbs[limb_selector as usize + 1]),
119        ];
120
121        let instruction = event.instruction;
122        cols.instruction.populate(&instruction);
123
124        // Check that the encoded instruction is correct
125        let encoding_check = instruction.encode();
126        assert_eq!(event.encoded_instruction, encoding_check);
127
128        cols.is_real = F::one();
129    }
130}
131
132impl<F: PrimeField32> MachineAir<F> for InstructionFetchChip {
133    type Record = ExecutionRecord;
134
135    type Program = Program;
136
137    fn name(&self) -> &'static str {
138        "InstructionFetch"
139    }
140
141    fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
142        let mut blu_batches = Vec::new();
143        for full_event in input.instruction_fetch_events.iter() {
144            let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
145            let mut row = [F::zero(); NUM_INSTRUCTION_FETCH_COLS];
146            let cols: &mut InstructionFetchCols<F> = row.as_mut_slice().borrow_mut();
147            let (event, memory_access) = full_event;
148            let (mem_access, encoded) = memory_access.untrusted_instruction.unwrap();
149            assert_eq!(encoded, event.encoded_instruction);
150            cols.memory_access.populate(mem_access, &mut blu);
151            let pc = event.pc;
152
153            let pc_0 = (pc & 0xFFFF) as u16;
154            let pc_1 = ((pc >> 16) & 0xFFFF) as u16;
155            let pc_2 = ((pc >> 32) & 0xFFFF) as u16;
156            blu.add_u16_range_checks(&[pc_0, pc_1, pc_2]);
157
158            self.event_to_row(event, memory_access, cols);
159
160            let offset: u16 = cols.offset.as_canonical_u32().try_into().unwrap();
161
162            blu.add_bit_range_check((pc_0 - 4 * offset) / 8, 13);
163
164            blu_batches.push(blu);
165        }
166
167        output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
168    }
169
170    fn generate_trace_into(
171        &self,
172        input: &ExecutionRecord,
173        _output: &mut ExecutionRecord,
174        buffer: &mut [MaybeUninit<F>],
175    ) {
176        let padded_nb_rows =
177            <InstructionFetchChip as MachineAir<F>>::num_rows(self, input).unwrap();
178        let num_event_rows = input.instruction_fetch_events.len();
179
180        unsafe {
181            let padding_start = num_event_rows * NUM_INSTRUCTION_FETCH_COLS;
182            let padding_size = (padded_nb_rows - num_event_rows) * NUM_INSTRUCTION_FETCH_COLS;
183            if padding_size > 0 {
184                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
185            }
186        }
187
188        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
189        let values = unsafe {
190            core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * NUM_INSTRUCTION_FETCH_COLS)
191        };
192
193        let chunk_size = std::cmp::max(input.instruction_fetch_events.len() / num_cpus::get(), 1);
194
195        values.chunks_mut(chunk_size * NUM_INSTRUCTION_FETCH_COLS).enumerate().for_each(
196            |(i, rows)| {
197                rows.chunks_mut(NUM_INSTRUCTION_FETCH_COLS).enumerate().for_each(|(j, row)| {
198                    let idx = i * chunk_size + j;
199                    let cols: &mut InstructionFetchCols<F> = row.borrow_mut();
200                    let (event, memory_access) = &input.instruction_fetch_events[idx];
201
202                    let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
203                    let (mem_access, encoded) = memory_access.untrusted_instruction.unwrap();
204                    assert_eq!(encoded, event.encoded_instruction);
205                    assert!(mem_access.current_record().timestamp == event.clk);
206
207                    cols.memory_access.populate(mem_access, &mut blu);
208                    self.event_to_row(event, memory_access, cols);
209                });
210            },
211        );
212    }
213
214    fn included(&self, shard: &Self::Record) -> bool {
215        if let Some(shape) = shard.shape.as_ref() {
216            shape.included::<F, _>(self)
217        } else {
218            !shard.instruction_fetch_events.is_empty()
219        }
220    }
221
222    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
223        let nb_rows = next_multiple_of_32(
224            input.instruction_fetch_events.len(),
225            input.fixed_log2_rows::<F, _>(self),
226        );
227
228        Some(nb_rows)
229    }
230}
231
232impl<F> BaseAir<F> for InstructionFetchChip {
233    fn width(&self) -> usize {
234        NUM_INSTRUCTION_FETCH_COLS
235    }
236}
237
238impl<AB> Air<AB> for InstructionFetchChip
239where
240    AB: SP1CoreAirBuilder,
241{
242    fn eval(&self, builder: &mut AB) {
243        let main = builder.main();
244        let local = main.row_slice(0);
245        let local: &InstructionFetchCols<AB::Var> = (*local).borrow();
246
247        let clk_high = local.clk_high.into();
248        let clk_low = local.clk_low.into();
249
250        builder.assert_bool(local.is_real.into());
251
252        // Verify and calculate aligned address
253        builder.slice_range_check_u16(&local.pc, local.is_real);
254        builder.assert_bool(local.offset);
255        builder.send_byte(
256            AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
257            (local.pc[0] - AB::Expr::from_canonical_u32(4) * local.offset)
258                * AB::F::from_canonical_u32(8).inverse(),
259            AB::Expr::from_canonical_u32(13),
260            AB::Expr::zero(),
261            local.is_real.into(),
262        );
263        let sum_top_two_limb = local.pc[1] + local.pc[2];
264
265        // Check that `pc >= 2^16`, so it doesn't touch registers.
266        // This implements a stack guard of size 2^16 bytes = 64KB.
267        // If `is_real = 1`, then `pc[1] + pc[2] != 0`, so `pc >= 2^16`.
268        builder.assert_eq(local.top_two_limb_inv * sum_top_two_limb, local.is_real);
269
270        let aligned_addr: [AB::Expr; 3] = [
271            local.pc[0] - AB::Expr::from_canonical_u32(4) * local.offset,
272            local.pc[1].into(),
273            local.pc[2].into(),
274        ];
275
276        // Verify picked correct instruction from address
277
278        // Step 2. Read the memory address.
279        builder.eval_memory_access_read(
280            clk_high.clone(),
281            clk_low.clone()
282                + AB::Expr::from_canonical_u32(MemoryAccessPosition::UntrustedInstruction as u32),
283            &aligned_addr,
284            local.memory_access,
285            local.is_real.into(),
286        );
287
288        builder
289            .when_not(local.offset)
290            .assert_eq(local.selected_word[0], local.memory_access.prev_value[0]);
291        builder
292            .when_not(local.offset)
293            .assert_eq(local.selected_word[1], local.memory_access.prev_value[1]);
294        builder
295            .when(local.offset)
296            .assert_eq(local.selected_word[0], local.memory_access.prev_value[2]);
297        builder
298            .when(local.offset)
299            .assert_eq(local.selected_word[1], local.memory_access.prev_value[3]);
300
301        // Constrain the interaction with untrusted program memory table
302        let untrusted_instruction_const_fields = [
303            local.instr_type.into(),
304            local.base_opcode.into(),
305            local.funct3.into(),
306            local.funct7.into(),
307        ];
308
309        builder.receive_instruction_fetch(
310            local.pc,
311            local.instruction,
312            untrusted_instruction_const_fields.clone(),
313            [clk_high.clone(), clk_low.clone()],
314            local.is_real.into(),
315        );
316
317        builder.send_instruction_decode(
318            local.selected_word,
319            local.instruction,
320            untrusted_instruction_const_fields,
321            local.is_real.into(),
322        );
323
324        builder.send_page_prot(
325            clk_high,
326            clk_low,
327            &aligned_addr.map(Into::into),
328            AB::Expr::from_canonical_u8(PROT_EXEC),
329            local.is_real.into(),
330        );
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    #![allow(clippy::print_stdout)]
337
338    use std::sync::Arc;
339
340    use sp1_primitives::SP1Field;
341
342    use slop_matrix::dense::RowMajorMatrix;
343    use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program};
344    use sp1_hypercube::air::MachineAir;
345
346    use crate::program::InstructionFetchChip;
347
348    #[test]
349    fn generate_trace() {
350        // main:
351        //     addi x29, x0, 5
352        //     addi x30, x0, 37
353        //     add x31, x30, x29
354        let instructions = vec![
355            Instruction::new(Opcode::ADDI, 29, 0, 5, false, true),
356            Instruction::new(Opcode::ADDI, 30, 0, 37, false, true),
357            Instruction::new(Opcode::ADD, 31, 30, 29, false, false),
358        ];
359        let shard = ExecutionRecord {
360            program: Arc::new(Program::new(instructions, 0, 0)),
361            ..Default::default()
362        };
363        let chip = InstructionFetchChip::new();
364        let trace: RowMajorMatrix<SP1Field> =
365            chip.generate_trace(&shard, &mut ExecutionRecord::default());
366        println!("{:?}", trace.values)
367    }
368}