Skip to main content

sp1_core_executor/
record.rs

1use crate::events::{TrapExecEvent, TrapMemInstrEvent};
2use deepsize2::DeepSizeOf;
3use hashbrown::HashMap;
4use slop_air::AirBuilder;
5use slop_algebra::{AbstractField, Field, PrimeField, PrimeField32};
6use sp1_hypercube::{
7    air::{
8        AirInteraction, BaseAirBuilder, InteractionScope, MachineAir, PublicValues, SP1AirBuilder,
9        PROOF_NONCE_NUM_WORDS, PV_DIGEST_NUM_WORDS, SP1_PROOF_NUM_PV_ELTS,
10    },
11    septic_digest::SepticDigest,
12    shape::Shape,
13    InteractionKind, MachineRecord,
14};
15use std::{
16    borrow::Borrow,
17    iter::once,
18    mem::take,
19    sync::{Arc, Mutex},
20};
21
22use serde::{Deserialize, Serialize};
23
24use crate::{
25    events::{
26        AluEvent, BranchEvent, ByteLookupEvent, ByteRecord, GlobalInteractionEvent,
27        InstructionDecodeEvent, InstructionFetchEvent, JumpEvent, MemInstrEvent,
28        MemoryInitializeFinalizeEvent, MemoryLocalEvent, MemoryRecordEnum,
29        PageProtInitializeFinalizeEvent, PageProtLocalEvent, PrecompileEvent, PrecompileEvents,
30        SyscallEvent, UTypeEvent,
31    },
32    program::Program,
33    ByteOpcode, Instruction, RetainedEventsPreset, RiscvAirId, SplitOpts, SyscallCode,
34};
35
36/// A record of the execution of a program.
37///
38/// The trace of the execution is represented as a list of "events" that occur every cycle.
39#[derive(Clone, Debug, Serialize, Deserialize, Default, DeepSizeOf)]
40pub struct ExecutionRecord {
41    /// The program.
42    pub program: Arc<Program>,
43    /// The number of CPU related events.
44    pub cpu_event_count: u32,
45    /// A trace of ALU events with `rd = x0`.
46    pub alu_x0_events: Vec<(AluEvent, ALUTypeRecord)>,
47    /// A trace of the ADD, and ADDI events.
48    pub add_events: Vec<(AluEvent, RTypeRecord)>,
49    /// A trace of the ADDW events.
50    pub addw_events: Vec<(AluEvent, ALUTypeRecord)>,
51    /// A trace of the ADDI events.
52    pub addi_events: Vec<(AluEvent, ITypeRecord)>,
53    /// A trace of the MUL events.
54    pub mul_events: Vec<(AluEvent, RTypeRecord)>,
55    /// A trace of the SUB events.
56    pub sub_events: Vec<(AluEvent, RTypeRecord)>,
57    /// A trace of the SUBW events.
58    pub subw_events: Vec<(AluEvent, RTypeRecord)>,
59    /// A trace of the XOR, XORI, OR, ORI, AND, and ANDI events.
60    pub bitwise_events: Vec<(AluEvent, ALUTypeRecord)>,
61    /// A trace of the SLL and SLLI events.
62    pub shift_left_events: Vec<(AluEvent, ALUTypeRecord)>,
63    /// A trace of the SRL, SRLI, SRA, and SRAI events.
64    pub shift_right_events: Vec<(AluEvent, ALUTypeRecord)>,
65    /// A trace of the DIV, DIVU, REM, and REMU events.
66    pub divrem_events: Vec<(AluEvent, RTypeRecord)>,
67    /// A trace of the SLT, SLTI, SLTU, and SLTIU events.
68    pub lt_events: Vec<(AluEvent, ALUTypeRecord)>,
69    /// A trace of load byte instructions.
70    pub memory_load_byte_events: Vec<(MemInstrEvent, ITypeRecord)>,
71    /// A trace of load half instructions.
72    pub memory_load_half_events: Vec<(MemInstrEvent, ITypeRecord)>,
73    /// A trace of load word instructions.
74    pub memory_load_word_events: Vec<(MemInstrEvent, ITypeRecord)>,
75    /// A trace of load instructions with `op_a = x0`.
76    pub memory_load_x0_events: Vec<(MemInstrEvent, ITypeRecord)>,
77    /// A trace of load double instructions.
78    pub memory_load_double_events: Vec<(MemInstrEvent, ITypeRecord)>,
79    /// A trace of store byte instructions.
80    pub memory_store_byte_events: Vec<(MemInstrEvent, ITypeRecord)>,
81    /// A trace of store half instructions.
82    pub memory_store_half_events: Vec<(MemInstrEvent, ITypeRecord)>,
83    /// A trace of store word instructions.
84    pub memory_store_word_events: Vec<(MemInstrEvent, ITypeRecord)>,
85    /// A trace of store double instructions.
86    pub memory_store_double_events: Vec<(MemInstrEvent, ITypeRecord)>,
87    /// A trace of the AUIPC and LUI events.
88    pub utype_events: Vec<(UTypeEvent, JTypeRecord)>,
89    /// A trace of the branch events.
90    pub branch_events: Vec<(BranchEvent, ITypeRecord)>,
91    /// A trace of the JAL events.
92    pub jal_events: Vec<(JumpEvent, JTypeRecord)>,
93    /// A trace of the JALR events.
94    pub jalr_events: Vec<(JumpEvent, ITypeRecord)>,
95    /// A trace of the byte lookups that are needed.
96    pub byte_lookups: HashMap<ByteLookupEvent, usize>,
97    /// A trace of the precompile events.
98    pub precompile_events: PrecompileEvents,
99    /// A trace of the global memory initialize events.
100    pub global_memory_initialize_events: Vec<MemoryInitializeFinalizeEvent>,
101    /// A trace of the global memory finalize events.
102    pub global_memory_finalize_events: Vec<MemoryInitializeFinalizeEvent>,
103    /// A trace of the global page prot initialize events.
104    pub global_page_prot_initialize_events: Vec<PageProtInitializeFinalizeEvent>,
105    /// A trace of the global page prot finalize events.
106    pub global_page_prot_finalize_events: Vec<PageProtInitializeFinalizeEvent>,
107    /// A trace of all the shard's local memory events.
108    pub cpu_local_memory_access: Vec<MemoryLocalEvent>,
109    /// A trace of all the local page prot events.
110    pub cpu_local_page_prot_access: Vec<PageProtLocalEvent>,
111    /// A trace of all the syscall events.
112    pub syscall_events: Vec<(SyscallEvent, RTypeRecord)>,
113    /// A trace of all the global interaction events.
114    pub global_interaction_events: Vec<GlobalInteractionEvent>,
115    /// A trace of all instruction fetch events.
116    pub instruction_fetch_events: Vec<(InstructionFetchEvent, MemoryAccessRecord)>,
117    /// A trace of all instruction decode events.
118    pub instruction_decode_events: Vec<InstructionDecodeEvent>,
119    /// A trace of all trap on untrusted program execution.
120    pub trap_exec_events: Vec<TrapExecEvent>,
121    /// A trace of all trap on load and store events.
122    pub trap_load_store_events: Vec<(TrapMemInstrEvent, ITypeRecord)>,
123    /// The global culmulative sum.
124    pub global_cumulative_sum: Arc<Mutex<SepticDigest<u32>>>,
125    /// The global interaction event count.
126    pub global_interaction_event_count: u32,
127    /// Memory records used to bump the timestamp of the register memory access.
128    pub bump_memory_events: Vec<(MemoryRecordEnum, u64, bool)>,
129    /// Record where the `clk >> 24` or `pc >> 16` has incremented.
130    pub bump_state_events: Vec<(u64, u64, bool, u64)>,
131    /// The public values.
132    pub public_values: PublicValues<u32, u64, u64, u32>,
133    /// The next nonce to use for a new lookup.
134    pub next_nonce: u64,
135    /// The shape of the proof.
136    pub shape: Option<Shape<RiscvAirId>>,
137    /// The estimated total trace area of the proof.
138    pub estimated_trace_area: u64,
139    /// The initial timestamp of the shard.
140    pub initial_timestamp: u64,
141    /// The final timestamp of the shard.
142    pub last_timestamp: u64,
143    /// The start program counter.
144    pub pc_start: Option<u64>,
145    /// The final program counter.
146    pub next_pc: u64,
147    /// The exit code.
148    pub exit_code: u32,
149    /// Use optimized `generate_dependencies` for global chip.
150    pub global_dependencies_opt: bool,
151}
152
153impl ExecutionRecord {
154    /// Create a new [`ExecutionRecord`].
155    #[must_use]
156    pub fn new(
157        program: Arc<Program>,
158        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
159        global_dependencies_opt: bool,
160    ) -> Self {
161        let enable_untrusted_programs = program.enable_untrusted_programs as u32;
162        #[cfg(feature = "mprotect")]
163        let trap_context = program.trap_context;
164        #[cfg(feature = "mprotect")]
165        let untrusted_memory = program.untrusted_memory;
166
167        let mut result = Self { program, ..Default::default() };
168        result.public_values.proof_nonce = proof_nonce;
169        result.public_values.is_untrusted_programs_enabled = enable_untrusted_programs;
170
171        #[cfg(feature = "mprotect")]
172        {
173            result.public_values.enable_trap_handler = trap_context.is_some() as u32;
174            result.public_values.trap_context =
175                trap_context.map_or([0, 0, 0], |addr| [addr, addr + 8, addr + 16]);
176            result.public_values.untrusted_memory =
177                untrusted_memory.map_or([0, 0], |(start, end)| [start, end]);
178        }
179        result.global_dependencies_opt = global_dependencies_opt;
180        result
181    }
182
183    /// Create a new [`ExecutionRecord`] with preallocated event vecs.
184    #[must_use]
185    pub fn new_preallocated(
186        program: Arc<Program>,
187        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
188        global_dependencies_opt: bool,
189        reservation_size: usize,
190    ) -> Self {
191        let enable_untrusted_programs = program.enable_untrusted_programs;
192        #[cfg(feature = "mprotect")]
193        let trap_context = program.trap_context;
194        #[cfg(feature = "mprotect")]
195        let untrusted_memory = program.untrusted_memory;
196        let mut result = Self { program, ..Default::default() };
197
198        result.alu_x0_events.reserve(reservation_size);
199        result.add_events.reserve(reservation_size);
200        result.addi_events.reserve(reservation_size);
201        result.addw_events.reserve(reservation_size);
202        result.mul_events.reserve(reservation_size);
203        result.sub_events.reserve(reservation_size);
204        result.subw_events.reserve(reservation_size);
205        result.bitwise_events.reserve(reservation_size);
206        result.shift_left_events.reserve(reservation_size);
207        result.shift_right_events.reserve(reservation_size);
208        result.divrem_events.reserve(reservation_size);
209        result.lt_events.reserve(reservation_size);
210        result.branch_events.reserve(reservation_size);
211        result.jal_events.reserve(reservation_size);
212        result.jalr_events.reserve(reservation_size);
213        result.utype_events.reserve(reservation_size);
214        result.memory_load_x0_events.reserve(reservation_size);
215        result.memory_load_byte_events.reserve(reservation_size);
216        result.memory_load_half_events.reserve(reservation_size);
217        result.memory_load_word_events.reserve(reservation_size);
218        result.memory_load_double_events.reserve(reservation_size);
219        result.memory_store_byte_events.reserve(reservation_size);
220        result.memory_store_half_events.reserve(reservation_size);
221        result.memory_store_word_events.reserve(reservation_size);
222        result.memory_store_double_events.reserve(reservation_size);
223        result.global_memory_initialize_events.reserve(reservation_size);
224        result.global_memory_finalize_events.reserve(reservation_size);
225        result.global_interaction_events.reserve(reservation_size);
226        result.byte_lookups.reserve(reservation_size);
227
228        result.public_values.proof_nonce = proof_nonce;
229        result.public_values.is_untrusted_programs_enabled = enable_untrusted_programs as u32;
230        #[cfg(feature = "mprotect")]
231        {
232            result.public_values.enable_trap_handler = trap_context.is_some() as u32;
233            result.public_values.trap_context =
234                trap_context.map_or([0, 0, 0], |addr| [addr, addr + 8, addr + 16]);
235            result.public_values.untrusted_memory =
236                untrusted_memory.map_or([0, 0], |(start, end)| [start, end]);
237        }
238        result.global_dependencies_opt = global_dependencies_opt;
239        result
240    }
241
242    /// Take out events from the [`ExecutionRecord`] that should be deferred to a separate shard.
243    ///
244    /// Note: we usually defer events that would increase the recursion cost significantly if
245    /// included in every shard.
246    #[must_use]
247    pub fn defer<'a>(
248        &mut self,
249        retain_presets: impl IntoIterator<Item = &'a RetainedEventsPreset>,
250    ) -> ExecutionRecord {
251        let mut execution_record = ExecutionRecord::new(
252            self.program.clone(),
253            self.public_values.proof_nonce,
254            self.global_dependencies_opt,
255        );
256        execution_record.precompile_events = std::mem::take(&mut self.precompile_events);
257
258        // Take back the events that should be retained.
259        self.precompile_events.events.extend(
260            retain_presets.into_iter().flat_map(RetainedEventsPreset::syscall_codes).filter_map(
261                |code| execution_record.precompile_events.events.remove(code).map(|x| (*code, x)),
262            ),
263        );
264
265        execution_record.global_memory_initialize_events =
266            std::mem::take(&mut self.global_memory_initialize_events);
267        execution_record.global_memory_finalize_events =
268            std::mem::take(&mut self.global_memory_finalize_events);
269        execution_record.global_page_prot_initialize_events =
270            std::mem::take(&mut self.global_page_prot_initialize_events);
271        execution_record.global_page_prot_finalize_events =
272            std::mem::take(&mut self.global_page_prot_finalize_events);
273        execution_record
274    }
275
276    /// Splits the deferred [`ExecutionRecord`] into multiple [`ExecutionRecord`]s, each which
277    /// contain a "reasonable" number of deferred events.
278    #[allow(clippy::too_many_lines)]
279    pub fn split(
280        &mut self,
281        done: bool,
282        last_record: &mut ExecutionRecord,
283        can_pack_global_memory: bool,
284        opts: &SplitOpts,
285    ) -> Vec<ExecutionRecord> {
286        let mut shards = Vec::new();
287
288        let precompile_events = take(&mut self.precompile_events);
289
290        for (syscall_code, events) in precompile_events.into_iter() {
291            let threshold: usize = opts.syscall_threshold[syscall_code];
292
293            let chunks = events.chunks_exact(threshold);
294            if done {
295                let remainder = chunks.remainder().to_vec();
296                if !remainder.is_empty() {
297                    let mut execution_record = ExecutionRecord::new(
298                        self.program.clone(),
299                        self.public_values.proof_nonce,
300                        self.global_dependencies_opt,
301                    );
302                    execution_record.precompile_events.insert(syscall_code, remainder);
303                    execution_record.public_values.update_initialized_state(
304                        self.program.pc_start_abs,
305                        self.program.enable_untrusted_programs,
306                        self.program.trap_context,
307                        self.program.untrusted_memory,
308                    );
309                    shards.push(execution_record);
310                }
311            } else {
312                self.precompile_events.insert(syscall_code, chunks.remainder().to_vec());
313            }
314            let mut event_shards = chunks
315                .map(|chunk| {
316                    let mut execution_record = ExecutionRecord::new(
317                        self.program.clone(),
318                        self.public_values.proof_nonce,
319                        self.global_dependencies_opt,
320                    );
321                    execution_record.precompile_events.insert(syscall_code, chunk.to_vec());
322                    execution_record.public_values.update_initialized_state(
323                        self.program.pc_start_abs,
324                        self.program.enable_untrusted_programs,
325                        self.program.trap_context,
326                        self.program.untrusted_memory,
327                    );
328                    execution_record
329                })
330                .collect::<Vec<_>>();
331            shards.append(&mut event_shards);
332        }
333
334        if done {
335            // If there are no precompile shards, and `last_record` is Some, pack the memory events
336            // into the last record.
337            let pack_memory_events_into_last_record = can_pack_global_memory && shards.is_empty();
338            let mut blank_record = ExecutionRecord::new(
339                self.program.clone(),
340                self.public_values.proof_nonce,
341                self.global_dependencies_opt,
342            );
343
344            // Clone the public values of the last record to update the last record's public values.
345            let last_record_public_values = last_record.public_values;
346
347            // Update the state of the blank record
348            blank_record
349                .public_values
350                .update_finalized_state_from_public_values(&last_record_public_values);
351
352            // If `last_record` is None, use a blank record to store the memory events.
353            let mem_record_ref =
354                if pack_memory_events_into_last_record { last_record } else { &mut blank_record };
355
356            let mut init_page_idx = 0;
357            let mut finalize_page_idx = 0;
358
359            // Put all of the page prot init and finalize events into the last record.
360            if !self.global_page_prot_initialize_events.is_empty()
361                || !self.global_page_prot_finalize_events.is_empty()
362            {
363                self.global_page_prot_initialize_events.sort_by_key(|event| event.page_idx);
364                self.global_page_prot_finalize_events.sort_by_key(|event| event.page_idx);
365
366                let init_iter = self.global_page_prot_initialize_events.iter();
367                let finalize_iter = self.global_page_prot_finalize_events.iter();
368                let mut init_remaining = init_iter.as_slice();
369                let mut finalize_remaining = finalize_iter.as_slice();
370
371                while !init_remaining.is_empty() || !finalize_remaining.is_empty() {
372                    let capacity = 2 * opts.page_prot;
373                    let init_to_take = init_remaining.len().min(capacity);
374                    let finalize_to_take = finalize_remaining.len().min(capacity - init_to_take);
375
376                    let finalize_to_take = if init_to_take < capacity {
377                        finalize_to_take.max(finalize_remaining.len().min(capacity - init_to_take))
378                    } else {
379                        0
380                    };
381
382                    let page_prot_init_chunk = &init_remaining[..init_to_take];
383                    let page_prot_finalize_chunk = &finalize_remaining[..finalize_to_take];
384
385                    mem_record_ref
386                        .global_page_prot_initialize_events
387                        .extend_from_slice(page_prot_init_chunk);
388                    mem_record_ref.public_values.previous_init_page_idx = init_page_idx;
389                    if let Some(last_event) = page_prot_init_chunk.last() {
390                        init_page_idx = last_event.page_idx;
391                    }
392                    mem_record_ref.public_values.last_init_page_idx = init_page_idx;
393
394                    mem_record_ref
395                        .global_page_prot_finalize_events
396                        .extend_from_slice(page_prot_finalize_chunk);
397                    mem_record_ref.public_values.previous_finalize_page_idx = finalize_page_idx;
398                    if let Some(last_event) = page_prot_finalize_chunk.last() {
399                        finalize_page_idx = last_event.page_idx;
400                    }
401                    mem_record_ref.public_values.last_finalize_page_idx = finalize_page_idx;
402
403                    // Because page prot events are non empty, we set the page protect active flag
404                    mem_record_ref.public_values.is_untrusted_programs_enabled = true as u32;
405
406                    init_remaining = &init_remaining[init_to_take..];
407                    finalize_remaining = &finalize_remaining[finalize_to_take..];
408
409                    // Ensure last record has same proof nonce as other shards
410                    mem_record_ref.public_values.proof_nonce = self.public_values.proof_nonce;
411                    mem_record_ref.global_dependencies_opt = self.global_dependencies_opt;
412
413                    if !pack_memory_events_into_last_record {
414                        // If not packing memory events into the last record, add 'last_record_ref'
415                        // to the returned records. `take` replaces `blank_program` with the
416                        // default.
417                        shards.push(take(mem_record_ref));
418
419                        // Reset the last record so its program is the correct one. (The default
420                        // program provided by `take` contains no
421                        // instructions.)
422                        mem_record_ref.program = self.program.clone();
423                        // Reset the public values execution state to match the last record state.
424                        mem_record_ref
425                            .public_values
426                            .update_finalized_state_from_public_values(&last_record_public_values);
427                    }
428                }
429            }
430
431            self.global_memory_initialize_events.sort_by_key(|event| event.addr);
432            self.global_memory_finalize_events.sort_by_key(|event| event.addr);
433
434            let mut init_addr = 0;
435            let mut finalize_addr = 0;
436
437            let mut mem_init_remaining = self.global_memory_initialize_events.as_slice();
438            let mut mem_finalize_remaining = self.global_memory_finalize_events.as_slice();
439
440            while !mem_init_remaining.is_empty() || !mem_finalize_remaining.is_empty() {
441                let capacity = 2 * opts.memory;
442                let init_to_take = mem_init_remaining.len().min(capacity);
443                let finalize_to_take = mem_finalize_remaining.len().min(capacity - init_to_take);
444
445                let finalize_to_take = if init_to_take < capacity {
446                    finalize_to_take.max(mem_finalize_remaining.len().min(capacity - init_to_take))
447                } else {
448                    0
449                };
450
451                let mem_init_chunk = &mem_init_remaining[..init_to_take];
452                let mem_finalize_chunk = &mem_finalize_remaining[..finalize_to_take];
453
454                mem_record_ref.global_memory_initialize_events.extend_from_slice(mem_init_chunk);
455                mem_record_ref.public_values.previous_init_addr = init_addr;
456                if let Some(last_event) = mem_init_chunk.last() {
457                    init_addr = last_event.addr;
458                }
459                mem_record_ref.public_values.last_init_addr = init_addr;
460
461                mem_record_ref.global_memory_finalize_events.extend_from_slice(mem_finalize_chunk);
462                mem_record_ref.public_values.previous_finalize_addr = finalize_addr;
463                if let Some(last_event) = mem_finalize_chunk.last() {
464                    finalize_addr = last_event.addr;
465                }
466                mem_record_ref.public_values.last_finalize_addr = finalize_addr;
467
468                mem_record_ref.public_values.proof_nonce = self.public_values.proof_nonce;
469                mem_record_ref.global_dependencies_opt = self.global_dependencies_opt;
470
471                mem_init_remaining = &mem_init_remaining[init_to_take..];
472                mem_finalize_remaining = &mem_finalize_remaining[finalize_to_take..];
473
474                if !pack_memory_events_into_last_record {
475                    mem_record_ref.public_values.previous_init_page_idx = init_page_idx;
476                    mem_record_ref.public_values.last_init_page_idx = init_page_idx;
477                    mem_record_ref.public_values.previous_finalize_page_idx = finalize_page_idx;
478                    mem_record_ref.public_values.last_finalize_page_idx = finalize_page_idx;
479
480                    // If not packing memory events into the last record, add 'last_record_ref'
481                    // to the returned records. `take` replaces `blank_program` with the default.
482                    shards.push(take(mem_record_ref));
483
484                    // Reset the last record so its program is the correct one. (The default program
485                    // provided by `take` contains no instructions.)
486                    mem_record_ref.program = self.program.clone();
487                    // Reset the public values execution state to match the last record state.
488                    mem_record_ref
489                        .public_values
490                        .update_finalized_state_from_public_values(&last_record_public_values);
491                }
492            }
493        }
494
495        shards
496    }
497
498    /// Return the number of rows needed for a chip, according to the proof shape specified in the
499    /// struct.
500    ///
501    /// **deprecated**: TODO: remove this method.
502    pub fn fixed_log2_rows<F: PrimeField, A: MachineAir<F>>(&self, _air: &A) -> Option<usize> {
503        None
504    }
505
506    /// Determines whether the execution record contains CPU events.
507    #[must_use]
508    pub fn contains_cpu(&self) -> bool {
509        self.cpu_event_count > 0
510    }
511
512    #[inline]
513    /// Add a precompile event to the execution record.
514    pub fn add_precompile_event(
515        &mut self,
516        syscall_code: SyscallCode,
517        syscall_event: SyscallEvent,
518        event: PrecompileEvent,
519    ) {
520        self.precompile_events.add_event(syscall_code, syscall_event, event);
521    }
522
523    /// Get all the precompile events for a syscall code.
524    #[inline]
525    #[must_use]
526    pub fn get_precompile_events(
527        &self,
528        syscall_code: SyscallCode,
529    ) -> &Vec<(SyscallEvent, PrecompileEvent)> {
530        self.precompile_events.get_events(syscall_code).expect("Precompile events not found")
531    }
532
533    /// Get all the local memory events.
534    #[inline]
535    pub fn get_local_mem_events(&self) -> impl Iterator<Item = &MemoryLocalEvent> {
536        let precompile_local_mem_events = self.precompile_events.get_local_mem_events();
537        precompile_local_mem_events.chain(self.cpu_local_memory_access.iter())
538    }
539
540    /// Get all the local page prot events.
541    #[inline]
542    pub fn get_local_page_prot_events(&self) -> impl Iterator<Item = &PageProtLocalEvent> {
543        let precompile_local_page_prot_events = self.precompile_events.get_local_page_prot_events();
544        precompile_local_page_prot_events.chain(self.cpu_local_page_prot_access.iter())
545    }
546
547    /// Reset the record, without deallocating the event vecs.
548    #[inline]
549    pub fn reset(&mut self) {
550        self.alu_x0_events.truncate(0);
551        self.add_events.truncate(0);
552        self.addw_events.truncate(0);
553        self.addi_events.truncate(0);
554        self.mul_events.truncate(0);
555        self.sub_events.truncate(0);
556        self.subw_events.truncate(0);
557        self.bitwise_events.truncate(0);
558        self.shift_left_events.truncate(0);
559        self.shift_right_events.truncate(0);
560        self.divrem_events.truncate(0);
561        self.lt_events.truncate(0);
562        self.memory_load_byte_events.truncate(0);
563        self.memory_load_half_events.truncate(0);
564        self.memory_load_word_events.truncate(0);
565        self.memory_load_x0_events.truncate(0);
566        self.memory_load_double_events.truncate(0);
567        self.memory_store_byte_events.truncate(0);
568        self.memory_store_half_events.truncate(0);
569        self.memory_store_word_events.truncate(0);
570        self.memory_store_double_events.truncate(0);
571        self.utype_events.truncate(0);
572        self.branch_events.truncate(0);
573        self.jal_events.truncate(0);
574        self.jalr_events.truncate(0);
575        self.byte_lookups.clear();
576        self.precompile_events = PrecompileEvents::default();
577        self.global_memory_initialize_events.truncate(0);
578        self.global_memory_finalize_events.truncate(0);
579        self.global_page_prot_initialize_events.truncate(0);
580        self.global_page_prot_finalize_events.truncate(0);
581        self.cpu_local_memory_access.truncate(0);
582        self.cpu_local_page_prot_access.truncate(0);
583        self.syscall_events.truncate(0);
584        self.global_interaction_events.truncate(0);
585        self.instruction_fetch_events.truncate(0);
586        self.instruction_decode_events.truncate(0);
587        let mut cumulative_sum = self.global_cumulative_sum.lock().unwrap();
588        *cumulative_sum = SepticDigest::default();
589        self.global_interaction_event_count = 0;
590        self.bump_memory_events.truncate(0);
591        self.bump_state_events.truncate(0);
592        let _ = self.public_values.reset();
593        self.next_nonce = 0;
594        self.shape = None;
595        self.estimated_trace_area = 0;
596        self.initial_timestamp = 0;
597        self.last_timestamp = 0;
598        self.pc_start = None;
599        self.next_pc = 0;
600        self.exit_code = 0;
601    }
602}
603
604/// A memory access record.
605#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize, DeepSizeOf)]
606pub struct MemoryAccessRecord {
607    /// The memory access of the `a` register.
608    pub a: Option<MemoryRecordEnum>,
609    /// The memory access of the `b` register.
610    pub b: Option<MemoryRecordEnum>,
611    /// The memory access of the `c` register.
612    pub c: Option<MemoryRecordEnum>,
613    /// The memory access of the `memory` register.
614    pub memory: Option<MemoryRecordEnum>,
615    /// The memory access of the untrusted instruction.
616    /// If memory access for `untrusted_instruction` occurs, we also pass along the selected 32
617    /// bits that is the encoded 32 bit instruction alongside the raw 64bit read
618    pub untrusted_instruction: Option<(MemoryRecordEnum, u32)>,
619}
620
621/// Memory record where all three operands are registers.
622#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, DeepSizeOf)]
623pub struct RTypeRecord {
624    /// The a operand.
625    pub op_a: u8,
626    /// The register `op_a` record.
627    pub a: MemoryRecordEnum,
628    /// The b operand.
629    pub op_b: u64,
630    /// The register `op_b` record.
631    pub b: MemoryRecordEnum,
632    /// The c operand.
633    pub op_c: u64,
634    /// The register `op_c` record.
635    pub c: MemoryRecordEnum,
636    /// Whether the instruction is untrusted.
637    pub is_untrusted: bool,
638}
639
640impl RTypeRecord {
641    pub(crate) fn new(value: &MemoryAccessRecord, instruction: &Instruction) -> Self {
642        Self {
643            op_a: instruction.op_a,
644            a: value.a.expect("expected MemoryRecord for op_a in RTypeRecord"),
645            op_b: instruction.op_b,
646            b: value.b.expect("expected MemoryRecord for op_b in RTypeRecord"),
647            op_c: instruction.op_c,
648            c: value.c.expect("expected MemoryRecord for op_c in RTypeRecord"),
649            is_untrusted: value.untrusted_instruction.is_some(),
650        }
651    }
652}
653/// Memory record where the first two operands are registers.
654#[derive(Debug, Clone, Copy, Serialize, Deserialize, DeepSizeOf)]
655pub struct ITypeRecord {
656    /// The a operand.
657    pub op_a: u8,
658    /// The register `op_a` record.
659    pub a: MemoryRecordEnum,
660    /// The b operand.
661    pub op_b: u64,
662    /// The register `op_b` record.
663    pub b: MemoryRecordEnum,
664    /// The c operand.
665    pub op_c: u64,
666    /// Whether the instruction is untrusted.
667    pub is_untrusted: bool,
668}
669
670impl ITypeRecord {
671    pub(crate) fn new(value: &MemoryAccessRecord, instruction: &Instruction) -> Self {
672        debug_assert!(value.c.is_none());
673        Self {
674            op_a: instruction.op_a,
675            a: value.a.expect("expected MemoryRecord for op_a in ITypeRecord"),
676            op_b: instruction.op_b,
677            b: value.b.expect("expected MemoryRecord for op_b in ITypeRecord"),
678            op_c: instruction.op_c,
679            is_untrusted: value.untrusted_instruction.is_some(),
680        }
681    }
682}
683
684/// Memory record where only one operand is a register.
685#[derive(Debug, Clone, Copy, Serialize, Deserialize, DeepSizeOf)]
686pub struct JTypeRecord {
687    /// The a operand.
688    pub op_a: u8,
689    /// The register `op_a` record.
690    pub a: MemoryRecordEnum,
691    /// The b operand.
692    pub op_b: u64,
693    /// The c operand.
694    pub op_c: u64,
695    /// Whether the instruction is untrusted.
696    pub is_untrusted: bool,
697}
698
699impl JTypeRecord {
700    pub(crate) fn new(value: &MemoryAccessRecord, instruction: &Instruction) -> Self {
701        debug_assert!(value.b.is_none());
702        debug_assert!(value.c.is_none());
703        Self {
704            op_a: instruction.op_a,
705            a: value.a.expect("expected MemoryRecord for op_a in JTypeRecord"),
706            op_b: instruction.op_b,
707            op_c: instruction.op_c,
708            is_untrusted: value.untrusted_instruction.is_some(),
709        }
710    }
711}
712
713/// Memory record where only the first two operands are known to be registers, but the third isn't.
714#[derive(Debug, Clone, Copy, Serialize, Deserialize, DeepSizeOf)]
715pub struct ALUTypeRecord {
716    /// The a operand.
717    pub op_a: u8,
718    /// The register `op_a` record.
719    pub a: MemoryRecordEnum,
720    /// The b operand.
721    pub op_b: u64,
722    /// The register `op_b` record.
723    pub b: MemoryRecordEnum,
724    /// The c operand.
725    pub op_c: u64,
726    /// The register `op_c` record.
727    pub c: Option<MemoryRecordEnum>,
728    /// Whether the instruction has an immediate.
729    pub is_imm: bool,
730    /// Whether the instruction is untrusted.
731    pub is_untrusted: bool,
732}
733
734impl ALUTypeRecord {
735    pub(crate) fn new(value: &MemoryAccessRecord, instruction: &Instruction) -> Self {
736        Self {
737            op_a: instruction.op_a,
738            a: value.a.expect("expected MemoryRecord for op_a in ALUTypeRecord"),
739            op_b: instruction.op_b,
740            b: value.b.expect("expected MemoryRecord for op_b in ALUTypeRecord"),
741            op_c: instruction.op_c,
742            c: value.c,
743            is_imm: instruction.imm_c,
744            is_untrusted: value.untrusted_instruction.is_some(),
745        }
746    }
747}
748
749/// Memory record for an untrusted program instruction fetch.
750#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
751pub struct UntrustedProgramInstructionRecord {
752    /// The a operand.
753    pub memory_access_record: MemoryAccessRecord,
754    /// The instruction.
755    pub instruction: Instruction,
756    /// The encoded instruction.
757    pub encoded_instruction: u32,
758}
759
760impl MachineRecord for ExecutionRecord {
761    fn stats(&self) -> HashMap<String, usize> {
762        let mut stats = HashMap::new();
763        stats.insert("cpu_events".to_string(), self.cpu_event_count as usize);
764        stats.insert("alu_x0_events".to_string(), self.alu_x0_events.len());
765        stats.insert("add_events".to_string(), self.add_events.len());
766        stats.insert("mul_events".to_string(), self.mul_events.len());
767        stats.insert("sub_events".to_string(), self.sub_events.len());
768        stats.insert("bitwise_events".to_string(), self.bitwise_events.len());
769        stats.insert("shift_left_events".to_string(), self.shift_left_events.len());
770        stats.insert("shift_right_events".to_string(), self.shift_right_events.len());
771        stats.insert("divrem_events".to_string(), self.divrem_events.len());
772        stats.insert("lt_events".to_string(), self.lt_events.len());
773        stats.insert("load_byte_events".to_string(), self.memory_load_byte_events.len());
774        stats.insert("load_half_events".to_string(), self.memory_load_half_events.len());
775        stats.insert("load_word_events".to_string(), self.memory_load_word_events.len());
776        stats.insert("load_x0_events".to_string(), self.memory_load_x0_events.len());
777        stats.insert("store_byte_events".to_string(), self.memory_store_byte_events.len());
778        stats.insert("store_half_events".to_string(), self.memory_store_half_events.len());
779        stats.insert("store_word_events".to_string(), self.memory_store_word_events.len());
780        stats.insert("branch_events".to_string(), self.branch_events.len());
781        stats.insert("jal_events".to_string(), self.jal_events.len());
782        stats.insert("jalr_events".to_string(), self.jalr_events.len());
783        stats.insert("utype_events".to_string(), self.utype_events.len());
784        stats.insert("instruction_decode_events".to_string(), self.instruction_decode_events.len());
785        stats.insert("instruction_fetch_events".to_string(), self.instruction_fetch_events.len());
786
787        for (syscall_code, events) in self.precompile_events.iter() {
788            stats.insert(format!("syscall {syscall_code:?}"), events.len());
789        }
790
791        stats.insert(
792            "global_memory_initialize_events".to_string(),
793            self.global_memory_initialize_events.len(),
794        );
795        stats.insert(
796            "global_memory_finalize_events".to_string(),
797            self.global_memory_finalize_events.len(),
798        );
799        stats.insert("local_memory_access_events".to_string(), self.cpu_local_memory_access.len());
800        stats.insert(
801            "local_page_prot_access_events".to_string(),
802            self.cpu_local_page_prot_access.len(),
803        );
804        if self.contains_cpu() {
805            stats.insert("byte_lookups".to_string(), self.byte_lookups.len());
806        }
807        // Filter out the empty events.
808        stats.retain(|_, v| *v != 0);
809        stats
810    }
811
812    fn append(&mut self, other: &mut ExecutionRecord) {
813        self.cpu_event_count += other.cpu_event_count;
814        other.cpu_event_count = 0;
815        self.public_values.global_count += other.public_values.global_count;
816        other.public_values.global_count = 0;
817        self.public_values.global_init_count += other.public_values.global_init_count;
818        other.public_values.global_init_count = 0;
819        self.public_values.global_finalize_count += other.public_values.global_finalize_count;
820        other.public_values.global_finalize_count = 0;
821        self.public_values.global_page_prot_init_count +=
822            other.public_values.global_page_prot_init_count;
823        other.public_values.global_page_prot_init_count = 0;
824        self.public_values.global_page_prot_finalize_count +=
825            other.public_values.global_page_prot_finalize_count;
826        other.public_values.global_page_prot_finalize_count = 0;
827        self.estimated_trace_area += other.estimated_trace_area;
828        other.estimated_trace_area = 0;
829        self.alu_x0_events.append(&mut other.alu_x0_events);
830        self.add_events.append(&mut other.add_events);
831        self.sub_events.append(&mut other.sub_events);
832        self.mul_events.append(&mut other.mul_events);
833        self.bitwise_events.append(&mut other.bitwise_events);
834        self.shift_left_events.append(&mut other.shift_left_events);
835        self.shift_right_events.append(&mut other.shift_right_events);
836        self.divrem_events.append(&mut other.divrem_events);
837        self.lt_events.append(&mut other.lt_events);
838        self.memory_load_byte_events.append(&mut other.memory_load_byte_events);
839        self.memory_load_half_events.append(&mut other.memory_load_half_events);
840        self.memory_load_word_events.append(&mut other.memory_load_word_events);
841        self.memory_load_x0_events.append(&mut other.memory_load_x0_events);
842        self.memory_store_byte_events.append(&mut other.memory_store_byte_events);
843        self.memory_store_half_events.append(&mut other.memory_store_half_events);
844        self.memory_store_word_events.append(&mut other.memory_store_word_events);
845        self.branch_events.append(&mut other.branch_events);
846        self.jal_events.append(&mut other.jal_events);
847        self.jalr_events.append(&mut other.jalr_events);
848        self.utype_events.append(&mut other.utype_events);
849        self.syscall_events.append(&mut other.syscall_events);
850        self.bump_memory_events.append(&mut other.bump_memory_events);
851        self.bump_state_events.append(&mut other.bump_state_events);
852        self.precompile_events.append(&mut other.precompile_events);
853        self.instruction_fetch_events.append(&mut other.instruction_fetch_events);
854        self.instruction_decode_events.append(&mut other.instruction_decode_events);
855
856        if self.byte_lookups.is_empty() {
857            self.byte_lookups = std::mem::take(&mut other.byte_lookups);
858        } else {
859            self.add_byte_lookup_events_from_maps(vec![&other.byte_lookups]);
860        }
861
862        self.global_memory_initialize_events.append(&mut other.global_memory_initialize_events);
863        self.global_memory_finalize_events.append(&mut other.global_memory_finalize_events);
864        self.global_page_prot_initialize_events
865            .append(&mut other.global_page_prot_initialize_events);
866        self.global_page_prot_finalize_events.append(&mut other.global_page_prot_finalize_events);
867        self.cpu_local_memory_access.append(&mut other.cpu_local_memory_access);
868        self.cpu_local_page_prot_access.append(&mut other.cpu_local_page_prot_access);
869        self.global_interaction_events.append(&mut other.global_interaction_events);
870    }
871
872    /// Retrieves the public values.  This method is needed for the `MachineRecord` trait, since
873    fn public_values<F: AbstractField>(&self) -> Vec<F> {
874        let mut public_values = self.public_values;
875        public_values.global_cumulative_sum = *self.global_cumulative_sum.lock().unwrap();
876        public_values.to_vec()
877    }
878
879    /// Constrains the public values.
880    #[allow(clippy::type_complexity)]
881    fn eval_public_values<AB: SP1AirBuilder>(builder: &mut AB) {
882        let public_values_slice: [AB::PublicVar; SP1_PROOF_NUM_PV_ELTS] =
883            core::array::from_fn(|i| builder.public_values()[i]);
884        let public_values: &PublicValues<
885            [AB::PublicVar; 4],
886            [AB::PublicVar; 3],
887            [AB::PublicVar; 4],
888            AB::PublicVar,
889        > = public_values_slice.as_slice().borrow();
890
891        for var in public_values.empty {
892            builder.assert_zero(var);
893        }
894
895        Self::eval_state(public_values, builder);
896        Self::eval_first_execution_shard(public_values, builder);
897        Self::eval_exit_code(public_values, builder);
898        Self::eval_committed_value_digest(public_values, builder);
899        Self::eval_deferred_proofs_digest(public_values, builder);
900        Self::eval_global_sum(public_values, builder);
901        Self::eval_global_memory_init(public_values, builder);
902        Self::eval_global_memory_finalize(public_values, builder);
903        Self::eval_global_page_prot_init(public_values, builder);
904        Self::eval_global_page_prot_finalize(public_values, builder);
905        #[cfg(feature = "mprotect")]
906        Self::eval_trap_handler(public_values, builder);
907    }
908
909    fn interactions_in_public_values() -> Vec<InteractionKind> {
910        InteractionKind::all_kinds()
911            .iter()
912            .filter(|kind| kind.appears_in_eval_public_values())
913            .copied()
914            .collect()
915    }
916}
917
918impl ByteRecord for ExecutionRecord {
919    fn add_byte_lookup_event(&mut self, blu_event: ByteLookupEvent) {
920        *self.byte_lookups.entry(blu_event).or_insert(0) += 1;
921    }
922
923    #[inline]
924    fn add_byte_lookup_events_from_maps(
925        &mut self,
926        new_events: Vec<&HashMap<ByteLookupEvent, usize>>,
927    ) {
928        for new_blu_map in new_events {
929            for (blu_event, count) in new_blu_map.iter() {
930                *self.byte_lookups.entry(*blu_event).or_insert(0) += count;
931            }
932        }
933    }
934}
935
936impl ExecutionRecord {
937    #[allow(clippy::type_complexity)]
938    fn eval_state<AB: SP1AirBuilder>(
939        public_values: &PublicValues<
940            [AB::PublicVar; 4],
941            [AB::PublicVar; 3],
942            [AB::PublicVar; 4],
943            AB::PublicVar,
944        >,
945        builder: &mut AB,
946    ) {
947        let initial_timestamp_high = public_values.initial_timestamp[1].into()
948            + public_values.initial_timestamp[0].into() * AB::Expr::from_canonical_u32(1 << 8);
949        let initial_timestamp_low = public_values.initial_timestamp[3].into()
950            + public_values.initial_timestamp[2].into() * AB::Expr::from_canonical_u32(1 << 16);
951        let last_timestamp_high = public_values.last_timestamp[1].into()
952            + public_values.last_timestamp[0].into() * AB::Expr::from_canonical_u32(1 << 8);
953        let last_timestamp_low = public_values.last_timestamp[3].into()
954            + public_values.last_timestamp[2].into() * AB::Expr::from_canonical_u32(1 << 16);
955
956        // Range check all the timestamp limbs.
957        builder.send_byte(
958            AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
959            public_values.initial_timestamp[0].into(),
960            AB::Expr::from_canonical_u32(16),
961            AB::Expr::zero(),
962            AB::Expr::one(),
963        );
964        builder.send_byte(
965            AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
966            (public_values.initial_timestamp[3].into() - AB::Expr::one())
967                * AB::F::from_canonical_u8(8).inverse(),
968            AB::Expr::from_canonical_u32(13),
969            AB::Expr::zero(),
970            AB::Expr::one(),
971        );
972        builder.send_byte(
973            AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
974            public_values.last_timestamp[0].into(),
975            AB::Expr::from_canonical_u32(16),
976            AB::Expr::zero(),
977            AB::Expr::one(),
978        );
979        builder.send_byte(
980            AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
981            (public_values.last_timestamp[3].into() - AB::Expr::one())
982                * AB::F::from_canonical_u8(8).inverse(),
983            AB::Expr::from_canonical_u32(13),
984            AB::Expr::zero(),
985            AB::Expr::one(),
986        );
987        builder.send_byte(
988            AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
989            AB::Expr::zero(),
990            public_values.initial_timestamp[1],
991            public_values.initial_timestamp[2],
992            AB::Expr::one(),
993        );
994        builder.send_byte(
995            AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
996            AB::Expr::zero(),
997            public_values.last_timestamp[1],
998            public_values.last_timestamp[2],
999            AB::Expr::one(),
1000        );
1001
1002        // Range check all the initial, final program counter limbs.
1003        for i in 0..3 {
1004            builder.send_byte(
1005                AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
1006                public_values.pc_start[i].into(),
1007                AB::Expr::from_canonical_u32(16),
1008                AB::Expr::zero(),
1009                AB::Expr::one(),
1010            );
1011            builder.send_byte(
1012                AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
1013                public_values.next_pc[i].into(),
1014                AB::Expr::from_canonical_u32(16),
1015                AB::Expr::zero(),
1016                AB::Expr::one(),
1017            );
1018        }
1019
1020        // Send and receive the initial and last state.
1021        builder.send_state(
1022            initial_timestamp_high.clone(),
1023            initial_timestamp_low.clone(),
1024            public_values.pc_start,
1025            AB::Expr::one(),
1026        );
1027        builder.receive_state(
1028            last_timestamp_high.clone(),
1029            last_timestamp_low.clone(),
1030            public_values.next_pc,
1031            AB::Expr::one(),
1032        );
1033
1034        // If the shard is not execution shard, assert that timestamp and pc remains equal.
1035        let is_execution_shard = public_values.is_execution_shard.into();
1036        builder.assert_bool(is_execution_shard.clone());
1037        builder
1038            .when_not(is_execution_shard.clone())
1039            .assert_eq(initial_timestamp_low.clone(), last_timestamp_low.clone());
1040        builder
1041            .when_not(is_execution_shard.clone())
1042            .assert_eq(initial_timestamp_high.clone(), last_timestamp_high.clone());
1043        builder
1044            .when_not(is_execution_shard.clone())
1045            .assert_all_eq(public_values.pc_start, public_values.next_pc);
1046
1047        // IsZeroOperation on the high bits of the timestamp.
1048        builder.assert_bool(public_values.is_timestamp_high_eq);
1049        // If high bits are equal, then `is_timestamp_high_eq == 1`.
1050        builder.assert_eq(
1051            (last_timestamp_high.clone() - initial_timestamp_high.clone())
1052                * public_values.inv_timestamp_high.into(),
1053            AB::Expr::one() - public_values.is_timestamp_high_eq.into(),
1054        );
1055        // If high bits are distinct, then `is_timestamp_high_eq == 0`.
1056        builder.assert_zero(
1057            (last_timestamp_high.clone() - initial_timestamp_high.clone())
1058                * public_values.is_timestamp_high_eq.into(),
1059        );
1060
1061        // IsZeroOperation on the low bits of the timestamp.
1062        builder.assert_bool(public_values.is_timestamp_low_eq);
1063        // If low bits are equal, then `is_timestamp_low_eq == 1`.
1064        builder.assert_eq(
1065            (last_timestamp_low.clone() - initial_timestamp_low.clone())
1066                * public_values.inv_timestamp_low.into(),
1067            AB::Expr::one() - public_values.is_timestamp_low_eq.into(),
1068        );
1069        // If low bits are distinct, then `is_timestamp_low_eq == 0`.
1070        builder.assert_zero(
1071            (last_timestamp_low.clone() - initial_timestamp_low.clone())
1072                * public_values.is_timestamp_low_eq.into(),
1073        );
1074
1075        // If the shard is an execution shard, then the timestamp is different.
1076        builder.assert_eq(
1077            AB::Expr::one() - is_execution_shard.clone(),
1078            public_values.is_timestamp_high_eq.into() * public_values.is_timestamp_low_eq.into(),
1079        );
1080
1081        // Check that an execution shard has `last_timestamp != 1` by providing an inverse.
1082        // The `high + low` value cannot overflow, as they were range checked to be 24 bits.
1083        // `high == 1, low == 0` is impossible, as `low == 1 (mod 8)` as checked in `eval_state`.
1084        builder.when(is_execution_shard.clone()).assert_eq(
1085            (last_timestamp_high + last_timestamp_low - AB::Expr::one())
1086                * public_values.last_timestamp_inv.into(),
1087            AB::Expr::one(),
1088        );
1089    }
1090
1091    #[allow(clippy::type_complexity)]
1092    fn eval_first_execution_shard<AB: SP1AirBuilder>(
1093        public_values: &PublicValues<
1094            [AB::PublicVar; 4],
1095            [AB::PublicVar; 3],
1096            [AB::PublicVar; 4],
1097            AB::PublicVar,
1098        >,
1099        builder: &mut AB,
1100    ) {
1101        // Check that `is_first_execution_shard` is boolean.
1102        builder.assert_bool(public_values.is_first_execution_shard.into());
1103
1104        // Timestamp constraints.
1105        //
1106        // We want to assert that `is_first_execution_shard == 1` corresponds exactly to the unique
1107        // execution shard with initial timestamp 1.We are assuming that there is a unique
1108        // shard with `is_first_execution_shard == 1`. This is enforced in the verifier and
1109        // in recursion. Given thus, it is enough to impose that for this unique shard,
1110        // `initial_timestamp == 1`.
1111        builder.when(public_values.is_first_execution_shard.into()).assert_all_eq(
1112            public_values.initial_timestamp,
1113            [AB::Expr::zero(), AB::Expr::zero(), AB::Expr::zero(), AB::Expr::one()],
1114        );
1115
1116        // If `is_first_execution_shard` is true, check `is_execution_shard == 1`.
1117        builder
1118            .when(public_values.is_first_execution_shard.into())
1119            .assert_one(public_values.is_execution_shard);
1120
1121        // If `is_first_execution_shard` is true, assert the initial boundary conditions.
1122
1123        // Check `prev_committed_value_digest == 0`.
1124        for i in 0..PV_DIGEST_NUM_WORDS {
1125            builder
1126                .when(public_values.is_first_execution_shard.into())
1127                .assert_all_zero(public_values.prev_committed_value_digest[i]);
1128        }
1129
1130        // Check `prev_deferred_proofs_digest == 0`.
1131        builder
1132            .when(public_values.is_first_execution_shard.into())
1133            .assert_all_zero(public_values.prev_deferred_proofs_digest);
1134
1135        // Check `prev_exit_code == 0`.
1136        builder
1137            .when(public_values.is_first_execution_shard.into())
1138            .assert_zero(public_values.prev_exit_code);
1139
1140        // Check `previous_init_addr == 0`.
1141        builder
1142            .when(public_values.is_first_execution_shard.into())
1143            .assert_all_zero(public_values.previous_init_addr);
1144
1145        // Check `previous_finalize_addr == 0`.
1146        builder
1147            .when(public_values.is_first_execution_shard.into())
1148            .assert_all_zero(public_values.previous_finalize_addr);
1149
1150        // Check `previous_init_page_idx == 0`
1151        builder
1152            .when(public_values.is_first_execution_shard.into())
1153            .assert_all_zero(public_values.previous_init_page_idx);
1154
1155        // Check `previous_finalize_page_idx == 0`
1156        builder
1157            .when(public_values.is_first_execution_shard.into())
1158            .assert_all_zero(public_values.previous_finalize_page_idx);
1159
1160        // Check `prev_commit_syscall == 0`.
1161        builder
1162            .when(public_values.is_first_execution_shard.into())
1163            .assert_zero(public_values.prev_commit_syscall);
1164
1165        // Check `prev_commit_deferred_syscall == 0`.
1166        builder
1167            .when(public_values.is_first_execution_shard.into())
1168            .assert_zero(public_values.prev_commit_deferred_syscall);
1169    }
1170
1171    #[allow(clippy::type_complexity)]
1172    fn eval_exit_code<AB: SP1AirBuilder>(
1173        public_values: &PublicValues<
1174            [AB::PublicVar; 4],
1175            [AB::PublicVar; 3],
1176            [AB::PublicVar; 4],
1177            AB::PublicVar,
1178        >,
1179        builder: &mut AB,
1180    ) {
1181        let is_execution_shard = public_values.is_execution_shard.into();
1182
1183        // If the `prev_exit_code` is non-zero, then the `exit_code` must be equal to it.
1184        builder.assert_zero(
1185            public_values.prev_exit_code.into()
1186                * (public_values.exit_code.into() - public_values.prev_exit_code.into()),
1187        );
1188
1189        // If it's not an execution shard, assert that `exit_code` will not change in that shard.
1190        builder
1191            .when_not(is_execution_shard.clone())
1192            .assert_eq(public_values.prev_exit_code, public_values.exit_code);
1193    }
1194
1195    #[allow(clippy::type_complexity)]
1196    fn eval_committed_value_digest<AB: SP1AirBuilder>(
1197        public_values: &PublicValues<
1198            [AB::PublicVar; 4],
1199            [AB::PublicVar; 3],
1200            [AB::PublicVar; 4],
1201            AB::PublicVar,
1202        >,
1203        builder: &mut AB,
1204    ) {
1205        let is_execution_shard = public_values.is_execution_shard.into();
1206
1207        // Assert that both `prev_committed_value_digest` and `committed_value_digest` are bytes.
1208        for i in 0..PV_DIGEST_NUM_WORDS {
1209            builder.send_byte(
1210                AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
1211                AB::Expr::zero(),
1212                public_values.prev_committed_value_digest[i][0],
1213                public_values.prev_committed_value_digest[i][1],
1214                AB::Expr::one(),
1215            );
1216            builder.send_byte(
1217                AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
1218                AB::Expr::zero(),
1219                public_values.prev_committed_value_digest[i][2],
1220                public_values.prev_committed_value_digest[i][3],
1221                AB::Expr::one(),
1222            );
1223            builder.send_byte(
1224                AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
1225                AB::Expr::zero(),
1226                public_values.committed_value_digest[i][0],
1227                public_values.committed_value_digest[i][1],
1228                AB::Expr::one(),
1229            );
1230            builder.send_byte(
1231                AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8),
1232                AB::Expr::zero(),
1233                public_values.committed_value_digest[i][2],
1234                public_values.committed_value_digest[i][3],
1235                AB::Expr::one(),
1236            );
1237        }
1238
1239        // Assert that both `prev_commit_syscall` and `commit_syscall` are boolean.
1240        builder.assert_bool(public_values.prev_commit_syscall);
1241        builder.assert_bool(public_values.commit_syscall);
1242
1243        // Assert that `prev_commit_syscall == 1` implies `commit_syscall == 1`.
1244        builder.when(public_values.prev_commit_syscall).assert_one(public_values.commit_syscall);
1245
1246        // Assert that the `commit_syscall` value doesn't change in a non-execution shard.
1247        builder
1248            .when_not(is_execution_shard.clone())
1249            .assert_eq(public_values.prev_commit_syscall, public_values.commit_syscall);
1250
1251        // Assert that `committed_value_digest` will not change in a non-execution shard.
1252        for i in 0..PV_DIGEST_NUM_WORDS {
1253            builder.when_not(is_execution_shard.clone()).assert_all_eq(
1254                public_values.prev_committed_value_digest[i],
1255                public_values.committed_value_digest[i],
1256            );
1257        }
1258
1259        // Assert that `prev_committed_value_digest != [0u8; 32]` implies `committed_value_digest`
1260        // must remain equal to the `prev_committed_value_digest`.
1261        for word in public_values.prev_committed_value_digest {
1262            for limb in word {
1263                for i in 0..PV_DIGEST_NUM_WORDS {
1264                    builder.when(limb).assert_all_eq(
1265                        public_values.prev_committed_value_digest[i],
1266                        public_values.committed_value_digest[i],
1267                    );
1268                }
1269            }
1270        }
1271
1272        // Assert that if `prev_commit_syscall` is true, `committed_value_digest` doesn't change.
1273        for i in 0..PV_DIGEST_NUM_WORDS {
1274            builder.when(public_values.prev_commit_syscall).assert_all_eq(
1275                public_values.prev_committed_value_digest[i],
1276                public_values.committed_value_digest[i],
1277            );
1278        }
1279    }
1280
1281    #[allow(clippy::type_complexity)]
1282    fn eval_deferred_proofs_digest<AB: SP1AirBuilder>(
1283        public_values: &PublicValues<
1284            [AB::PublicVar; 4],
1285            [AB::PublicVar; 3],
1286            [AB::PublicVar; 4],
1287            AB::PublicVar,
1288        >,
1289        builder: &mut AB,
1290    ) {
1291        let is_execution_shard = public_values.is_execution_shard.into();
1292
1293        // Assert that `prev_commit_deferred_syscall` and `commit_deferred_syscall` are boolean.
1294        builder.assert_bool(public_values.prev_commit_deferred_syscall);
1295        builder.assert_bool(public_values.commit_deferred_syscall);
1296
1297        // Assert that `prev_commit_deferred_syscall == 1` implies `commit_deferred_syscall == 1`.
1298        builder
1299            .when(public_values.prev_commit_deferred_syscall)
1300            .assert_one(public_values.commit_deferred_syscall);
1301
1302        // Assert that the `commit_deferred_syscall` value doesn't change in a non-execution shard.
1303        builder.when_not(is_execution_shard.clone()).assert_eq(
1304            public_values.prev_commit_deferred_syscall,
1305            public_values.commit_deferred_syscall,
1306        );
1307
1308        // Assert that `deferred_proofs_digest` will not change in a non-execution shard.
1309        builder.when_not(is_execution_shard.clone()).assert_all_eq(
1310            public_values.prev_deferred_proofs_digest,
1311            public_values.deferred_proofs_digest,
1312        );
1313
1314        // Assert that `prev_deferred_proofs_digest != 0` implies `deferred_proofs_digest` must
1315        // remain equal to the `prev_deferred_proofs_digest`.
1316        for limb in public_values.prev_deferred_proofs_digest {
1317            builder.when(limb).assert_all_eq(
1318                public_values.prev_deferred_proofs_digest,
1319                public_values.deferred_proofs_digest,
1320            );
1321        }
1322
1323        // If `prev_commit_deferred_syscall` is true, `deferred_proofs_digest` doesn't change.
1324        builder.when(public_values.prev_commit_deferred_syscall).assert_all_eq(
1325            public_values.prev_deferred_proofs_digest,
1326            public_values.deferred_proofs_digest,
1327        );
1328    }
1329
1330    #[allow(clippy::type_complexity)]
1331    fn eval_global_sum<AB: SP1AirBuilder>(
1332        public_values: &PublicValues<
1333            [AB::PublicVar; 4],
1334            [AB::PublicVar; 3],
1335            [AB::PublicVar; 4],
1336            AB::PublicVar,
1337        >,
1338        builder: &mut AB,
1339    ) {
1340        let initial_sum = SepticDigest::<AB::F>::zero().0;
1341        builder.send(
1342            AirInteraction::new(
1343                once(AB::Expr::zero())
1344                    .chain(initial_sum.x.0.into_iter().map(Into::into))
1345                    .chain(initial_sum.y.0.into_iter().map(Into::into))
1346                    .collect(),
1347                AB::Expr::one(),
1348                InteractionKind::GlobalAccumulation,
1349            ),
1350            InteractionScope::Local,
1351        );
1352        builder.receive(
1353            AirInteraction::new(
1354                once(public_values.global_count.into())
1355                    .chain(public_values.global_cumulative_sum.0.x.0.map(Into::into))
1356                    .chain(public_values.global_cumulative_sum.0.y.0.map(Into::into))
1357                    .collect(),
1358                AB::Expr::one(),
1359                InteractionKind::GlobalAccumulation,
1360            ),
1361            InteractionScope::Local,
1362        );
1363    }
1364
1365    #[allow(clippy::type_complexity)]
1366    fn eval_global_memory_init<AB: SP1AirBuilder>(
1367        public_values: &PublicValues<
1368            [AB::PublicVar; 4],
1369            [AB::PublicVar; 3],
1370            [AB::PublicVar; 4],
1371            AB::PublicVar,
1372        >,
1373        builder: &mut AB,
1374    ) {
1375        // Check the addresses are of valid u16 limbs.
1376        for i in 0..3 {
1377            builder.send_byte(
1378                AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
1379                public_values.previous_init_addr[i].into(),
1380                AB::Expr::from_canonical_u32(16),
1381                AB::Expr::zero(),
1382                AB::Expr::one(),
1383            );
1384            builder.send_byte(
1385                AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
1386                public_values.last_init_addr[i].into(),
1387                AB::Expr::from_canonical_u32(16),
1388                AB::Expr::zero(),
1389                AB::Expr::one(),
1390            );
1391        }
1392
1393        builder.send(
1394            AirInteraction::new(
1395                once(AB::Expr::zero())
1396                    .chain(public_values.previous_init_addr.into_iter().map(Into::into))
1397                    .chain(once(AB::Expr::one()))
1398                    .collect(),
1399                AB::Expr::one(),
1400                InteractionKind::MemoryGlobalInitControl,
1401            ),
1402            InteractionScope::Local,
1403        );
1404        builder.receive(
1405            AirInteraction::new(
1406                once(public_values.global_init_count.into())
1407                    .chain(public_values.last_init_addr.into_iter().map(Into::into))
1408                    .chain(once(AB::Expr::one()))
1409                    .collect(),
1410                AB::Expr::one(),
1411                InteractionKind::MemoryGlobalInitControl,
1412            ),
1413            InteractionScope::Local,
1414        );
1415    }
1416
1417    #[allow(clippy::type_complexity)]
1418    fn eval_global_memory_finalize<AB: SP1AirBuilder>(
1419        public_values: &PublicValues<
1420            [AB::PublicVar; 4],
1421            [AB::PublicVar; 3],
1422            [AB::PublicVar; 4],
1423            AB::PublicVar,
1424        >,
1425        builder: &mut AB,
1426    ) {
1427        // Check the addresses are of valid u16 limbs.
1428        for i in 0..3 {
1429            builder.send_byte(
1430                AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
1431                public_values.previous_finalize_addr[i].into(),
1432                AB::Expr::from_canonical_u32(16),
1433                AB::Expr::zero(),
1434                AB::Expr::one(),
1435            );
1436            builder.send_byte(
1437                AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
1438                public_values.last_finalize_addr[i].into(),
1439                AB::Expr::from_canonical_u32(16),
1440                AB::Expr::zero(),
1441                AB::Expr::one(),
1442            );
1443        }
1444
1445        builder.send(
1446            AirInteraction::new(
1447                once(AB::Expr::zero())
1448                    .chain(public_values.previous_finalize_addr.into_iter().map(Into::into))
1449                    .chain(once(AB::Expr::one()))
1450                    .collect(),
1451                AB::Expr::one(),
1452                InteractionKind::MemoryGlobalFinalizeControl,
1453            ),
1454            InteractionScope::Local,
1455        );
1456        builder.receive(
1457            AirInteraction::new(
1458                once(public_values.global_finalize_count.into())
1459                    .chain(public_values.last_finalize_addr.into_iter().map(Into::into))
1460                    .chain(once(AB::Expr::one()))
1461                    .collect(),
1462                AB::Expr::one(),
1463                InteractionKind::MemoryGlobalFinalizeControl,
1464            ),
1465            InteractionScope::Local,
1466        );
1467    }
1468
1469    #[allow(clippy::type_complexity)]
1470    fn eval_global_page_prot_init<AB: SP1AirBuilder>(
1471        public_values: &PublicValues<
1472            [AB::PublicVar; 4],
1473            [AB::PublicVar; 3],
1474            [AB::PublicVar; 4],
1475            AB::PublicVar,
1476        >,
1477        builder: &mut AB,
1478    ) {
1479        builder.assert_bool(public_values.is_untrusted_programs_enabled.into());
1480        builder.send(
1481            AirInteraction::new(
1482                once(AB::Expr::zero())
1483                    .chain(public_values.previous_init_page_idx.into_iter().map(Into::into))
1484                    .chain(once(AB::Expr::one()))
1485                    .collect(),
1486                public_values.is_untrusted_programs_enabled.into(),
1487                InteractionKind::PageProtGlobalInitControl,
1488            ),
1489            InteractionScope::Local,
1490        );
1491        builder.receive(
1492            AirInteraction::new(
1493                once(public_values.global_page_prot_init_count.into())
1494                    .chain(public_values.last_init_page_idx.into_iter().map(Into::into))
1495                    .chain(once(AB::Expr::one()))
1496                    .collect(),
1497                public_values.is_untrusted_programs_enabled.into(),
1498                InteractionKind::PageProtGlobalInitControl,
1499            ),
1500            InteractionScope::Local,
1501        );
1502    }
1503
1504    #[allow(clippy::type_complexity)]
1505    fn eval_global_page_prot_finalize<AB: SP1AirBuilder>(
1506        public_values: &PublicValues<
1507            [AB::PublicVar; 4],
1508            [AB::PublicVar; 3],
1509            [AB::PublicVar; 4],
1510            AB::PublicVar,
1511        >,
1512        builder: &mut AB,
1513    ) {
1514        builder.assert_bool(public_values.is_untrusted_programs_enabled.into());
1515        builder.send(
1516            AirInteraction::new(
1517                once(AB::Expr::zero())
1518                    .chain(public_values.previous_finalize_page_idx.into_iter().map(Into::into))
1519                    .chain(once(AB::Expr::one()))
1520                    .collect(),
1521                public_values.is_untrusted_programs_enabled.into(),
1522                InteractionKind::PageProtGlobalFinalizeControl,
1523            ),
1524            InteractionScope::Local,
1525        );
1526        builder.receive(
1527            AirInteraction::new(
1528                once(public_values.global_page_prot_finalize_count.into())
1529                    .chain(public_values.last_finalize_page_idx.into_iter().map(Into::into))
1530                    .chain(once(AB::Expr::one()))
1531                    .collect(),
1532                public_values.is_untrusted_programs_enabled.into(),
1533                InteractionKind::PageProtGlobalFinalizeControl,
1534            ),
1535            InteractionScope::Local,
1536        );
1537    }
1538
1539    #[cfg(feature = "mprotect")]
1540    #[allow(clippy::type_complexity)]
1541    fn eval_trap_handler<AB: SP1AirBuilder>(
1542        public_values: &PublicValues<
1543            [AB::PublicVar; 4],
1544            [AB::PublicVar; 3],
1545            [AB::PublicVar; 4],
1546            AB::PublicVar,
1547        >,
1548        builder: &mut AB,
1549    ) {
1550        // `is_untrusted_programs_enabled` must be boolean.
1551        builder.assert_bool(public_values.is_untrusted_programs_enabled);
1552        // `enable_trap_handler` must be boolean.
1553        builder.assert_bool(public_values.enable_trap_handler);
1554
1555        // If untrusted programs are not enabled, there are no trap handlers.
1556        builder
1557            .when_not(public_values.is_untrusted_programs_enabled)
1558            .assert_zero(public_values.enable_trap_handler);
1559
1560        // The `trap_context` is with 16-bit limbs.
1561        // If there are no trap handlers, `trap_context` is all zero.
1562        for addr_idx in 0..3 {
1563            builder
1564                .when_not(public_values.enable_trap_handler)
1565                .assert_all_zero(public_values.trap_context[addr_idx]);
1566            for idx in 0..3 {
1567                builder.send_byte(
1568                    AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
1569                    public_values.trap_context[addr_idx][idx].into(),
1570                    AB::Expr::from_canonical_u32(16),
1571                    AB::Expr::zero(),
1572                    AB::Expr::one(),
1573                );
1574            }
1575        }
1576
1577        // The `untrusted_memory` is with 16-bit limbs.
1578        // If untrusted programs are not enabled, `untrusted_memory` is all zero.
1579        for addr_idx in 0..2 {
1580            builder
1581                .when_not(public_values.is_untrusted_programs_enabled)
1582                .assert_all_zero(public_values.untrusted_memory[addr_idx]);
1583            for idx in 0..3 {
1584                builder.send_byte(
1585                    AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
1586                    public_values.untrusted_memory[addr_idx][idx].into(),
1587                    AB::Expr::from_canonical_u32(16),
1588                    AB::Expr::zero(),
1589                    AB::Expr::one(),
1590                );
1591            }
1592        }
1593    }
1594
1595    /// Finalize the public values.
1596    pub fn finalize_public_values<F: PrimeField32>(&mut self, is_execution_shard: bool) {
1597        let state = &mut self.public_values;
1598        state.is_execution_shard = is_execution_shard as u32;
1599
1600        let initial_timestamp_high = (state.initial_timestamp >> 24) as u32;
1601        let initial_timestamp_low = (state.initial_timestamp & 0xFFFFFF) as u32;
1602        let last_timestamp_high = (state.last_timestamp >> 24) as u32;
1603        let last_timestamp_low = (state.last_timestamp & 0xFFFFFF) as u32;
1604
1605        state.initial_timestamp_inv = if state.initial_timestamp == 1 {
1606            0
1607        } else {
1608            F::from_canonical_u32(initial_timestamp_high + initial_timestamp_low - 1)
1609                .inverse()
1610                .as_canonical_u32()
1611        };
1612
1613        state.last_timestamp_inv =
1614            F::from_canonical_u32(last_timestamp_high + last_timestamp_low - 1)
1615                .inverse()
1616                .as_canonical_u32();
1617
1618        if initial_timestamp_high == last_timestamp_high {
1619            state.is_timestamp_high_eq = 1;
1620        } else {
1621            state.is_timestamp_high_eq = 0;
1622            state.inv_timestamp_high = (F::from_canonical_u32(last_timestamp_high)
1623                - F::from_canonical_u32(initial_timestamp_high))
1624            .inverse()
1625            .as_canonical_u32();
1626        }
1627
1628        if initial_timestamp_low == last_timestamp_low {
1629            state.is_timestamp_low_eq = 1;
1630        } else {
1631            state.is_timestamp_low_eq = 0;
1632            state.inv_timestamp_low = (F::from_canonical_u32(last_timestamp_low)
1633                - F::from_canonical_u32(initial_timestamp_low))
1634            .inverse()
1635            .as_canonical_u32();
1636        }
1637        state.is_first_execution_shard = (state.initial_timestamp == 1) as u32;
1638    }
1639}