sp1_core_machine/program/
instruction_fetch.rs1use core::{
2 borrow::{Borrow, BorrowMut},
3 mem::{size_of, MaybeUninit},
4};
5
6use crate::{
7 air::SP1CoreAirBuilder, memory::MemoryAccessCols, program::InstructionCols,
8 utils::next_multiple_of_32,
9};
10use hashbrown::HashMap;
11use itertools::Itertools;
12use slop_air::{Air, AirBuilder, BaseAir};
13use slop_algebra::{AbstractField, Field, PrimeField, PrimeField32};
14use slop_matrix::Matrix;
15use sp1_core_executor::{
16 events::{ByteLookupEvent, ByteRecord, InstructionFetchEvent, MemoryAccessPosition},
17 ByteOpcode, ExecutionRecord, MemoryAccessRecord, Opcode, Program,
18};
19use sp1_derive::AlignedBorrow;
20use sp1_hypercube::air::MachineAir;
21use sp1_primitives::consts::{u64_to_u16_limbs, PROT_EXEC};
22
23pub const NUM_INSTRUCTION_FETCH_COLS: usize = size_of::<InstructionFetchCols<u8>>();
25
26#[derive(AlignedBorrow, Clone, Copy, Default)]
28#[repr(C)]
29pub struct InstructionFetchCols<T> {
30 pub clk_high: T,
31 pub clk_low: T,
32 pub pc: [T; 3],
33 pub top_two_limb_inv: T,
35
36 pub instruction: InstructionCols<T>,
37 pub instr_type: T,
38 pub base_opcode: T,
39 pub funct3: T,
40 pub funct7: T,
41
42 pub memory_access: MemoryAccessCols<T>,
43 pub selected_word: [T; 2],
45 pub offset: T,
46 pub is_real: T,
47}
48
49#[derive(Default)]
51pub struct InstructionFetchChip;
52
53impl InstructionFetchChip {
54 pub const fn new() -> Self {
55 Self {}
56 }
57
58 fn event_to_row<F: PrimeField>(
59 &self,
60 event: &InstructionFetchEvent,
61 memory_access: &MemoryAccessRecord,
62 cols: &mut InstructionFetchCols<F>,
63 ) {
64 let instruction = event.instruction;
65 let (mem_access, encoded) = memory_access.untrusted_instruction.unwrap();
66 assert_eq!(encoded, event.encoded_instruction);
67
68 let pc = event.pc; cols.pc = [
70 F::from_canonical_u16((pc & 0xFFFF) as u16),
71 F::from_canonical_u16(((pc >> 16) & 0xFFFF) as u16),
72 F::from_canonical_u16(((pc >> 32) & 0xFFFF) as u16),
73 ];
74
75 let sum_top_two_limb = cols.pc[1] + cols.pc[2];
76 cols.top_two_limb_inv = sum_top_two_limb.inverse();
77
78 let clk_high = (event.clk >> 24) as u32;
79 let clk_low = (event.clk & 0xFFFFFF) as u32;
80 cols.clk_high = F::from_canonical_u32(clk_high);
81 cols.clk_low = F::from_canonical_u32(clk_low);
82
83 if instruction.opcode != Opcode::UNIMP {
84 let (instr_type, instr_type_imm) = instruction.opcode.instruction_type();
85 cols.instr_type = if instr_type_imm.is_some() && instruction.imm_c {
86 F::from_canonical_u32(instr_type_imm.unwrap() as u32)
87 } else {
88 F::from_canonical_u32(instr_type as u32)
89 };
90 assert!(cols.instr_type != F::zero());
91
92 let (base_opcode, base_imm_opcode) = instruction.opcode.base_opcode();
93 cols.base_opcode = if base_imm_opcode.is_some() && instruction.imm_c {
94 F::from_canonical_u32(base_imm_opcode.unwrap())
95 } else {
96 F::from_canonical_u32(base_opcode)
97 };
98 let funct3 = instruction.opcode.funct3().unwrap_or(0);
99 let funct7 = instruction.opcode.funct7().unwrap_or(0);
100 cols.funct3 = F::from_canonical_u8(funct3);
101 cols.funct7 = F::from_canonical_u8(funct7);
102 }
103
104 let offset = (pc / 4) % 2;
106 cols.offset = F::from_canonical_u8(offset as u8);
107
108 let limbs = u64_to_u16_limbs(mem_access.value());
110
111 let limb_selector = 2 * offset;
114
115 cols.selected_word = [
117 F::from_canonical_u16(limbs[limb_selector as usize]),
118 F::from_canonical_u16(limbs[limb_selector as usize + 1]),
119 ];
120
121 let instruction = event.instruction;
122 cols.instruction.populate(&instruction);
123
124 let encoding_check = instruction.encode();
126 assert_eq!(event.encoded_instruction, encoding_check);
127
128 cols.is_real = F::one();
129 }
130}
131
132impl<F: PrimeField32> MachineAir<F> for InstructionFetchChip {
133 type Record = ExecutionRecord;
134
135 type Program = Program;
136
137 fn name(&self) -> &'static str {
138 "InstructionFetch"
139 }
140
141 fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
142 let mut blu_batches = Vec::new();
143 for full_event in input.instruction_fetch_events.iter() {
144 let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
145 let mut row = [F::zero(); NUM_INSTRUCTION_FETCH_COLS];
146 let cols: &mut InstructionFetchCols<F> = row.as_mut_slice().borrow_mut();
147 let (event, memory_access) = full_event;
148 let (mem_access, encoded) = memory_access.untrusted_instruction.unwrap();
149 assert_eq!(encoded, event.encoded_instruction);
150 cols.memory_access.populate(mem_access, &mut blu);
151 let pc = event.pc;
152
153 let pc_0 = (pc & 0xFFFF) as u16;
154 let pc_1 = ((pc >> 16) & 0xFFFF) as u16;
155 let pc_2 = ((pc >> 32) & 0xFFFF) as u16;
156 blu.add_u16_range_checks(&[pc_0, pc_1, pc_2]);
157
158 self.event_to_row(event, memory_access, cols);
159
160 let offset: u16 = cols.offset.as_canonical_u32().try_into().unwrap();
161
162 blu.add_bit_range_check((pc_0 - 4 * offset) / 8, 13);
163
164 blu_batches.push(blu);
165 }
166
167 output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
168 }
169
170 fn generate_trace_into(
171 &self,
172 input: &ExecutionRecord,
173 _output: &mut ExecutionRecord,
174 buffer: &mut [MaybeUninit<F>],
175 ) {
176 let padded_nb_rows =
177 <InstructionFetchChip as MachineAir<F>>::num_rows(self, input).unwrap();
178 let num_event_rows = input.instruction_fetch_events.len();
179
180 unsafe {
181 let padding_start = num_event_rows * NUM_INSTRUCTION_FETCH_COLS;
182 let padding_size = (padded_nb_rows - num_event_rows) * NUM_INSTRUCTION_FETCH_COLS;
183 if padding_size > 0 {
184 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
185 }
186 }
187
188 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
189 let values = unsafe {
190 core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * NUM_INSTRUCTION_FETCH_COLS)
191 };
192
193 let chunk_size = std::cmp::max(input.instruction_fetch_events.len() / num_cpus::get(), 1);
194
195 values.chunks_mut(chunk_size * NUM_INSTRUCTION_FETCH_COLS).enumerate().for_each(
196 |(i, rows)| {
197 rows.chunks_mut(NUM_INSTRUCTION_FETCH_COLS).enumerate().for_each(|(j, row)| {
198 let idx = i * chunk_size + j;
199 let cols: &mut InstructionFetchCols<F> = row.borrow_mut();
200 let (event, memory_access) = &input.instruction_fetch_events[idx];
201
202 let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
203 let (mem_access, encoded) = memory_access.untrusted_instruction.unwrap();
204 assert_eq!(encoded, event.encoded_instruction);
205 assert!(mem_access.current_record().timestamp == event.clk);
206
207 cols.memory_access.populate(mem_access, &mut blu);
208 self.event_to_row(event, memory_access, cols);
209 });
210 },
211 );
212 }
213
214 fn included(&self, shard: &Self::Record) -> bool {
215 if let Some(shape) = shard.shape.as_ref() {
216 shape.included::<F, _>(self)
217 } else {
218 !shard.instruction_fetch_events.is_empty()
219 }
220 }
221
222 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
223 let nb_rows = next_multiple_of_32(
224 input.instruction_fetch_events.len(),
225 input.fixed_log2_rows::<F, _>(self),
226 );
227
228 Some(nb_rows)
229 }
230}
231
232impl<F> BaseAir<F> for InstructionFetchChip {
233 fn width(&self) -> usize {
234 NUM_INSTRUCTION_FETCH_COLS
235 }
236}
237
238impl<AB> Air<AB> for InstructionFetchChip
239where
240 AB: SP1CoreAirBuilder,
241{
242 fn eval(&self, builder: &mut AB) {
243 let main = builder.main();
244 let local = main.row_slice(0);
245 let local: &InstructionFetchCols<AB::Var> = (*local).borrow();
246
247 let clk_high = local.clk_high.into();
248 let clk_low = local.clk_low.into();
249
250 builder.assert_bool(local.is_real.into());
251
252 builder.slice_range_check_u16(&local.pc, local.is_real);
254 builder.assert_bool(local.offset);
255 builder.send_byte(
256 AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
257 (local.pc[0] - AB::Expr::from_canonical_u32(4) * local.offset)
258 * AB::F::from_canonical_u32(8).inverse(),
259 AB::Expr::from_canonical_u32(13),
260 AB::Expr::zero(),
261 local.is_real.into(),
262 );
263 let sum_top_two_limb = local.pc[1] + local.pc[2];
264
265 builder.assert_eq(local.top_two_limb_inv * sum_top_two_limb, local.is_real);
269
270 let aligned_addr: [AB::Expr; 3] = [
271 local.pc[0] - AB::Expr::from_canonical_u32(4) * local.offset,
272 local.pc[1].into(),
273 local.pc[2].into(),
274 ];
275
276 builder.eval_memory_access_read(
280 clk_high.clone(),
281 clk_low.clone()
282 + AB::Expr::from_canonical_u32(MemoryAccessPosition::UntrustedInstruction as u32),
283 &aligned_addr,
284 local.memory_access,
285 local.is_real.into(),
286 );
287
288 builder
289 .when_not(local.offset)
290 .assert_eq(local.selected_word[0], local.memory_access.prev_value[0]);
291 builder
292 .when_not(local.offset)
293 .assert_eq(local.selected_word[1], local.memory_access.prev_value[1]);
294 builder
295 .when(local.offset)
296 .assert_eq(local.selected_word[0], local.memory_access.prev_value[2]);
297 builder
298 .when(local.offset)
299 .assert_eq(local.selected_word[1], local.memory_access.prev_value[3]);
300
301 let untrusted_instruction_const_fields = [
303 local.instr_type.into(),
304 local.base_opcode.into(),
305 local.funct3.into(),
306 local.funct7.into(),
307 ];
308
309 builder.receive_instruction_fetch(
310 local.pc,
311 local.instruction,
312 untrusted_instruction_const_fields.clone(),
313 [clk_high.clone(), clk_low.clone()],
314 local.is_real.into(),
315 );
316
317 builder.send_instruction_decode(
318 local.selected_word,
319 local.instruction,
320 untrusted_instruction_const_fields,
321 local.is_real.into(),
322 );
323
324 builder.send_page_prot(
325 clk_high,
326 clk_low,
327 &aligned_addr.map(Into::into),
328 AB::Expr::from_canonical_u8(PROT_EXEC),
329 local.is_real.into(),
330 );
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 #![allow(clippy::print_stdout)]
337
338 use std::sync::Arc;
339
340 use sp1_primitives::SP1Field;
341
342 use slop_matrix::dense::RowMajorMatrix;
343 use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program};
344 use sp1_hypercube::air::MachineAir;
345
346 use crate::program::InstructionFetchChip;
347
348 #[test]
349 fn generate_trace() {
350 let instructions = vec![
355 Instruction::new(Opcode::ADDI, 29, 0, 5, false, true),
356 Instruction::new(Opcode::ADDI, 30, 0, 37, false, true),
357 Instruction::new(Opcode::ADD, 31, 30, 29, false, false),
358 ];
359 let shard = ExecutionRecord {
360 program: Arc::new(Program::new(instructions, 0, 0)),
361 ..Default::default()
362 };
363 let chip = InstructionFetchChip::new();
364 let trace: RowMajorMatrix<SP1Field> =
365 chip.generate_trace(&shard, &mut ExecutionRecord::default());
366 println!("{:?}", trace.values)
367 }
368}