sp1_core_machine/memory/instructions/
trace.rs

1use std::borrow::BorrowMut;
2
3use hashbrown::HashMap;
4use itertools::Itertools;
5use p3_field::PrimeField32;
6use p3_matrix::dense::RowMajorMatrix;
7use rayon::iter::{ParallelBridge, ParallelIterator};
8use sp1_core_executor::{
9    events::{ByteLookupEvent, ByteRecord, MemInstrEvent},
10    ByteOpcode, ExecutionRecord, Opcode, Program,
11};
12use sp1_primitives::consts::WORD_SIZE;
13use sp1_stark::air::MachineAir;
14
15use crate::utils::{next_power_of_two, zeroed_f_vec};
16
17use super::{
18    columns::{MemoryInstructionsColumns, NUM_MEMORY_INSTRUCTIONS_COLUMNS},
19    MemoryInstructionsChip,
20};
21
22impl<F: PrimeField32> MachineAir<F> for MemoryInstructionsChip {
23    type Record = ExecutionRecord;
24
25    type Program = Program;
26
27    fn name(&self) -> String {
28        "MemoryInstrs".to_string()
29    }
30
31    fn generate_trace(
32        &self,
33        input: &ExecutionRecord,
34        output: &mut ExecutionRecord,
35    ) -> RowMajorMatrix<F> {
36        let chunk_size = std::cmp::max((input.memory_instr_events.len()) / num_cpus::get(), 1);
37        let nb_rows = input.memory_instr_events.len();
38        let size_log2 = input.fixed_log2_rows::<F, _>(self);
39        let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
40        let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_INSTRUCTIONS_COLUMNS);
41
42        let blu_events = values
43            .chunks_mut(chunk_size * NUM_MEMORY_INSTRUCTIONS_COLUMNS)
44            .enumerate()
45            .par_bridge()
46            .map(|(i, rows)| {
47                let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
48                rows.chunks_mut(NUM_MEMORY_INSTRUCTIONS_COLUMNS).enumerate().for_each(
49                    |(j, row)| {
50                        let idx = i * chunk_size + j;
51                        let cols: &mut MemoryInstructionsColumns<F> = row.borrow_mut();
52
53                        if idx < input.memory_instr_events.len() {
54                            let event = &input.memory_instr_events[idx];
55                            self.event_to_row(event, cols, &mut blu);
56                        }
57                    },
58                );
59                blu
60            })
61            .collect::<Vec<_>>();
62
63        output.add_byte_lookup_events_from_maps(blu_events.iter().collect_vec());
64
65        // Convert the trace to a row major matrix.
66        RowMajorMatrix::new(values, NUM_MEMORY_INSTRUCTIONS_COLUMNS)
67    }
68
69    fn included(&self, shard: &Self::Record) -> bool {
70        if let Some(shape) = shard.shape.as_ref() {
71            shape.included::<F, _>(self)
72        } else {
73            !shard.memory_instr_events.is_empty()
74        }
75    }
76
77    fn local_only(&self) -> bool {
78        true
79    }
80}
81
82impl MemoryInstructionsChip {
83    fn event_to_row<F: PrimeField32>(
84        &self,
85        event: &MemInstrEvent,
86        cols: &mut MemoryInstructionsColumns<F>,
87        blu: &mut HashMap<ByteLookupEvent, usize>,
88    ) {
89        cols.shard = F::from_canonical_u32(event.shard);
90        assert!(cols.shard != F::zero());
91        cols.clk = F::from_canonical_u32(event.clk);
92        cols.pc = F::from_canonical_u32(event.pc);
93        cols.op_a_value = event.a.into();
94        cols.op_b_value = event.b.into();
95        cols.op_c_value = event.c.into();
96        cols.op_a_0 = F::from_bool(event.op_a_0);
97
98        // Populate memory accesses for reading from memory.
99        cols.memory_access.populate(event.mem_access, blu);
100
101        // Populate addr_word and addr_aligned columns.
102        let memory_addr = event.b.wrapping_add(event.c);
103        let aligned_addr = memory_addr - memory_addr % WORD_SIZE as u32;
104        cols.addr_word = memory_addr.into();
105        cols.addr_word_range_checker.populate(cols.addr_word, blu);
106        cols.addr_aligned = F::from_canonical_u32(aligned_addr);
107
108        // Populate the aa_least_sig_byte_decomp columns.
109        assert!(aligned_addr.is_multiple_of(4));
110        // Populate memory offsets.
111        let addr_ls_two_bits = (memory_addr % WORD_SIZE as u32) as u8;
112        cols.addr_ls_two_bits = F::from_canonical_u8(addr_ls_two_bits);
113        cols.ls_bits_is_one = F::from_bool(addr_ls_two_bits == 1);
114        cols.ls_bits_is_two = F::from_bool(addr_ls_two_bits == 2);
115        cols.ls_bits_is_three = F::from_bool(addr_ls_two_bits == 3);
116
117        // Add byte lookup event to verify correct calculation of addr_ls_two_bits.
118        blu.add_byte_lookup_event(ByteLookupEvent {
119            opcode: ByteOpcode::AND,
120            a1: addr_ls_two_bits as u16,
121            a2: 0,
122            b: cols.addr_word[0].as_canonical_u32() as u8,
123            c: 0b11,
124        });
125
126        // If it is a load instruction, set the unsigned_mem_val column.
127        let mem_value = event.mem_access.value();
128        if matches!(event.opcode, Opcode::LB | Opcode::LBU | Opcode::LH | Opcode::LHU | Opcode::LW)
129        {
130            match event.opcode {
131                Opcode::LB | Opcode::LBU => {
132                    cols.unsigned_mem_val =
133                        (mem_value.to_le_bytes()[addr_ls_two_bits as usize] as u32).into();
134                }
135                Opcode::LH | Opcode::LHU => {
136                    let value = match (addr_ls_two_bits >> 1) % 2 {
137                        0 => mem_value & 0x0000FFFF,
138                        1 => (mem_value & 0xFFFF0000) >> 16,
139                        _ => unreachable!(),
140                    };
141                    cols.unsigned_mem_val = value.into();
142                }
143                Opcode::LW => {
144                    cols.unsigned_mem_val = mem_value.into();
145                }
146                _ => unreachable!(),
147            }
148
149            // For the signed load instructions, we need to check if the loaded value is negative.
150            if matches!(event.opcode, Opcode::LB | Opcode::LH) {
151                let most_sig_mem_value_byte = if matches!(event.opcode, Opcode::LB) {
152                    cols.unsigned_mem_val.to_u32().to_le_bytes()[0]
153                } else {
154                    cols.unsigned_mem_val.to_u32().to_le_bytes()[1]
155                };
156
157                let most_sig_mem_value_bit = most_sig_mem_value_byte >> 7;
158                if most_sig_mem_value_bit == 1 {
159                    cols.mem_value_is_neg_not_x0 = F::from_bool(!event.op_a_0);
160                }
161
162                cols.most_sig_byte = F::from_canonical_u8(most_sig_mem_value_byte);
163                cols.most_sig_bit = F::from_canonical_u8(most_sig_mem_value_bit);
164
165                blu.add_byte_lookup_event(ByteLookupEvent {
166                    opcode: ByteOpcode::MSB,
167                    a1: most_sig_mem_value_bit as u16,
168                    a2: 0,
169                    b: most_sig_mem_value_byte,
170                    c: 0,
171                });
172            }
173
174            // Set the `mem_value_is_pos_not_x0` composite flag.
175            cols.mem_value_is_pos_not_x0 = F::from_bool(
176                ((matches!(event.opcode, Opcode::LB | Opcode::LH) &&
177                    (cols.most_sig_bit == F::zero())) ||
178                    matches!(event.opcode, Opcode::LBU | Opcode::LHU | Opcode::LW)) &&
179                    !event.op_a_0,
180            )
181        }
182
183        cols.is_lb = F::from_bool(matches!(event.opcode, Opcode::LB));
184        cols.is_lbu = F::from_bool(matches!(event.opcode, Opcode::LBU));
185        cols.is_lh = F::from_bool(matches!(event.opcode, Opcode::LH));
186        cols.is_lhu = F::from_bool(matches!(event.opcode, Opcode::LHU));
187        cols.is_lw = F::from_bool(matches!(event.opcode, Opcode::LW));
188        cols.is_sb = F::from_bool(matches!(event.opcode, Opcode::SB));
189        cols.is_sh = F::from_bool(matches!(event.opcode, Opcode::SH));
190        cols.is_sw = F::from_bool(matches!(event.opcode, Opcode::SW));
191
192        // Add event to byte lookup for byte range checking each byte in the memory addr
193        let addr_bytes = memory_addr.to_le_bytes();
194        blu.add_byte_lookup_event(ByteLookupEvent {
195            opcode: ByteOpcode::U8Range,
196            a1: 0,
197            a2: 0,
198            b: addr_bytes[1],
199            c: addr_bytes[2],
200        });
201
202        cols.most_sig_bytes_zero
203            .populate_from_field_element(cols.addr_word[1] + cols.addr_word[2] + cols.addr_word[3]);
204
205        if cols.most_sig_bytes_zero.result == F::one() {
206            blu.add_byte_lookup_event(ByteLookupEvent {
207                opcode: ByteOpcode::LTU,
208                a1: 1,
209                a2: 0,
210                b: 31,
211                c: cols.addr_word[0].as_canonical_u32() as u8,
212            });
213        }
214    }
215}