sp1_core_machine/cpu/
trace.rs1use std::borrow::BorrowMut;
2
3use hashbrown::HashMap;
4use itertools::Itertools;
5use p3_field::{PrimeField, PrimeField32};
6use p3_matrix::dense::RowMajorMatrix;
7use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator, ParallelSlice};
8use sp1_core_executor::{
9 events::{ByteLookupEvent, ByteRecord, CpuEvent, MemoryRecordEnum},
10 syscalls::SyscallCode,
11 ByteOpcode::{self, U16Range},
12 ExecutionRecord, Instruction, Program,
13};
14use sp1_stark::air::MachineAir;
15use tracing::instrument;
16
17use super::{columns::NUM_CPU_COLS, CpuChip};
18use crate::{cpu::columns::CpuCols, memory::MemoryCols, utils::zeroed_f_vec};
19
20impl<F: PrimeField32> MachineAir<F> for CpuChip {
21 type Record = ExecutionRecord;
22
23 type Program = Program;
24
25 fn name(&self) -> String {
26 self.id().to_string()
27 }
28
29 fn generate_trace(
30 &self,
31 input: &ExecutionRecord,
32 _: &mut ExecutionRecord,
33 ) -> RowMajorMatrix<F> {
34 let n_real_rows = input.cpu_events.len();
35 let padded_nb_rows = if let Some(shape) = &input.shape {
36 shape.height(&self.id()).unwrap()
37 } else if n_real_rows < 16 {
38 16
39 } else {
40 n_real_rows.next_power_of_two()
41 };
42 let mut values = zeroed_f_vec(padded_nb_rows * NUM_CPU_COLS);
43
44 let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1);
45 values.chunks_mut(chunk_size * NUM_CPU_COLS).enumerate().par_bridge().for_each(
46 |(i, rows)| {
47 rows.chunks_mut(NUM_CPU_COLS).enumerate().for_each(|(j, row)| {
48 let idx = i * chunk_size + j;
49 let cols: &mut CpuCols<F> = row.borrow_mut();
50
51 if idx >= input.cpu_events.len() {
52 cols.instruction.imm_b = F::one();
53 cols.instruction.imm_c = F::one();
54 cols.is_syscall = F::one();
55 } else {
56 let mut byte_lookup_events = Vec::new();
57 let event = &input.cpu_events[idx];
58 let instruction = &input.program.fetch(event.pc);
59 self.event_to_row(
60 event,
61 cols,
62 &mut byte_lookup_events,
63 input.public_values.execution_shard,
64 instruction,
65 );
66 }
67 });
68 },
69 );
70
71 RowMajorMatrix::new(values, NUM_CPU_COLS)
73 }
74
75 #[instrument(name = "generate cpu dependencies", level = "debug", skip_all)]
76 fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
77 let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1);
79
80 let blu_events: Vec<_> = input
81 .cpu_events
82 .par_chunks(chunk_size)
83 .map(|ops: &[CpuEvent]| {
84 let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
86 ops.iter().for_each(|op| {
87 let mut row = [F::zero(); NUM_CPU_COLS];
88 let cols: &mut CpuCols<F> = row.as_mut_slice().borrow_mut();
89 let instruction = &input.program.fetch(op.pc);
90 self.event_to_row::<F>(
91 op,
92 cols,
93 &mut blu,
94 input.public_values.execution_shard,
95 instruction,
96 );
97 });
98 blu
99 })
100 .collect::<Vec<_>>();
101
102 output.add_byte_lookup_events_from_maps(blu_events.iter().collect_vec());
103 }
104
105 fn included(&self, shard: &Self::Record) -> bool {
106 if let Some(shape) = shard.shape.as_ref() {
107 shape.included::<F, _>(self)
108 } else {
109 shard.contains_cpu()
110 }
111 }
112}
113
114impl CpuChip {
115 fn event_to_row<F: PrimeField32>(
117 &self,
118 event: &CpuEvent,
119 cols: &mut CpuCols<F>,
120 blu_events: &mut impl ByteRecord,
121 shard: u32,
122 instruction: &Instruction,
123 ) {
124 self.populate_shard_clk(cols, event, blu_events, shard);
126
127 cols.pc = F::from_canonical_u32(event.pc);
129 cols.next_pc = F::from_canonical_u32(event.next_pc);
130 cols.instruction.populate(instruction);
131 cols.op_a_immutable = F::from_bool(
132 instruction.is_memory_store_instruction() || instruction.is_branch_instruction(),
133 );
134 cols.is_memory = F::from_bool(
135 instruction.is_memory_load_instruction() || instruction.is_memory_store_instruction(),
136 );
137 cols.is_syscall = F::from_bool(instruction.is_ecall_instruction());
138 *cols.op_a_access.value_mut() = event.a.into();
139 *cols.op_b_access.value_mut() = event.b.into();
140 *cols.op_c_access.value_mut() = event.c.into();
141
142 cols.shard_to_send = if instruction.is_memory_load_instruction() ||
143 instruction.is_memory_store_instruction() ||
144 instruction.is_ecall_instruction()
145 {
146 cols.shard
147 } else {
148 F::zero()
149 };
150 cols.clk_to_send = if instruction.is_memory_load_instruction() ||
151 instruction.is_memory_store_instruction() ||
152 instruction.is_ecall_instruction()
153 {
154 F::from_canonical_u32(event.clk)
155 } else {
156 F::zero()
157 };
158
159 if let Some(record) = event.a_record {
161 if instruction.is_ecall_instruction() {
162 cols.op_a_access.populate(record, &mut Vec::new());
167 } else {
168 cols.op_a_access.populate(record, blu_events);
169 }
170 }
171 if let Some(MemoryRecordEnum::Read(record)) = event.b_record {
172 cols.op_b_access.populate(record, blu_events);
173 }
174 if let Some(MemoryRecordEnum::Read(record)) = event.c_record {
175 cols.op_c_access.populate(record, blu_events);
176 }
177
178 if instruction.is_ecall_instruction() {
179 let syscall_id = cols.op_a_access.prev_value[0];
180 let num_extra_cycles = cols.op_a_access.prev_value[2];
181 cols.is_halt =
182 F::from_bool(syscall_id == F::from_canonical_u32(SyscallCode::HALT.syscall_id()));
183 cols.num_extra_cycles = num_extra_cycles;
184 }
185
186 let a_bytes = cols
188 .op_a_access
189 .access
190 .value
191 .0
192 .iter()
193 .map(|x| x.as_canonical_u32())
194 .collect::<Vec<_>>();
195 blu_events.add_byte_lookup_event(ByteLookupEvent {
196 opcode: ByteOpcode::U8Range,
197 a1: 0,
198 a2: 0,
199 b: a_bytes[0] as u8,
200 c: a_bytes[1] as u8,
201 });
202 blu_events.add_byte_lookup_event(ByteLookupEvent {
203 opcode: ByteOpcode::U8Range,
204 a1: 0,
205 a2: 0,
206 b: a_bytes[2] as u8,
207 c: a_bytes[3] as u8,
208 });
209
210 cols.is_real = F::one();
212 }
213
214 fn populate_shard_clk<F: PrimeField>(
216 &self,
217 cols: &mut CpuCols<F>,
218 event: &CpuEvent,
219 blu_events: &mut impl ByteRecord,
220 shard: u32,
221 ) {
222 cols.shard = F::from_canonical_u32(shard);
223
224 let clk_16bit_limb = (event.clk & 0xffff) as u16;
225 let clk_8bit_limb = ((event.clk >> 16) & 0xff) as u8;
226 cols.clk_16bit_limb = F::from_canonical_u16(clk_16bit_limb);
227 cols.clk_8bit_limb = F::from_canonical_u8(clk_8bit_limb);
228
229 blu_events.add_byte_lookup_event(ByteLookupEvent::new(U16Range, shard as u16, 0, 0, 0));
230 blu_events.add_byte_lookup_event(ByteLookupEvent::new(U16Range, clk_16bit_limb, 0, 0, 0));
231 blu_events.add_byte_lookup_event(ByteLookupEvent::new(
232 ByteOpcode::U8Range,
233 0,
234 0,
235 0,
236 clk_8bit_limb as u8,
237 ));
238 }
239}