sp1_core_executor/
record.rs

1use enum_map::EnumMap;
2use hashbrown::HashMap;
3use itertools::{EitherOrBoth, Itertools};
4use p3_field::{AbstractField, PrimeField};
5use sp1_stark::{
6    air::{MachineAir, PublicValues},
7    shape::Shape,
8    MachineRecord, SP1CoreOpts, SplitOpts,
9};
10use std::{mem::take, str::FromStr, sync::Arc};
11
12use serde::{Deserialize, Serialize};
13
14use crate::{
15    events::{
16        AUIPCEvent, AluEvent, BranchEvent, ByteLookupEvent, ByteRecord, CpuEvent,
17        GlobalInteractionEvent, JumpEvent, MemInstrEvent, MemoryInitializeFinalizeEvent,
18        MemoryLocalEvent, MemoryRecordEnum, PrecompileEvent, PrecompileEvents, SyscallEvent,
19    },
20    program::Program,
21    syscalls::SyscallCode,
22    RiscvAirId,
23};
24
25/// A record of the execution of a program.
26///
27/// The trace of the execution is represented as a list of "events" that occur every cycle.
28#[derive(Clone, Debug, Serialize, Deserialize, Default)]
29pub struct ExecutionRecord {
30    /// The program.
31    pub program: Arc<Program>,
32    /// A trace of the CPU events which get emitted during execution.
33    pub cpu_events: Vec<CpuEvent>,
34    /// A trace of the ADD, and ADDI events.
35    pub add_events: Vec<AluEvent>,
36    /// A trace of the MUL events.
37    pub mul_events: Vec<AluEvent>,
38    /// A trace of the SUB events.
39    pub sub_events: Vec<AluEvent>,
40    /// A trace of the XOR, XORI, OR, ORI, AND, and ANDI events.
41    pub bitwise_events: Vec<AluEvent>,
42    /// A trace of the SLL and SLLI events.
43    pub shift_left_events: Vec<AluEvent>,
44    /// A trace of the SRL, SRLI, SRA, and SRAI events.
45    pub shift_right_events: Vec<AluEvent>,
46    /// A trace of the DIV, DIVU, REM, and REMU events.
47    pub divrem_events: Vec<AluEvent>,
48    /// A trace of the SLT, SLTI, SLTU, and SLTIU events.
49    pub lt_events: Vec<AluEvent>,
50    /// A trace of the memory instructions.
51    pub memory_instr_events: Vec<MemInstrEvent>,
52    /// A trace of the AUIPC events.
53    pub auipc_events: Vec<AUIPCEvent>,
54    /// A trace of the branch events.
55    pub branch_events: Vec<BranchEvent>,
56    /// A trace of the jump events.
57    pub jump_events: Vec<JumpEvent>,
58    /// A trace of the byte lookups that are needed.
59    pub byte_lookups: HashMap<ByteLookupEvent, usize>,
60    /// A trace of the precompile events.
61    pub precompile_events: PrecompileEvents,
62    /// A trace of the global memory initialize events.
63    pub global_memory_initialize_events: Vec<MemoryInitializeFinalizeEvent>,
64    /// A trace of the global memory finalize events.
65    pub global_memory_finalize_events: Vec<MemoryInitializeFinalizeEvent>,
66    /// A trace of all the shard's local memory events.
67    pub cpu_local_memory_access: Vec<MemoryLocalEvent>,
68    /// A trace of all the syscall events.
69    pub syscall_events: Vec<SyscallEvent>,
70    /// A trace of all the global interaction events.
71    pub global_interaction_events: Vec<GlobalInteractionEvent>,
72    /// The public values.
73    pub public_values: PublicValues<u32, u32>,
74    /// The next nonce to use for a new lookup.
75    pub next_nonce: u64,
76    /// The shape of the proof.
77    pub shape: Option<Shape<RiscvAirId>>,
78    /// The predicted counts of the proof.
79    pub counts: Option<EnumMap<RiscvAirId, u64>>,
80}
81
82impl ExecutionRecord {
83    /// Create a new [`ExecutionRecord`].
84    #[must_use]
85    pub fn new(program: Arc<Program>) -> Self {
86        Self { program, ..Default::default() }
87    }
88
89    /// Take out events from the [`ExecutionRecord`] that should be deferred to a separate shard.
90    ///
91    /// Note: we usually defer events that would increase the recursion cost significantly if
92    /// included in every shard.
93    #[must_use]
94    pub fn defer(&mut self) -> ExecutionRecord {
95        let mut execution_record = ExecutionRecord::new(self.program.clone());
96        execution_record.precompile_events = std::mem::take(&mut self.precompile_events);
97        execution_record.global_memory_initialize_events =
98            std::mem::take(&mut self.global_memory_initialize_events);
99        execution_record.global_memory_finalize_events =
100            std::mem::take(&mut self.global_memory_finalize_events);
101        execution_record
102    }
103
104    /// Splits the deferred [`ExecutionRecord`] into multiple [`ExecutionRecord`]s, each which
105    /// contain a "reasonable" number of deferred events.
106    ///
107    /// The optional `last_record` will be provided if there are few enough deferred events that
108    /// they can all be packed into the already existing last record.
109    pub fn split(
110        &mut self,
111        last: bool,
112        last_record: Option<&mut ExecutionRecord>,
113        opts: SplitOpts,
114    ) -> Vec<ExecutionRecord> {
115        let mut shards = Vec::new();
116
117        let precompile_events = take(&mut self.precompile_events);
118
119        for (syscall_code, events) in precompile_events.into_iter() {
120            let threshold = match syscall_code {
121                SyscallCode::KECCAK_PERMUTE => opts.keccak,
122                SyscallCode::SHA_EXTEND => opts.sha_extend,
123                SyscallCode::SHA_COMPRESS => opts.sha_compress,
124                _ => opts.deferred,
125            };
126
127            let chunks = events.chunks_exact(threshold);
128            if last {
129                let remainder = chunks.remainder().to_vec();
130                if !remainder.is_empty() {
131                    let mut execution_record = ExecutionRecord::new(self.program.clone());
132                    execution_record.precompile_events.insert(syscall_code, remainder);
133                    shards.push(execution_record);
134                }
135            } else {
136                self.precompile_events.insert(syscall_code, chunks.remainder().to_vec());
137            }
138            let mut event_shards = chunks
139                .map(|chunk| {
140                    let mut execution_record = ExecutionRecord::new(self.program.clone());
141                    execution_record.precompile_events.insert(syscall_code, chunk.to_vec());
142                    execution_record
143                })
144                .collect::<Vec<_>>();
145            shards.append(&mut event_shards);
146        }
147
148        if last {
149            self.global_memory_initialize_events.sort_by_key(|event| event.addr);
150            self.global_memory_finalize_events.sort_by_key(|event| event.addr);
151
152            // If there are no precompile shards, and `last_record` is Some, pack the memory events
153            // into the last record.
154            let pack_memory_events_into_last_record = last_record.is_some() && shards.is_empty();
155            let mut blank_record = ExecutionRecord::new(self.program.clone());
156
157            // If `last_record` is None, use a blank record to store the memory events.
158            let last_record_ref = if pack_memory_events_into_last_record {
159                last_record.unwrap()
160            } else {
161                &mut blank_record
162            };
163
164            let mut init_addr_bits = [0; 32];
165            let mut finalize_addr_bits = [0; 32];
166            for mem_chunks in self
167                .global_memory_initialize_events
168                .chunks(opts.memory)
169                .zip_longest(self.global_memory_finalize_events.chunks(opts.memory))
170            {
171                let (mem_init_chunk, mem_finalize_chunk) = match mem_chunks {
172                    EitherOrBoth::Both(mem_init_chunk, mem_finalize_chunk) => {
173                        (mem_init_chunk, mem_finalize_chunk)
174                    }
175                    EitherOrBoth::Left(mem_init_chunk) => (mem_init_chunk, [].as_slice()),
176                    EitherOrBoth::Right(mem_finalize_chunk) => ([].as_slice(), mem_finalize_chunk),
177                };
178                last_record_ref.global_memory_initialize_events.extend_from_slice(mem_init_chunk);
179                last_record_ref.public_values.previous_init_addr_bits = init_addr_bits;
180                if let Some(last_event) = mem_init_chunk.last() {
181                    let last_init_addr_bits = core::array::from_fn(|i| (last_event.addr >> i) & 1);
182                    init_addr_bits = last_init_addr_bits;
183                }
184                last_record_ref.public_values.last_init_addr_bits = init_addr_bits;
185
186                last_record_ref.global_memory_finalize_events.extend_from_slice(mem_finalize_chunk);
187                last_record_ref.public_values.previous_finalize_addr_bits = finalize_addr_bits;
188                if let Some(last_event) = mem_finalize_chunk.last() {
189                    let last_finalize_addr_bits =
190                        core::array::from_fn(|i| (last_event.addr >> i) & 1);
191                    finalize_addr_bits = last_finalize_addr_bits;
192                }
193                last_record_ref.public_values.last_finalize_addr_bits = finalize_addr_bits;
194
195                if !pack_memory_events_into_last_record {
196                    // If not packing memory events into the last record, add 'last_record_ref'
197                    // to the returned records. `take` replaces `blank_program` with the default.
198                    shards.push(take(last_record_ref));
199
200                    // Reset the last record so its program is the correct one. (The default program
201                    // provided by `take` contains no instructions.)
202                    last_record_ref.program = self.program.clone();
203                }
204            }
205        }
206        shards
207    }
208
209    /// Return the number of rows needed for a chip, according to the proof shape specified in the
210    /// struct.
211    pub fn fixed_log2_rows<F: PrimeField, A: MachineAir<F>>(&self, air: &A) -> Option<usize> {
212        self.shape.as_ref().map(|shape| {
213            shape
214                .log2_height(&RiscvAirId::from_str(&air.name()).unwrap())
215                .unwrap_or_else(|| panic!("Chip {} not found in specified shape", air.name()))
216        })
217    }
218
219    /// Determines whether the execution record contains CPU events.
220    #[must_use]
221    pub fn contains_cpu(&self) -> bool {
222        !self.cpu_events.is_empty()
223    }
224
225    #[inline]
226    /// Add a precompile event to the execution record.
227    pub fn add_precompile_event(
228        &mut self,
229        syscall_code: SyscallCode,
230        syscall_event: SyscallEvent,
231        event: PrecompileEvent,
232    ) {
233        self.precompile_events.add_event(syscall_code, syscall_event, event);
234    }
235
236    /// Get all the precompile events for a syscall code.
237    #[inline]
238    #[must_use]
239    pub fn get_precompile_events(
240        &self,
241        syscall_code: SyscallCode,
242    ) -> &Vec<(SyscallEvent, PrecompileEvent)> {
243        self.precompile_events.get_events(syscall_code).expect("Precompile events not found")
244    }
245
246    /// Get all the local memory events.
247    #[inline]
248    pub fn get_local_mem_events(&self) -> impl Iterator<Item = &MemoryLocalEvent> {
249        let precompile_local_mem_events = self.precompile_events.get_local_mem_events();
250        precompile_local_mem_events.chain(self.cpu_local_memory_access.iter())
251    }
252}
253
254/// A memory access record.
255#[derive(Debug, Copy, Clone, Default)]
256pub struct MemoryAccessRecord {
257    /// The memory access of the `a` register.
258    pub a: Option<MemoryRecordEnum>,
259    /// The memory access of the `b` register.
260    pub b: Option<MemoryRecordEnum>,
261    /// The memory access of the `c` register.
262    pub c: Option<MemoryRecordEnum>,
263    /// The memory access of the `memory` register.
264    pub memory: Option<MemoryRecordEnum>,
265}
266
267impl MachineRecord for ExecutionRecord {
268    type Config = SP1CoreOpts;
269
270    fn stats(&self) -> HashMap<String, usize> {
271        let mut stats = HashMap::new();
272        stats.insert("cpu_events".to_string(), self.cpu_events.len());
273        stats.insert("add_events".to_string(), self.add_events.len());
274        stats.insert("mul_events".to_string(), self.mul_events.len());
275        stats.insert("sub_events".to_string(), self.sub_events.len());
276        stats.insert("bitwise_events".to_string(), self.bitwise_events.len());
277        stats.insert("shift_left_events".to_string(), self.shift_left_events.len());
278        stats.insert("shift_right_events".to_string(), self.shift_right_events.len());
279        stats.insert("divrem_events".to_string(), self.divrem_events.len());
280        stats.insert("lt_events".to_string(), self.lt_events.len());
281        stats.insert("memory_instructions_events".to_string(), self.memory_instr_events.len());
282        stats.insert("branch_events".to_string(), self.branch_events.len());
283        stats.insert("jump_events".to_string(), self.jump_events.len());
284        stats.insert("auipc_events".to_string(), self.auipc_events.len());
285
286        for (syscall_code, events) in self.precompile_events.iter() {
287            stats.insert(format!("syscall {syscall_code:?}"), events.len());
288        }
289
290        stats.insert(
291            "global_memory_initialize_events".to_string(),
292            self.global_memory_initialize_events.len(),
293        );
294        stats.insert(
295            "global_memory_finalize_events".to_string(),
296            self.global_memory_finalize_events.len(),
297        );
298        stats.insert("local_memory_access_events".to_string(), self.cpu_local_memory_access.len());
299        if !self.cpu_events.is_empty() {
300            stats.insert("byte_lookups".to_string(), self.byte_lookups.len());
301        }
302        // Filter out the empty events.
303        stats.retain(|_, v| *v != 0);
304        stats
305    }
306
307    fn append(&mut self, other: &mut ExecutionRecord) {
308        self.cpu_events.append(&mut other.cpu_events);
309        self.add_events.append(&mut other.add_events);
310        self.sub_events.append(&mut other.sub_events);
311        self.mul_events.append(&mut other.mul_events);
312        self.bitwise_events.append(&mut other.bitwise_events);
313        self.shift_left_events.append(&mut other.shift_left_events);
314        self.shift_right_events.append(&mut other.shift_right_events);
315        self.divrem_events.append(&mut other.divrem_events);
316        self.lt_events.append(&mut other.lt_events);
317        self.memory_instr_events.append(&mut other.memory_instr_events);
318        self.branch_events.append(&mut other.branch_events);
319        self.jump_events.append(&mut other.jump_events);
320        self.auipc_events.append(&mut other.auipc_events);
321        self.syscall_events.append(&mut other.syscall_events);
322
323        self.precompile_events.append(&mut other.precompile_events);
324
325        if self.byte_lookups.is_empty() {
326            self.byte_lookups = std::mem::take(&mut other.byte_lookups);
327        } else {
328            self.add_byte_lookup_events_from_maps(vec![&other.byte_lookups]);
329        }
330
331        self.global_memory_initialize_events.append(&mut other.global_memory_initialize_events);
332        self.global_memory_finalize_events.append(&mut other.global_memory_finalize_events);
333        self.cpu_local_memory_access.append(&mut other.cpu_local_memory_access);
334        self.global_interaction_events.append(&mut other.global_interaction_events);
335    }
336
337    /// Retrieves the public values.  This method is needed for the `MachineRecord` trait, since
338    fn public_values<F: AbstractField>(&self) -> Vec<F> {
339        self.public_values.to_vec()
340    }
341}
342
343impl ByteRecord for ExecutionRecord {
344    fn add_byte_lookup_event(&mut self, blu_event: ByteLookupEvent) {
345        *self.byte_lookups.entry(blu_event).or_insert(0) += 1;
346    }
347
348    #[inline]
349    fn add_byte_lookup_events_from_maps(
350        &mut self,
351        new_events: Vec<&HashMap<ByteLookupEvent, usize>>,
352    ) {
353        for new_blu_map in new_events {
354            for (blu_event, count) in new_blu_map.iter() {
355                *self.byte_lookups.entry(*blu_event).or_insert(0) += count;
356            }
357        }
358    }
359}