Skip to main content

sp1_core_executor/
vm.rs

1#![allow(unknown_lints)]
2#![allow(clippy::manual_checked_ops)]
3
4use crate::{
5    events::{MemoryAccessPosition, MemoryReadRecord, MemoryRecord, MemoryWriteRecord},
6    vm::{
7        results::{
8            AluResult, BranchResult, CycleResult, EcallResult, JumpResult, LoadResult,
9            MaybeImmediate, StoreResult, UTypeResult,
10        },
11        syscall::{sp1_ecall_handler, SyscallRuntime},
12    },
13    ExecutionError, Instruction, Opcode, Program, Register, RetainedEventsPreset, SP1CoreOpts,
14    SyscallCode, CLK_INC as CLK_INC_32, HALT_PC, PC_INC as PC_INC_32,
15};
16use sp1_hypercube::air::{PROOF_NONCE_NUM_WORDS, PV_DIGEST_NUM_WORDS};
17use sp1_jit::{MemReads, MinimalTrace};
18use std::{mem::MaybeUninit, num::Wrapping, ptr::addr_of_mut, sync::Arc};
19
20pub(crate) mod gas;
21pub(crate) mod memory;
22pub(crate) mod results;
23pub(crate) mod shapes;
24pub(crate) mod syscall;
25
26const CLK_INC: u64 = CLK_INC_32 as u64;
27const PC_INC: u64 = PC_INC_32 as u64;
28
29/// A RISC-V VM that uses a [`MinimalTrace`] to oracle memory access.
30pub struct CoreVM<'a> {
31    registers: [MemoryRecord; 32],
32    /// The current clock of the VM.
33    clk: u64,
34    /// The global clock of the VM.
35    global_clk: u64,
36    /// The current program counter of the VM.
37    pc: u64,
38    /// The current exit code of the VM.
39    exit_code: u32,
40    /// The memory reads cursoir.
41    pub mem_reads: MemReads<'a>,
42    /// The next program counter that will be set in [`CoreVM::advance`].
43    next_pc: u64,
44    /// The next clock that will be set in [`CoreVM::advance`].
45    next_clk: u64,
46    /// The program that is being executed.
47    pub program: Arc<Program>,
48    /// The syscalls that are not marked as external, ie. they stay in the same shard.
49    pub(crate) retained_syscall_codes: Vec<SyscallCode>,
50    /// The options to configure the VM, mostly for syscall / shard handling.
51    pub opts: SP1CoreOpts,
52    /// The end clk of the trace chunk.
53    pub clk_end: u64,
54    /// The public value digest.
55    pub public_value_digest: [u32; PV_DIGEST_NUM_WORDS],
56    /// The nonce associated with the proof.
57    pub proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
58}
59
60impl<'a> CoreVM<'a> {
61    /// Create a [`CoreVM`] from a [`MinimalTrace`] and a [`Program`].
62    pub fn new<T: MinimalTrace>(
63        trace: &'a T,
64        program: Arc<Program>,
65        opts: SP1CoreOpts,
66        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
67    ) -> Self {
68        let start_clk = trace.clk_start();
69
70        // SAFETY: We're mapping a [T; 32] -> [T; 32] infallibly.
71        let registers = unsafe {
72            trace
73                .start_registers()
74                .into_iter()
75                .map(|v| MemoryRecord { timestamp: start_clk - 1, value: v })
76                .collect::<Vec<_>>()
77                .try_into()
78                .unwrap_unchecked()
79        };
80        let start_pc = trace.pc_start();
81
82        let retained_syscall_codes = opts
83            .retained_events_presets
84            .iter()
85            .flat_map(RetainedEventsPreset::syscall_codes)
86            .copied()
87            .collect();
88
89        tracing::trace!("start_clk: {}", start_clk);
90        tracing::trace!("start_pc: {}", start_pc);
91        tracing::trace!("trace.clk_end(): {}", trace.clk_end());
92        tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
93        tracing::trace!("trace.start_registers(): {:?}", trace.start_registers());
94
95        if trace.clk_start() == 1 {
96            assert_eq!(trace.pc_start(), program.pc_start_abs);
97        }
98
99        Self {
100            registers,
101            global_clk: 0,
102            clk: start_clk,
103            pc: start_pc,
104            program,
105            mem_reads: trace.mem_reads(),
106            next_pc: start_pc.wrapping_add(PC_INC),
107            next_clk: start_clk.wrapping_add(CLK_INC),
108            exit_code: 0,
109            retained_syscall_codes,
110            opts,
111            clk_end: trace.clk_end(),
112            public_value_digest: [0; PV_DIGEST_NUM_WORDS],
113            proof_nonce,
114        }
115    }
116
117    /// Fetch the next instruction from the program.
118    #[inline]
119    pub fn fetch(&mut self) -> Option<&Instruction> {
120        // todo: mprotect / kernel mode logic.
121        self.program.fetch(self.pc)
122    }
123
124    #[inline]
125    /// Increment the state of the VM by one cycle.
126    /// Calling this method will update the pc and the clk to the next cycle.
127    pub fn advance(&mut self) -> CycleResult {
128        self.clk = self.next_clk;
129        self.pc = self.next_pc;
130
131        // Reset the next_clk and next_pc to the next cycle.
132        self.next_clk = self.clk.wrapping_add(CLK_INC);
133        self.next_pc = self.pc.wrapping_add(PC_INC);
134        self.global_clk = self.global_clk.wrapping_add(1);
135
136        // Check if the program has halted.
137        if self.pc == HALT_PC {
138            return CycleResult::Done(true);
139        }
140
141        // Check if the shard limit has been reached.
142        if self.is_trace_end() {
143            return CycleResult::TraceEnd;
144        }
145
146        // Return that the program is still running.
147        CycleResult::Done(false)
148    }
149
150    /// Execute a load instruction.
151    #[inline]
152    pub fn execute_load(
153        &mut self,
154        instruction: &Instruction,
155    ) -> Result<LoadResult, ExecutionError> {
156        let (rd, rs1, imm) = instruction.i_type();
157
158        let rr_record = self.rr(rs1, MemoryAccessPosition::B);
159        let b = rr_record.value;
160
161        // Compute the address.
162        let addr = b.wrapping_add(imm);
163        let mr_record = self.mr(addr);
164        let word = mr_record.value;
165
166        let a = match instruction.opcode {
167            Opcode::LB => ((word >> ((addr % 8) * 8)) & 0xFF) as i8 as i64 as u64,
168            Opcode::LH => {
169                if !addr.is_multiple_of(2) {
170                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LH, addr));
171                }
172
173                ((word >> (((addr / 2) % 4) * 16)) & 0xFFFF) as i16 as i64 as u64
174            }
175            Opcode::LW => {
176                if !addr.is_multiple_of(4) {
177                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LW, addr));
178                }
179
180                ((word >> (((addr / 4) % 2) * 32)) & 0xFFFFFFFF) as i32 as u64
181            }
182            Opcode::LBU => ((word >> ((addr % 8) * 8)) & 0xFF) as u8 as u64,
183            Opcode::LHU => {
184                if !addr.is_multiple_of(2) {
185                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LHU, addr));
186                }
187
188                ((word >> (((addr / 2) % 4) * 16)) & 0xFFFF) as u16 as u64
189            }
190            // RISCV-64
191            Opcode::LWU => {
192                if !addr.is_multiple_of(4) {
193                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LWU, addr));
194                }
195
196                (word >> (((addr / 4) % 2) * 32)) & 0xFFFFFFFF
197            }
198            Opcode::LD => {
199                if !addr.is_multiple_of(8) {
200                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LD, addr));
201                }
202
203                word
204            }
205            _ => unreachable!("Invalid opcode for `execute_load`: {:?}", instruction.opcode),
206        };
207
208        let rw_record = self.rw(rd, a);
209
210        Ok(LoadResult { a, b, c: imm, addr, rs1, rd, rr_record, rw_record, mr_record })
211    }
212
213    /// Execute a store instruction.
214    #[inline]
215    pub fn execute_store(
216        &mut self,
217        instruction: &Instruction,
218    ) -> Result<StoreResult, ExecutionError> {
219        let (rs1, rs2, imm) = instruction.s_type();
220
221        let c = imm;
222        let rs2_record = self.rr(rs2, MemoryAccessPosition::B);
223        let rs1_record = self.rr(rs1, MemoryAccessPosition::A);
224
225        let b = rs2_record.value;
226        let a = rs1_record.value;
227        let addr = b.wrapping_add(c);
228        let mr_record = self.mr(addr);
229        let word = mr_record.value;
230
231        let memory_store_value = match instruction.opcode {
232            Opcode::SB => {
233                let shift = (addr % 8) * 8;
234                ((a & 0xFF) << shift) | (word & !(0xFF << shift))
235            }
236            Opcode::SH => {
237                if !addr.is_multiple_of(2) {
238                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SH, addr));
239                }
240                let shift = ((addr / 2) % 4) * 16;
241                ((a & 0xFFFF) << shift) | (word & !(0xFFFF << shift))
242            }
243            Opcode::SW => {
244                if !addr.is_multiple_of(4) {
245                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SW, addr));
246                }
247                let shift = ((addr / 4) % 2) * 32;
248                ((a & 0xFFFFFFFF) << shift) | (word & !(0xFFFFFFFF << shift))
249            }
250            // RISCV-64
251            Opcode::SD => {
252                if !addr.is_multiple_of(8) {
253                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SD, addr));
254                }
255                a
256            }
257            _ => unreachable!(),
258        };
259
260        let mw_record = self.mw(mr_record, memory_store_value);
261
262        Ok(StoreResult { a, b, c, addr, rs1, rs1_record, rs2, rs2_record, mw_record })
263    }
264
265    /// Execute an ALU instruction.
266    #[inline]
267    #[allow(clippy::too_many_lines)]
268    pub fn execute_alu(&mut self, instruction: &Instruction) -> AluResult {
269        let mut result = MaybeUninit::<AluResult>::uninit();
270        let result_ptr = result.as_mut_ptr();
271
272        let (rd, b, c) = if !instruction.imm_c {
273            let (rd, rs1, rs2) = instruction.r_type();
274            let c = self.rr(rs2, MemoryAccessPosition::C);
275            let b = self.rr(rs1, MemoryAccessPosition::B);
276
277            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
278            // `result`.
279            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Register(rs1, b)) };
280            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Register(rs2, c)) };
281
282            (rd, b.value, c.value)
283        } else if !instruction.imm_b && instruction.imm_c {
284            let (rd, rs1, imm) = instruction.i_type();
285            let (rd, b, c) = (rd, self.rr(rs1, MemoryAccessPosition::B), imm);
286
287            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
288            // `result`.
289            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Register(rs1, b)) };
290            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Immediate(c)) };
291
292            (rd, b.value, c)
293        } else {
294            debug_assert!(instruction.imm_b && instruction.imm_c);
295            let (rd, b, c) =
296                (Register::from_u8(instruction.op_a), instruction.op_b, instruction.op_c);
297
298            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
299            // `result`.
300            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Immediate(b)) };
301            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Immediate(c)) };
302
303            (rd, b, c)
304        };
305
306        let a = match instruction.opcode {
307            Opcode::ADD | Opcode::ADDI => (Wrapping(b) + Wrapping(c)).0,
308            Opcode::SUB => (Wrapping(b) - Wrapping(c)).0,
309            Opcode::XOR => b ^ c,
310            Opcode::OR => b | c,
311            Opcode::AND => b & c,
312            Opcode::SLL => b << (c & 0x3f),
313            Opcode::SRL => b >> (c & 0x3f),
314            Opcode::SRA => ((b as i64) >> (c & 0x3f)) as u64,
315            Opcode::SLT => {
316                if (b as i64) < (c as i64) {
317                    1
318                } else {
319                    0
320                }
321            }
322            Opcode::SLTU => {
323                if b < c {
324                    1
325                } else {
326                    0
327                }
328            }
329            Opcode::MUL => (Wrapping(b as i64) * Wrapping(c as i64)).0 as u64,
330            Opcode::MULH => (((b as i64) as i128).wrapping_mul((c as i64) as i128) >> 64) as u64,
331            Opcode::MULHU => ((b as u128 * c as u128) >> 64) as u64,
332            Opcode::MULHSU => ((((b as i64) as i128) * (c as i128)) >> 64) as u64,
333            Opcode::DIV => {
334                if c == 0 {
335                    u64::MAX
336                } else {
337                    (b as i64).wrapping_div(c as i64) as u64
338                }
339            }
340            Opcode::DIVU => {
341                if c == 0 {
342                    u64::MAX
343                } else {
344                    b / c
345                }
346            }
347            Opcode::REM => {
348                if c == 0 {
349                    b
350                } else {
351                    (b as i64).wrapping_rem(c as i64) as u64
352                }
353            }
354            Opcode::REMU => {
355                if c == 0 {
356                    b
357                } else {
358                    b % c
359                }
360            }
361            // RISCV-64 word operations
362            Opcode::ADDW => (Wrapping(b as i32) + Wrapping(c as i32)).0 as i64 as u64,
363            Opcode::SUBW => (Wrapping(b as i32) - Wrapping(c as i32)).0 as i64 as u64,
364            Opcode::MULW => (Wrapping(b as i32) * Wrapping(c as i32)).0 as i64 as u64,
365            Opcode::DIVW => {
366                if c as i32 == 0 {
367                    u64::MAX
368                } else {
369                    (b as i32).wrapping_div(c as i32) as i64 as u64
370                }
371            }
372            Opcode::DIVUW => {
373                if c as i32 == 0 {
374                    u64::MAX
375                } else {
376                    ((b as u32 / c as u32) as i32) as i64 as u64
377                }
378            }
379            Opcode::REMW => {
380                if c as i32 == 0 {
381                    (b as i32) as u64
382                } else {
383                    (b as i32).wrapping_rem(c as i32) as i64 as u64
384                }
385            }
386            Opcode::REMUW => {
387                if c as u32 == 0 {
388                    (b as i32) as u64
389                } else {
390                    (((b as u32) % (c as u32)) as i32) as i64 as u64
391                }
392            }
393            // RISCV-64 bit operations
394            Opcode::SLLW => (((b as i64) << (c & 0x1f)) as i32) as i64 as u64,
395            Opcode::SRLW => (((b as u32) >> ((c & 0x1f) as u32)) as i32) as u64,
396            Opcode::SRAW => {
397                (b as i32).wrapping_shr(((c as i64 & 0x1f) as i32) as u32) as i64 as u64
398            }
399            _ => unreachable!(),
400        };
401
402        let rw_record = self.rw(rd, a);
403
404        // SAFETY: We're writing to a valid pointer as we just created the pointer from the
405        // `result`.
406        unsafe { addr_of_mut!((*result_ptr).a).write(a) };
407        unsafe { addr_of_mut!((*result_ptr).b).write(b) };
408        unsafe { addr_of_mut!((*result_ptr).c).write(c) };
409        unsafe { addr_of_mut!((*result_ptr).rd).write(rd) };
410        unsafe { addr_of_mut!((*result_ptr).rw_record).write(rw_record) };
411
412        // SAFETY: All fields have been initialized by this point.
413        unsafe { result.assume_init() }
414    }
415
416    /// Execute a jump instruction.
417    pub fn execute_jump(&mut self, instruction: &Instruction) -> JumpResult {
418        match instruction.opcode {
419            Opcode::JAL => {
420                let (rd, imm) = instruction.j_type();
421                let imm_se = sign_extend_imm(imm, 21);
422                let a = self.pc.wrapping_add(4);
423                let rd_record = self.rw(rd, a);
424
425                let next_pc = ((self.pc as i64).wrapping_add(imm_se)) as u64;
426                let b = imm_se as u64;
427                let c = 0;
428
429                self.next_pc = next_pc;
430
431                JumpResult { a, b, c, rd, rd_record, rs1: MaybeImmediate::Immediate(b) }
432            }
433            Opcode::JALR => {
434                let (rd, rs1, c) = instruction.i_type();
435                let imm_se = sign_extend_imm(c, 12);
436                let b_record = self.rr(rs1, MemoryAccessPosition::B);
437                let a = self.pc.wrapping_add(4);
438
439                // Calculate next PC: (rs1 + imm) & ~1
440                let next_pc = ((b_record.value as i64).wrapping_add(imm_se) as u64) & !1_u64;
441                let rd_record = self.rw(rd, a);
442
443                self.next_pc = next_pc;
444
445                JumpResult {
446                    a,
447                    b: b_record.value,
448                    c,
449                    rd,
450                    rd_record,
451                    rs1: MaybeImmediate::Register(rs1, b_record),
452                }
453            }
454            _ => unreachable!("Invalid opcode for `execute_jump`: {:?}", instruction.opcode),
455        }
456    }
457
458    /// Execute a branch instruction.
459    pub fn execute_branch(&mut self, instruction: &Instruction) -> BranchResult {
460        let (rs1, rs2, imm) = instruction.b_type();
461
462        let c = imm;
463        let b_record = self.rr(rs2, MemoryAccessPosition::B);
464        let a_record = self.rr(rs1, MemoryAccessPosition::A);
465
466        let a = a_record.value;
467        let b = b_record.value;
468
469        let branch = match instruction.opcode {
470            Opcode::BEQ => a == b,
471            Opcode::BNE => a != b,
472            Opcode::BLT => (a as i64) < (b as i64),
473            Opcode::BGE => (a as i64) >= (b as i64),
474            Opcode::BLTU => a < b,
475            Opcode::BGEU => a >= b,
476            _ => {
477                unreachable!()
478            }
479        };
480
481        if branch {
482            self.next_pc = self.pc.wrapping_add(c);
483        }
484
485        BranchResult { a, rs1, a_record, b, rs2, b_record, c }
486    }
487
488    /// Execute a U-type instruction.
489    #[inline]
490    pub fn execute_utype(&mut self, instruction: &Instruction) -> UTypeResult {
491        let (rd, imm) = instruction.u_type();
492        let (b, c) = (imm, imm);
493        let a = if instruction.opcode == Opcode::AUIPC { self.pc.wrapping_add(imm) } else { imm };
494        let a_record = self.rw(rd, a);
495
496        UTypeResult { a, b, c, rd, rw_record: a_record }
497    }
498
499    #[inline]
500    /// Execute an ecall instruction.
501    ///
502    /// # WARNING:
503    ///
504    /// Its up to the syscall handler to update the shape checker abouut sent/internal ecalls.
505    pub fn execute_ecall<RT>(
506        rt: &mut RT,
507        instruction: &Instruction,
508        code: SyscallCode,
509    ) -> Result<EcallResult, ExecutionError>
510    where
511        RT: SyscallRuntime<'a>,
512    {
513        if !instruction.is_ecall_instruction() {
514            unreachable!("Invalid opcode for `execute_ecall`: {:?}", instruction.opcode);
515        }
516
517        let core = rt.core_mut();
518
519        let c_record = core.rr(Register::X11, MemoryAccessPosition::C);
520        let b_record = core.rr(Register::X10, MemoryAccessPosition::B);
521        let c = c_record.value;
522        let b = b_record.value;
523
524        // The only way unconstrained mode interacts with the parts of the program that proven is
525        // via hints, this means during tracing and splicing, we can just "skip" the whole
526        // set of unconstrained cycles, and rely on the fact that the hints are already
527        // apart of the minimal trace.
528        let a = if code == SyscallCode::ENTER_UNCONSTRAINED {
529            0
530        } else {
531            sp1_ecall_handler(rt, code, b, c).unwrap_or(code as u64)
532        };
533
534        // Bad borrow checker!
535        let core = rt.core_mut();
536
537        // Read the code from the x5 register.
538        let a_record = core.rw(Register::X5, a);
539
540        // Add 256 to the next clock to account for the ecall.
541        core.set_next_clk(core.next_clk() + 256);
542
543        Ok(EcallResult { a, a_record, b, b_record, c, c_record })
544    }
545
546    /// Peek to get the code from the x5 register.
547    #[must_use]
548    pub fn read_code(&self) -> SyscallCode {
549        // We peek at register x5 to get the syscall id. The reason we don't `self.rr` this
550        // register is that we write to it later.
551        let t0 = Register::X5;
552
553        // Peek at the register, we dont care about the read here.
554        let syscall_id = self.registers[t0 as usize].value;
555
556        // Convert the raw value to a SyscallCode.
557        SyscallCode::from_u32(syscall_id as u32)
558    }
559}
560
561impl CoreVM<'_> {
562    /// Read the next required memory read from the trace.
563    #[inline]
564    fn mr(&mut self, addr: u64) -> MemoryReadRecord {
565        #[allow(clippy::manual_let_else)]
566        let record = match self.mem_reads.next() {
567            Some(next) => next,
568            None => {
569                unreachable!("memory reads unexpectdely exhausted at {addr}, clk {}", self.clk);
570            }
571        };
572
573        MemoryReadRecord {
574            value: record.value,
575            timestamp: self.timestamp(MemoryAccessPosition::Memory),
576            prev_timestamp: record.clk,
577            prev_page_prot_record: None,
578        }
579    }
580
581    #[inline]
582    pub(crate) fn mr_slice_unsafe(&mut self, len: usize) -> Vec<u64> {
583        let mem_reads = self.mem_reads();
584
585        mem_reads.take(len).map(|value| value.value).collect()
586    }
587
588    #[inline]
589    pub(crate) fn mr_slice(&mut self, _addr: u64, len: usize) -> Vec<MemoryReadRecord> {
590        let current_clk = self.clk();
591        let mem_reads = self.mem_reads();
592
593        let records: Vec<MemoryReadRecord> = mem_reads
594            .take(len)
595            .map(|value| MemoryReadRecord {
596                value: value.value,
597                timestamp: current_clk,
598                prev_timestamp: value.clk,
599                prev_page_prot_record: None,
600            })
601            .collect();
602
603        records
604    }
605
606    #[inline]
607    pub(crate) fn mw_slice(&mut self, _addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
608        let mem_writes = self.mem_reads();
609
610        let raw_records: Vec<_> = mem_writes.take(len * 2).collect();
611        let records: Vec<MemoryWriteRecord> = raw_records
612            .chunks(2)
613            .map(|chunk| {
614                #[allow(clippy::manual_let_else)]
615                let (old, new) = match (chunk.first(), chunk.last()) {
616                    (Some(old), Some(new)) => (old, new),
617                    _ => unreachable!("Precompile memory write out of bounds"),
618                };
619
620                MemoryWriteRecord {
621                    prev_timestamp: old.clk,
622                    prev_value: old.value,
623                    timestamp: new.clk,
624                    value: new.value,
625                    prev_page_prot_record: None,
626                }
627            })
628            .collect();
629
630        records
631    }
632
633    #[inline]
634    fn mw(&mut self, read_record: MemoryReadRecord, value: u64) -> MemoryWriteRecord {
635        MemoryWriteRecord {
636            prev_timestamp: read_record.prev_timestamp,
637            prev_value: read_record.value,
638            timestamp: self.timestamp(MemoryAccessPosition::Memory),
639            value,
640            prev_page_prot_record: None,
641        }
642    }
643
644    /// Read a value from a register, updating the register entry and returning the record.
645    #[inline]
646    fn rr(&mut self, register: Register, position: MemoryAccessPosition) -> MemoryReadRecord {
647        let prev_record = self.registers[register as usize];
648        let new_record =
649            MemoryRecord { timestamp: self.timestamp(position), value: prev_record.value };
650
651        self.registers[register as usize] = new_record;
652
653        MemoryReadRecord {
654            value: new_record.value,
655            timestamp: new_record.timestamp,
656            prev_timestamp: prev_record.timestamp,
657            prev_page_prot_record: None,
658        }
659    }
660
661    /// Read a value from a register, updating the register entry and returning the record.
662    #[inline]
663    fn rr_precompile(&mut self, register: usize) -> MemoryReadRecord {
664        debug_assert!(register < 32, "out of bounds register: {register}");
665
666        let prev_record = self.registers[register];
667        let new_record = MemoryRecord { timestamp: self.clk(), value: prev_record.value };
668
669        self.registers[register] = new_record;
670
671        MemoryReadRecord {
672            value: new_record.value,
673            timestamp: new_record.timestamp,
674            prev_timestamp: prev_record.timestamp,
675            prev_page_prot_record: None,
676        }
677    }
678
679    /// Touch all the registers in the VM, bumping thier clock to `self.clk - 1`.
680    pub fn register_refresh(&mut self) -> [MemoryReadRecord; 32] {
681        fn bump_register(vm: &mut CoreVM, register: usize) -> MemoryReadRecord {
682            let prev_record = vm.registers[register];
683            let new_record = MemoryRecord { timestamp: vm.clk - 1, value: prev_record.value };
684
685            vm.registers[register] = new_record;
686
687            MemoryReadRecord {
688                value: new_record.value,
689                timestamp: new_record.timestamp,
690                prev_timestamp: prev_record.timestamp,
691                prev_page_prot_record: None,
692            }
693        }
694
695        tracing::trace!("register refresh to: {}", self.clk - 1);
696
697        let mut out = [MaybeUninit::uninit(); 32];
698        for (i, record) in out.iter_mut().enumerate() {
699            *record = MaybeUninit::new(bump_register(self, i));
700        }
701
702        // SAFETY: We're transmuting a [MaybeUninit<MemoryReadRecord>; 32] to a [MemoryReadRecord;
703        // 32], which we just initialized.
704        //
705        // These types are guaranteed to have the same representation.
706        unsafe { std::mem::transmute(out) }
707    }
708
709    /// Write a value to a register, updating the register entry and returning the record.
710    #[inline]
711    fn rw(&mut self, register: Register, value: u64) -> MemoryWriteRecord {
712        let value = if register == Register::X0 { 0 } else { value };
713
714        let prev_record = self.registers[register as usize];
715        let new_record = MemoryRecord { timestamp: self.timestamp(MemoryAccessPosition::A), value };
716
717        self.registers[register as usize] = new_record;
718
719        // if SHAPE_CHECKING {
720        //     self.shape_checker.handle_mem_event(register as u64, prev_record.timestamp);
721        // }
722
723        // if REPORT_GENERATING {
724        //     self.gas_calculator.handle_mem_event(register as u64, prev_record.timestamp);
725        // }
726
727        MemoryWriteRecord {
728            value: new_record.value,
729            timestamp: new_record.timestamp,
730            prev_timestamp: prev_record.timestamp,
731            prev_value: prev_record.value,
732            prev_page_prot_record: None,
733        }
734    }
735}
736
737impl CoreVM<'_> {
738    /// Get the current timestamp for a given memory access position.
739    #[inline]
740    #[must_use]
741    pub const fn timestamp(&self, position: MemoryAccessPosition) -> u64 {
742        self.clk + position as u64
743    }
744
745    /// Check if the top 24 bits have changed, which imply a `state bump` event needs to be emitted.
746    #[inline]
747    #[must_use]
748    pub const fn needs_bump_clk_high(&self) -> bool {
749        (self.next_clk() >> 24) ^ (self.clk() >> 24) > 0
750    }
751
752    /// Check if the state needs to be bumped, which implies a `state bump` event needs to be
753    /// emitted.
754    #[inline]
755    #[must_use]
756    pub const fn needs_state_bump(&self, instruction: &Instruction) -> bool {
757        let next_pc = self.next_pc();
758        let increment = self.next_clk() + 8 - self.clk();
759
760        let bump1 = self.clk() % (1 << 24) + increment >= (1 << 24);
761        let bump2 = !instruction.is_with_correct_next_pc()
762            && next_pc == self.pc().wrapping_add(4)
763            && (next_pc >> 16) != (self.pc() >> 16);
764
765        bump1 || bump2
766    }
767}
768
769impl<'a> CoreVM<'a> {
770    #[inline]
771    #[must_use]
772    /// Get the current clock, this clock is incremented by [`CLK_INC`] each cycle.
773    pub const fn clk(&self) -> u64 {
774        self.clk
775    }
776
777    #[inline]
778    /// Set the current clock.
779    pub fn set_clk(&mut self, new_clk: u64) {
780        self.clk = new_clk;
781    }
782
783    #[inline]
784    /// Set the next clock.
785    pub fn set_next_clk(&mut self, clk: u64) {
786        self.next_clk = clk;
787    }
788
789    #[inline]
790    #[must_use]
791    /// Get the global clock, this clock is incremented by 1 each cycle.
792    pub fn global_clk(&self) -> u64 {
793        self.global_clk
794    }
795
796    #[inline]
797    #[must_use]
798    /// Get the current program counter.
799    pub const fn pc(&self) -> u64 {
800        self.pc
801    }
802
803    #[inline]
804    #[must_use]
805    /// Get the next program counter that will be set in [`CoreVM::advance`].
806    pub const fn next_pc(&self) -> u64 {
807        self.next_pc
808    }
809
810    #[inline]
811    /// Set the next program counter.
812    pub fn set_next_pc(&mut self, pc: u64) {
813        self.next_pc = pc;
814    }
815
816    #[inline]
817    #[must_use]
818    /// Get the exit code.
819    pub fn exit_code(&self) -> u32 {
820        self.exit_code
821    }
822
823    #[inline]
824    /// Set the exit code.
825    pub fn set_exit_code(&mut self, exit_code: u32) {
826        self.exit_code = exit_code;
827    }
828
829    #[inline]
830    /// Set the program counter.
831    pub fn set_pc(&mut self, pc: u64) {
832        self.pc = pc;
833    }
834
835    #[inline]
836    /// Set the global clock.
837    pub fn set_global_clk(&mut self, global_clk: u64) {
838        self.global_clk = global_clk;
839    }
840
841    #[inline]
842    #[must_use]
843    /// Get the next clock that will be set in [`CoreVM::advance`].
844    pub const fn next_clk(&self) -> u64 {
845        self.next_clk
846    }
847
848    #[inline]
849    #[must_use]
850    /// Get the current registers (immutable).
851    pub fn registers(&self) -> &[MemoryRecord; 32] {
852        &self.registers
853    }
854
855    #[inline]
856    #[must_use]
857    /// Get the current registers (mutable).
858    pub fn registers_mut(&mut self) -> &mut [MemoryRecord; 32] {
859        &mut self.registers
860    }
861
862    #[inline]
863    /// Get the memory reads iterator.
864    pub fn mem_reads(&mut self) -> &mut MemReads<'a> {
865        &mut self.mem_reads
866    }
867
868    /// Check if the syscall is retained.
869    #[inline]
870    #[must_use]
871    pub fn is_retained_syscall(&self, syscall_code: SyscallCode) -> bool {
872        self.retained_syscall_codes.contains(&syscall_code)
873    }
874
875    /// Check if the trace has ended.
876    #[inline]
877    #[must_use]
878    pub const fn is_trace_end(&self) -> bool {
879        self.clk_end == self.clk()
880    }
881
882    /// Check if the program has halted.
883    #[inline]
884    #[must_use]
885    pub const fn is_done(&self) -> bool {
886        self.pc() == HALT_PC
887    }
888}
889
890fn sign_extend_imm(value: u64, bits: u8) -> i64 {
891    let shift = 64 - bits;
892    ((value as i64) << shift) >> shift
893}