sp1_core_machine/memory/instructions/
trace.rs1use std::borrow::BorrowMut;
2
3use hashbrown::HashMap;
4use itertools::Itertools;
5use p3_field::PrimeField32;
6use p3_matrix::dense::RowMajorMatrix;
7use rayon::iter::{ParallelBridge, ParallelIterator};
8use sp1_core_executor::{
9 events::{ByteLookupEvent, ByteRecord, MemInstrEvent},
10 ByteOpcode, ExecutionRecord, Opcode, Program,
11};
12use sp1_primitives::consts::WORD_SIZE;
13use sp1_stark::air::MachineAir;
14
15use crate::utils::{next_power_of_two, zeroed_f_vec};
16
17use super::{
18 columns::{MemoryInstructionsColumns, NUM_MEMORY_INSTRUCTIONS_COLUMNS},
19 MemoryInstructionsChip,
20};
21
22impl<F: PrimeField32> MachineAir<F> for MemoryInstructionsChip {
23 type Record = ExecutionRecord;
24
25 type Program = Program;
26
27 fn name(&self) -> String {
28 "MemoryInstrs".to_string()
29 }
30
31 fn generate_trace(
32 &self,
33 input: &ExecutionRecord,
34 output: &mut ExecutionRecord,
35 ) -> RowMajorMatrix<F> {
36 let chunk_size = std::cmp::max((input.memory_instr_events.len()) / num_cpus::get(), 1);
37 let nb_rows = input.memory_instr_events.len();
38 let size_log2 = input.fixed_log2_rows::<F, _>(self);
39 let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
40 let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_INSTRUCTIONS_COLUMNS);
41
42 let blu_events = values
43 .chunks_mut(chunk_size * NUM_MEMORY_INSTRUCTIONS_COLUMNS)
44 .enumerate()
45 .par_bridge()
46 .map(|(i, rows)| {
47 let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
48 rows.chunks_mut(NUM_MEMORY_INSTRUCTIONS_COLUMNS).enumerate().for_each(
49 |(j, row)| {
50 let idx = i * chunk_size + j;
51 let cols: &mut MemoryInstructionsColumns<F> = row.borrow_mut();
52
53 if idx < input.memory_instr_events.len() {
54 let event = &input.memory_instr_events[idx];
55 self.event_to_row(event, cols, &mut blu);
56 }
57 },
58 );
59 blu
60 })
61 .collect::<Vec<_>>();
62
63 output.add_byte_lookup_events_from_maps(blu_events.iter().collect_vec());
64
65 RowMajorMatrix::new(values, NUM_MEMORY_INSTRUCTIONS_COLUMNS)
67 }
68
69 fn included(&self, shard: &Self::Record) -> bool {
70 if let Some(shape) = shard.shape.as_ref() {
71 shape.included::<F, _>(self)
72 } else {
73 !shard.memory_instr_events.is_empty()
74 }
75 }
76
77 fn local_only(&self) -> bool {
78 true
79 }
80}
81
82impl MemoryInstructionsChip {
83 fn event_to_row<F: PrimeField32>(
84 &self,
85 event: &MemInstrEvent,
86 cols: &mut MemoryInstructionsColumns<F>,
87 blu: &mut HashMap<ByteLookupEvent, usize>,
88 ) {
89 cols.shard = F::from_canonical_u32(event.shard);
90 assert!(cols.shard != F::zero());
91 cols.clk = F::from_canonical_u32(event.clk);
92 cols.pc = F::from_canonical_u32(event.pc);
93 cols.op_a_value = event.a.into();
94 cols.op_b_value = event.b.into();
95 cols.op_c_value = event.c.into();
96 cols.op_a_0 = F::from_bool(event.op_a_0);
97
98 cols.memory_access.populate(event.mem_access, blu);
100
101 let memory_addr = event.b.wrapping_add(event.c);
103 let aligned_addr = memory_addr - memory_addr % WORD_SIZE as u32;
104 cols.addr_word = memory_addr.into();
105 cols.addr_word_range_checker.populate(cols.addr_word, blu);
106 cols.addr_aligned = F::from_canonical_u32(aligned_addr);
107
108 assert!(aligned_addr.is_multiple_of(4));
110 let addr_ls_two_bits = (memory_addr % WORD_SIZE as u32) as u8;
112 cols.addr_ls_two_bits = F::from_canonical_u8(addr_ls_two_bits);
113 cols.ls_bits_is_one = F::from_bool(addr_ls_two_bits == 1);
114 cols.ls_bits_is_two = F::from_bool(addr_ls_two_bits == 2);
115 cols.ls_bits_is_three = F::from_bool(addr_ls_two_bits == 3);
116
117 blu.add_byte_lookup_event(ByteLookupEvent {
119 opcode: ByteOpcode::AND,
120 a1: addr_ls_two_bits as u16,
121 a2: 0,
122 b: cols.addr_word[0].as_canonical_u32() as u8,
123 c: 0b11,
124 });
125
126 let mem_value = event.mem_access.value();
128 if matches!(event.opcode, Opcode::LB | Opcode::LBU | Opcode::LH | Opcode::LHU | Opcode::LW)
129 {
130 match event.opcode {
131 Opcode::LB | Opcode::LBU => {
132 cols.unsigned_mem_val =
133 (mem_value.to_le_bytes()[addr_ls_two_bits as usize] as u32).into();
134 }
135 Opcode::LH | Opcode::LHU => {
136 let value = match (addr_ls_two_bits >> 1) % 2 {
137 0 => mem_value & 0x0000FFFF,
138 1 => (mem_value & 0xFFFF0000) >> 16,
139 _ => unreachable!(),
140 };
141 cols.unsigned_mem_val = value.into();
142 }
143 Opcode::LW => {
144 cols.unsigned_mem_val = mem_value.into();
145 }
146 _ => unreachable!(),
147 }
148
149 if matches!(event.opcode, Opcode::LB | Opcode::LH) {
151 let most_sig_mem_value_byte = if matches!(event.opcode, Opcode::LB) {
152 cols.unsigned_mem_val.to_u32().to_le_bytes()[0]
153 } else {
154 cols.unsigned_mem_val.to_u32().to_le_bytes()[1]
155 };
156
157 let most_sig_mem_value_bit = most_sig_mem_value_byte >> 7;
158 if most_sig_mem_value_bit == 1 {
159 cols.mem_value_is_neg_not_x0 = F::from_bool(!event.op_a_0);
160 }
161
162 cols.most_sig_byte = F::from_canonical_u8(most_sig_mem_value_byte);
163 cols.most_sig_bit = F::from_canonical_u8(most_sig_mem_value_bit);
164
165 blu.add_byte_lookup_event(ByteLookupEvent {
166 opcode: ByteOpcode::MSB,
167 a1: most_sig_mem_value_bit as u16,
168 a2: 0,
169 b: most_sig_mem_value_byte,
170 c: 0,
171 });
172 }
173
174 cols.mem_value_is_pos_not_x0 = F::from_bool(
176 ((matches!(event.opcode, Opcode::LB | Opcode::LH) &&
177 (cols.most_sig_bit == F::zero())) ||
178 matches!(event.opcode, Opcode::LBU | Opcode::LHU | Opcode::LW)) &&
179 !event.op_a_0,
180 )
181 }
182
183 cols.is_lb = F::from_bool(matches!(event.opcode, Opcode::LB));
184 cols.is_lbu = F::from_bool(matches!(event.opcode, Opcode::LBU));
185 cols.is_lh = F::from_bool(matches!(event.opcode, Opcode::LH));
186 cols.is_lhu = F::from_bool(matches!(event.opcode, Opcode::LHU));
187 cols.is_lw = F::from_bool(matches!(event.opcode, Opcode::LW));
188 cols.is_sb = F::from_bool(matches!(event.opcode, Opcode::SB));
189 cols.is_sh = F::from_bool(matches!(event.opcode, Opcode::SH));
190 cols.is_sw = F::from_bool(matches!(event.opcode, Opcode::SW));
191
192 let addr_bytes = memory_addr.to_le_bytes();
194 blu.add_byte_lookup_event(ByteLookupEvent {
195 opcode: ByteOpcode::U8Range,
196 a1: 0,
197 a2: 0,
198 b: addr_bytes[1],
199 c: addr_bytes[2],
200 });
201
202 cols.most_sig_bytes_zero
203 .populate_from_field_element(cols.addr_word[1] + cols.addr_word[2] + cols.addr_word[3]);
204
205 if cols.most_sig_bytes_zero.result == F::one() {
206 blu.add_byte_lookup_event(ByteLookupEvent {
207 opcode: ByteOpcode::LTU,
208 a1: 1,
209 a2: 0,
210 b: 31,
211 c: cols.addr_word[0].as_canonical_u32() as u8,
212 });
213 }
214 }
215}