Skip to main content

sp1_core_executor/
vm.rs

1use crate::{
2    events::{MemoryAccessPosition, MemoryReadRecord, MemoryRecord, MemoryWriteRecord},
3    vm::{
4        results::{
5            AluResult, BranchResult, CycleResult, EcallResult, JumpResult, LoadResult,
6            MaybeImmediate, StoreResult, UTypeResult,
7        },
8        syscall::{sp1_ecall_handler, SyscallRuntime},
9    },
10    ExecutionError, Instruction, Opcode, Program, Register, RetainedEventsPreset, SP1CoreOpts,
11    SyscallCode, CLK_INC as CLK_INC_32, HALT_PC, PC_INC as PC_INC_32,
12};
13use sp1_hypercube::air::{PROOF_NONCE_NUM_WORDS, PV_DIGEST_NUM_WORDS};
14use sp1_jit::{MemReads, MinimalTrace};
15use std::{mem::MaybeUninit, num::Wrapping, ptr::addr_of_mut, sync::Arc};
16
17pub(crate) mod gas;
18pub(crate) mod memory;
19pub(crate) mod results;
20pub(crate) mod shapes;
21pub(crate) mod syscall;
22
23const CLK_INC: u64 = CLK_INC_32 as u64;
24const PC_INC: u64 = PC_INC_32 as u64;
25
26/// A RISC-V VM that uses a [`MinimalTrace`] to oracle memory access.
27pub struct CoreVM<'a> {
28    registers: [MemoryRecord; 32],
29    /// The current clock of the VM.
30    clk: u64,
31    /// The global clock of the VM.
32    global_clk: u64,
33    /// The current program counter of the VM.
34    pc: u64,
35    /// The current exit code of the VM.
36    exit_code: u32,
37    /// The memory reads cursoir.
38    pub mem_reads: MemReads<'a>,
39    /// The next program counter that will be set in [`CoreVM::advance`].
40    next_pc: u64,
41    /// The next clock that will be set in [`CoreVM::advance`].
42    next_clk: u64,
43    /// The hint lenghts that read from within the vm.
44    hint_lens: std::slice::Iter<'a, usize>,
45    /// The program that is being executed.
46    pub program: Arc<Program>,
47    /// The syscalls that are not marked as external, ie. they stay in the same shard.
48    pub(crate) retained_syscall_codes: Vec<SyscallCode>,
49    /// The options to configure the VM, mostly for syscall / shard handling.
50    pub opts: SP1CoreOpts,
51    /// The end clk of the trace chunk.
52    pub clk_end: u64,
53    /// The public value digest.
54    pub public_value_digest: [u32; PV_DIGEST_NUM_WORDS],
55    /// The nonce associated with the proof.
56    pub proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
57}
58
59impl<'a> CoreVM<'a> {
60    /// Create a [`CoreVM`] from a [`MinimalTrace`] and a [`Program`].
61    pub fn new<T: MinimalTrace>(
62        trace: &'a T,
63        program: Arc<Program>,
64        opts: SP1CoreOpts,
65        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
66    ) -> Self {
67        let start_clk = trace.clk_start();
68
69        // SAFETY: We're mapping a [T; 32] -> [T; 32] infallibly.
70        let registers = unsafe {
71            trace
72                .start_registers()
73                .into_iter()
74                .map(|v| MemoryRecord { timestamp: start_clk - 1, value: v })
75                .collect::<Vec<_>>()
76                .try_into()
77                .unwrap_unchecked()
78        };
79        let start_pc = trace.pc_start();
80
81        let retained_syscall_codes = opts
82            .retained_events_presets
83            .iter()
84            .flat_map(RetainedEventsPreset::syscall_codes)
85            .copied()
86            .collect();
87
88        tracing::trace!("start_clk: {}", start_clk);
89        tracing::trace!("start_pc: {}", start_pc);
90        tracing::trace!("trace.clk_end(): {}", trace.clk_end());
91        tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
92        tracing::trace!("trace.hint_lens(): {:?}", trace.hint_lens().len());
93        tracing::trace!("trace.start_registers(): {:?}", trace.start_registers());
94
95        if trace.clk_start() == 1 {
96            assert_eq!(trace.pc_start(), program.pc_start_abs);
97        }
98
99        Self {
100            registers,
101            global_clk: 0,
102            clk: start_clk,
103            pc: start_pc,
104            program,
105            mem_reads: trace.mem_reads(),
106            next_pc: start_pc.wrapping_add(PC_INC),
107            next_clk: start_clk.wrapping_add(CLK_INC),
108            hint_lens: trace.hint_lens().iter(),
109            exit_code: 0,
110            retained_syscall_codes,
111            opts,
112            clk_end: trace.clk_end(),
113            public_value_digest: [0; PV_DIGEST_NUM_WORDS],
114            proof_nonce,
115        }
116    }
117
118    /// Fetch the next instruction from the program.
119    #[inline]
120    pub fn fetch(&mut self) -> Option<&Instruction> {
121        // todo: mprotect / kernel mode logic.
122        self.program.fetch(self.pc)
123    }
124
125    #[inline]
126    /// Increment the state of the VM by one cycle.
127    /// Calling this method will update the pc and the clk to the next cycle.
128    pub fn advance(&mut self) -> CycleResult {
129        self.clk = self.next_clk;
130        self.pc = self.next_pc;
131
132        // Reset the next_clk and next_pc to the next cycle.
133        self.next_clk = self.clk.wrapping_add(CLK_INC);
134        self.next_pc = self.pc.wrapping_add(PC_INC);
135        self.global_clk = self.global_clk.wrapping_add(1);
136
137        // Check if the program has halted.
138        if self.pc == HALT_PC {
139            return CycleResult::Done(true);
140        }
141
142        // Check if the shard limit has been reached.
143        if self.is_trace_end() {
144            return CycleResult::TraceEnd;
145        }
146
147        // Return that the program is still running.
148        CycleResult::Done(false)
149    }
150
151    /// Execute a load instruction.
152    #[inline]
153    pub fn execute_load(
154        &mut self,
155        instruction: &Instruction,
156    ) -> Result<LoadResult, ExecutionError> {
157        let (rd, rs1, imm) = instruction.i_type();
158
159        let rr_record = self.rr(rs1, MemoryAccessPosition::B);
160        let b = rr_record.value;
161
162        // Compute the address.
163        let addr = b.wrapping_add(imm);
164        let mr_record = self.mr(addr);
165        let word = mr_record.value;
166
167        let a = match instruction.opcode {
168            Opcode::LB => ((word >> ((addr % 8) * 8)) & 0xFF) as i8 as i64 as u64,
169            Opcode::LH => {
170                if !addr.is_multiple_of(2) {
171                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LH, addr));
172                }
173
174                ((word >> (((addr / 2) % 4) * 16)) & 0xFFFF) as i16 as i64 as u64
175            }
176            Opcode::LW => {
177                if !addr.is_multiple_of(4) {
178                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LW, addr));
179                }
180
181                ((word >> (((addr / 4) % 2) * 32)) & 0xFFFFFFFF) as i32 as u64
182            }
183            Opcode::LBU => ((word >> ((addr % 8) * 8)) & 0xFF) as u8 as u64,
184            Opcode::LHU => {
185                if !addr.is_multiple_of(2) {
186                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LHU, addr));
187                }
188
189                ((word >> (((addr / 2) % 4) * 16)) & 0xFFFF) as u16 as u64
190            }
191            // RISCV-64
192            Opcode::LWU => {
193                if !addr.is_multiple_of(4) {
194                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LWU, addr));
195                }
196
197                (word >> (((addr / 4) % 2) * 32)) & 0xFFFFFFFF
198            }
199            Opcode::LD => {
200                if !addr.is_multiple_of(8) {
201                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LD, addr));
202                }
203
204                word
205            }
206            _ => unreachable!("Invalid opcode for `execute_load`: {:?}", instruction.opcode),
207        };
208
209        let rw_record = self.rw(rd, a);
210
211        Ok(LoadResult { a, b, c: imm, addr, rs1, rd, rr_record, rw_record, mr_record })
212    }
213
214    /// Execute a store instruction.
215    #[inline]
216    pub fn execute_store(
217        &mut self,
218        instruction: &Instruction,
219    ) -> Result<StoreResult, ExecutionError> {
220        let (rs1, rs2, imm) = instruction.s_type();
221
222        let c = imm;
223        let rs2_record = self.rr(rs2, MemoryAccessPosition::B);
224        let rs1_record = self.rr(rs1, MemoryAccessPosition::A);
225
226        let b = rs2_record.value;
227        let a = rs1_record.value;
228        let addr = b.wrapping_add(c);
229        let mr_record = self.mr(addr);
230        let word = mr_record.value;
231
232        let memory_store_value = match instruction.opcode {
233            Opcode::SB => {
234                let shift = (addr % 8) * 8;
235                ((a & 0xFF) << shift) | (word & !(0xFF << shift))
236            }
237            Opcode::SH => {
238                if !addr.is_multiple_of(2) {
239                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SH, addr));
240                }
241                let shift = ((addr / 2) % 4) * 16;
242                ((a & 0xFFFF) << shift) | (word & !(0xFFFF << shift))
243            }
244            Opcode::SW => {
245                if !addr.is_multiple_of(4) {
246                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SW, addr));
247                }
248                let shift = ((addr / 4) % 2) * 32;
249                ((a & 0xFFFFFFFF) << shift) | (word & !(0xFFFFFFFF << shift))
250            }
251            // RISCV-64
252            Opcode::SD => {
253                if !addr.is_multiple_of(8) {
254                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SD, addr));
255                }
256                a
257            }
258            _ => unreachable!(),
259        };
260
261        let mw_record = self.mw(mr_record, memory_store_value);
262
263        Ok(StoreResult { a, b, c, addr, rs1, rs1_record, rs2, rs2_record, mw_record })
264    }
265
266    /// Execute an ALU instruction.
267    #[inline]
268    #[allow(clippy::too_many_lines)]
269    pub fn execute_alu(&mut self, instruction: &Instruction) -> AluResult {
270        let mut result = MaybeUninit::<AluResult>::uninit();
271        let result_ptr = result.as_mut_ptr();
272
273        let (rd, b, c) = if !instruction.imm_c {
274            let (rd, rs1, rs2) = instruction.r_type();
275            let c = self.rr(rs2, MemoryAccessPosition::C);
276            let b = self.rr(rs1, MemoryAccessPosition::B);
277
278            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
279            // `result`.
280            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Register(rs1, b)) };
281            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Register(rs2, c)) };
282
283            (rd, b.value, c.value)
284        } else if !instruction.imm_b && instruction.imm_c {
285            let (rd, rs1, imm) = instruction.i_type();
286            let (rd, b, c) = (rd, self.rr(rs1, MemoryAccessPosition::B), imm);
287
288            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
289            // `result`.
290            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Register(rs1, b)) };
291            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Immediate(c)) };
292
293            (rd, b.value, c)
294        } else {
295            debug_assert!(instruction.imm_b && instruction.imm_c);
296            let (rd, b, c) =
297                (Register::from_u8(instruction.op_a), instruction.op_b, instruction.op_c);
298
299            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
300            // `result`.
301            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Immediate(b)) };
302            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Immediate(c)) };
303
304            (rd, b, c)
305        };
306
307        let a = match instruction.opcode {
308            Opcode::ADD | Opcode::ADDI => (Wrapping(b) + Wrapping(c)).0,
309            Opcode::SUB => (Wrapping(b) - Wrapping(c)).0,
310            Opcode::XOR => b ^ c,
311            Opcode::OR => b | c,
312            Opcode::AND => b & c,
313            Opcode::SLL => b << (c & 0x3f),
314            Opcode::SRL => b >> (c & 0x3f),
315            Opcode::SRA => ((b as i64) >> (c & 0x3f)) as u64,
316            Opcode::SLT => {
317                if (b as i64) < (c as i64) {
318                    1
319                } else {
320                    0
321                }
322            }
323            Opcode::SLTU => {
324                if b < c {
325                    1
326                } else {
327                    0
328                }
329            }
330            Opcode::MUL => (Wrapping(b as i64) * Wrapping(c as i64)).0 as u64,
331            Opcode::MULH => (((b as i64) as i128).wrapping_mul((c as i64) as i128) >> 64) as u64,
332            Opcode::MULHU => ((b as u128 * c as u128) >> 64) as u64,
333            Opcode::MULHSU => ((((b as i64) as i128) * (c as i128)) >> 64) as u64,
334            Opcode::DIV => {
335                if c == 0 {
336                    u64::MAX
337                } else {
338                    (b as i64).wrapping_div(c as i64) as u64
339                }
340            }
341            Opcode::DIVU => {
342                if c == 0 {
343                    u64::MAX
344                } else {
345                    b / c
346                }
347            }
348            Opcode::REM => {
349                if c == 0 {
350                    b
351                } else {
352                    (b as i64).wrapping_rem(c as i64) as u64
353                }
354            }
355            Opcode::REMU => {
356                if c == 0 {
357                    b
358                } else {
359                    b % c
360                }
361            }
362            // RISCV-64 word operations
363            Opcode::ADDW => (Wrapping(b as i32) + Wrapping(c as i32)).0 as i64 as u64,
364            Opcode::SUBW => (Wrapping(b as i32) - Wrapping(c as i32)).0 as i64 as u64,
365            Opcode::MULW => (Wrapping(b as i32) * Wrapping(c as i32)).0 as i64 as u64,
366            Opcode::DIVW => {
367                if c as i32 == 0 {
368                    u64::MAX
369                } else {
370                    (b as i32).wrapping_div(c as i32) as i64 as u64
371                }
372            }
373            Opcode::DIVUW => {
374                if c as i32 == 0 {
375                    u64::MAX
376                } else {
377                    ((b as u32 / c as u32) as i32) as i64 as u64
378                }
379            }
380            Opcode::REMW => {
381                if c as i32 == 0 {
382                    (b as i32) as u64
383                } else {
384                    (b as i32).wrapping_rem(c as i32) as i64 as u64
385                }
386            }
387            Opcode::REMUW => {
388                if c as u32 == 0 {
389                    (b as i32) as u64
390                } else {
391                    (((b as u32) % (c as u32)) as i32) as i64 as u64
392                }
393            }
394            // RISCV-64 bit operations
395            Opcode::SLLW => (((b as i64) << (c & 0x1f)) as i32) as i64 as u64,
396            Opcode::SRLW => (((b as u32) >> ((c & 0x1f) as u32)) as i32) as u64,
397            Opcode::SRAW => {
398                (b as i32).wrapping_shr(((c as i64 & 0x1f) as i32) as u32) as i64 as u64
399            }
400            _ => unreachable!(),
401        };
402
403        let rw_record = self.rw(rd, a);
404
405        // SAFETY: We're writing to a valid pointer as we just created the pointer from the
406        // `result`.
407        unsafe { addr_of_mut!((*result_ptr).a).write(a) };
408        unsafe { addr_of_mut!((*result_ptr).b).write(b) };
409        unsafe { addr_of_mut!((*result_ptr).c).write(c) };
410        unsafe { addr_of_mut!((*result_ptr).rd).write(rd) };
411        unsafe { addr_of_mut!((*result_ptr).rw_record).write(rw_record) };
412
413        // SAFETY: All fields have been initialized by this point.
414        unsafe { result.assume_init() }
415    }
416
417    /// Execute a jump instruction.
418    pub fn execute_jump(&mut self, instruction: &Instruction) -> JumpResult {
419        match instruction.opcode {
420            Opcode::JAL => {
421                let (rd, imm) = instruction.j_type();
422                let imm_se = sign_extend_imm(imm, 21);
423                let a = self.pc.wrapping_add(4);
424                let rd_record = self.rw(rd, a);
425
426                let next_pc = ((self.pc as i64).wrapping_add(imm_se)) as u64;
427                let b = imm_se as u64;
428                let c = 0;
429
430                self.next_pc = next_pc;
431
432                JumpResult { a, b, c, rd, rd_record, rs1: MaybeImmediate::Immediate(b) }
433            }
434            Opcode::JALR => {
435                let (rd, rs1, c) = instruction.i_type();
436                let imm_se = sign_extend_imm(c, 12);
437                let b_record = self.rr(rs1, MemoryAccessPosition::B);
438                let a = self.pc.wrapping_add(4);
439
440                // Calculate next PC: (rs1 + imm) & ~1
441                let next_pc = ((b_record.value as i64).wrapping_add(imm_se) as u64) & !1_u64;
442                let rd_record = self.rw(rd, a);
443
444                self.next_pc = next_pc;
445
446                JumpResult {
447                    a,
448                    b: b_record.value,
449                    c,
450                    rd,
451                    rd_record,
452                    rs1: MaybeImmediate::Register(rs1, b_record),
453                }
454            }
455            _ => unreachable!("Invalid opcode for `execute_jump`: {:?}", instruction.opcode),
456        }
457    }
458
459    /// Execute a branch instruction.
460    pub fn execute_branch(&mut self, instruction: &Instruction) -> BranchResult {
461        let (rs1, rs2, imm) = instruction.b_type();
462
463        let c = imm;
464        let b_record = self.rr(rs2, MemoryAccessPosition::B);
465        let a_record = self.rr(rs1, MemoryAccessPosition::A);
466
467        let a = a_record.value;
468        let b = b_record.value;
469
470        let branch = match instruction.opcode {
471            Opcode::BEQ => a == b,
472            Opcode::BNE => a != b,
473            Opcode::BLT => (a as i64) < (b as i64),
474            Opcode::BGE => (a as i64) >= (b as i64),
475            Opcode::BLTU => a < b,
476            Opcode::BGEU => a >= b,
477            _ => {
478                unreachable!()
479            }
480        };
481
482        if branch {
483            self.next_pc = self.pc.wrapping_add(c);
484        }
485
486        BranchResult { a, rs1, a_record, b, rs2, b_record, c }
487    }
488
489    /// Execute a U-type instruction.
490    #[inline]
491    pub fn execute_utype(&mut self, instruction: &Instruction) -> UTypeResult {
492        let (rd, imm) = instruction.u_type();
493        let (b, c) = (imm, imm);
494        let a = if instruction.opcode == Opcode::AUIPC { self.pc.wrapping_add(imm) } else { imm };
495        let a_record = self.rw(rd, a);
496
497        UTypeResult { a, b, c, rd, rw_record: a_record }
498    }
499
500    #[inline]
501    /// Execute an ecall instruction.
502    ///
503    /// # WARNING:
504    ///
505    /// Its up to the syscall handler to update the shape checker abouut sent/internal ecalls.
506    pub fn execute_ecall<RT>(
507        rt: &mut RT,
508        instruction: &Instruction,
509        code: SyscallCode,
510    ) -> Result<EcallResult, ExecutionError>
511    where
512        RT: SyscallRuntime<'a>,
513    {
514        if !instruction.is_ecall_instruction() {
515            unreachable!("Invalid opcode for `execute_ecall`: {:?}", instruction.opcode);
516        }
517
518        let core = rt.core_mut();
519
520        let c_record = core.rr(Register::X11, MemoryAccessPosition::C);
521        let b_record = core.rr(Register::X10, MemoryAccessPosition::B);
522        let c = c_record.value;
523        let b = b_record.value;
524
525        // The only way unconstrained mode interacts with the parts of the program that proven is
526        // via hints, this means during tracing and splicing, we can just "skip" the whole
527        // set of unconstrained cycles, and rely on the fact that the hints are already
528        // apart of the minimal trace.
529        let a = if code == SyscallCode::ENTER_UNCONSTRAINED {
530            0
531        } else {
532            sp1_ecall_handler(rt, code, b, c).unwrap_or(code as u64)
533        };
534
535        // Bad borrow checker!
536        let core = rt.core_mut();
537
538        // Read the code from the x5 register.
539        let a_record = core.rw(Register::X5, a);
540
541        // Add 256 to the next clock to account for the ecall.
542        core.set_next_clk(core.next_clk() + 256);
543
544        Ok(EcallResult { a, a_record, b, b_record, c, c_record })
545    }
546
547    /// Peek to get the code from the x5 register.
548    #[must_use]
549    pub fn read_code(&self) -> SyscallCode {
550        // We peek at register x5 to get the syscall id. The reason we don't `self.rr` this
551        // register is that we write to it later.
552        let t0 = Register::X5;
553
554        // Peek at the register, we dont care about the read here.
555        let syscall_id = self.registers[t0 as usize].value;
556
557        // Convert the raw value to a SyscallCode.
558        SyscallCode::from_u32(syscall_id as u32)
559    }
560}
561
562impl CoreVM<'_> {
563    /// Read the next required memory read from the trace.
564    #[inline]
565    fn mr(&mut self, addr: u64) -> MemoryReadRecord {
566        #[allow(clippy::manual_let_else)]
567        let record = match self.mem_reads.next() {
568            Some(next) => next,
569            None => {
570                unreachable!("memory reads unexpectdely exhausted at {addr}, clk {}", self.clk);
571            }
572        };
573
574        MemoryReadRecord {
575            value: record.value,
576            timestamp: self.timestamp(MemoryAccessPosition::Memory),
577            prev_timestamp: record.clk,
578            prev_page_prot_record: None,
579        }
580    }
581
582    #[inline]
583    pub(crate) fn mr_slice_unsafe(&mut self, len: usize) -> Vec<u64> {
584        let mem_reads = self.mem_reads();
585
586        mem_reads.take(len).map(|value| value.value).collect()
587    }
588
589    #[inline]
590    pub(crate) fn mr_slice(&mut self, _addr: u64, len: usize) -> Vec<MemoryReadRecord> {
591        let current_clk = self.clk();
592        let mem_reads = self.mem_reads();
593
594        let records: Vec<MemoryReadRecord> = mem_reads
595            .take(len)
596            .map(|value| MemoryReadRecord {
597                value: value.value,
598                timestamp: current_clk,
599                prev_timestamp: value.clk,
600                prev_page_prot_record: None,
601            })
602            .collect();
603
604        records
605    }
606
607    #[inline]
608    pub(crate) fn mw_slice(&mut self, _addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
609        let mem_writes = self.mem_reads();
610
611        let raw_records: Vec<_> = mem_writes.take(len * 2).collect();
612        let records: Vec<MemoryWriteRecord> = raw_records
613            .chunks(2)
614            .map(|chunk| {
615                #[allow(clippy::manual_let_else)]
616                let (old, new) = match (chunk.first(), chunk.last()) {
617                    (Some(old), Some(new)) => (old, new),
618                    _ => unreachable!("Precompile memory write out of bounds"),
619                };
620
621                MemoryWriteRecord {
622                    prev_timestamp: old.clk,
623                    prev_value: old.value,
624                    timestamp: new.clk,
625                    value: new.value,
626                    prev_page_prot_record: None,
627                }
628            })
629            .collect();
630
631        records
632    }
633
634    #[inline]
635    fn mw(&mut self, read_record: MemoryReadRecord, value: u64) -> MemoryWriteRecord {
636        MemoryWriteRecord {
637            prev_timestamp: read_record.prev_timestamp,
638            prev_value: read_record.value,
639            timestamp: self.timestamp(MemoryAccessPosition::Memory),
640            value,
641            prev_page_prot_record: None,
642        }
643    }
644
645    /// Read a value from a register, updating the register entry and returning the record.
646    #[inline]
647    fn rr(&mut self, register: Register, position: MemoryAccessPosition) -> MemoryReadRecord {
648        let prev_record = self.registers[register as usize];
649        let new_record =
650            MemoryRecord { timestamp: self.timestamp(position), value: prev_record.value };
651
652        self.registers[register as usize] = new_record;
653
654        MemoryReadRecord {
655            value: new_record.value,
656            timestamp: new_record.timestamp,
657            prev_timestamp: prev_record.timestamp,
658            prev_page_prot_record: None,
659        }
660    }
661
662    /// Read a value from a register, updating the register entry and returning the record.
663    #[inline]
664    fn rr_precompile(&mut self, register: usize) -> MemoryReadRecord {
665        debug_assert!(register < 32, "out of bounds register: {register}");
666
667        let prev_record = self.registers[register];
668        let new_record = MemoryRecord { timestamp: self.clk(), value: prev_record.value };
669
670        self.registers[register] = new_record;
671
672        MemoryReadRecord {
673            value: new_record.value,
674            timestamp: new_record.timestamp,
675            prev_timestamp: prev_record.timestamp,
676            prev_page_prot_record: None,
677        }
678    }
679
680    /// Touch all the registers in the VM, bumping thier clock to `self.clk - 1`.
681    pub fn register_refresh(&mut self) -> [MemoryReadRecord; 32] {
682        fn bump_register(vm: &mut CoreVM, register: usize) -> MemoryReadRecord {
683            let prev_record = vm.registers[register];
684            let new_record = MemoryRecord { timestamp: vm.clk - 1, value: prev_record.value };
685
686            vm.registers[register] = new_record;
687
688            MemoryReadRecord {
689                value: new_record.value,
690                timestamp: new_record.timestamp,
691                prev_timestamp: prev_record.timestamp,
692                prev_page_prot_record: None,
693            }
694        }
695
696        tracing::trace!("register refresh to: {}", self.clk - 1);
697
698        let mut out = [MaybeUninit::uninit(); 32];
699        for (i, record) in out.iter_mut().enumerate() {
700            *record = MaybeUninit::new(bump_register(self, i));
701        }
702
703        // SAFETY: We're transmuting a [MaybeUninit<MemoryReadRecord>; 32] to a [MemoryReadRecord;
704        // 32], which we just initialized.
705        //
706        // These types are guaranteed to have the same representation.
707        unsafe { std::mem::transmute(out) }
708    }
709
710    /// Write a value to a register, updating the register entry and returning the record.
711    #[inline]
712    fn rw(&mut self, register: Register, value: u64) -> MemoryWriteRecord {
713        let value = if register == Register::X0 { 0 } else { value };
714
715        let prev_record = self.registers[register as usize];
716        let new_record = MemoryRecord { timestamp: self.timestamp(MemoryAccessPosition::A), value };
717
718        self.registers[register as usize] = new_record;
719
720        // if SHAPE_CHECKING {
721        //     self.shape_checker.handle_mem_event(register as u64, prev_record.timestamp);
722        // }
723
724        // if REPORT_GENERATING {
725        //     self.gas_calculator.handle_mem_event(register as u64, prev_record.timestamp);
726        // }
727
728        MemoryWriteRecord {
729            value: new_record.value,
730            timestamp: new_record.timestamp,
731            prev_timestamp: prev_record.timestamp,
732            prev_value: prev_record.value,
733            prev_page_prot_record: None,
734        }
735    }
736}
737
738impl CoreVM<'_> {
739    /// Get the current timestamp for a given memory access position.
740    #[inline]
741    #[must_use]
742    pub const fn timestamp(&self, position: MemoryAccessPosition) -> u64 {
743        self.clk + position as u64
744    }
745
746    /// Check if the top 24 bits have changed, which imply a `state bump` event needs to be emitted.
747    #[inline]
748    #[must_use]
749    pub const fn needs_bump_clk_high(&self) -> bool {
750        (self.next_clk() >> 24) ^ (self.clk() >> 24) > 0
751    }
752
753    /// Check if the state needs to be bumped, which implies a `state bump` event needs to be
754    /// emitted.
755    #[inline]
756    #[must_use]
757    pub const fn needs_state_bump(&self, instruction: &Instruction) -> bool {
758        let next_pc = self.next_pc();
759        let increment = self.next_clk() + 8 - self.clk();
760
761        let bump1 = self.clk() % (1 << 24) + increment >= (1 << 24);
762        let bump2 = !instruction.is_with_correct_next_pc()
763            && next_pc == self.pc().wrapping_add(4)
764            && (next_pc >> 16) != (self.pc() >> 16);
765
766        bump1 || bump2
767    }
768}
769
770impl<'a> CoreVM<'a> {
771    #[inline]
772    #[must_use]
773    /// Get the current clock, this clock is incremented by [`CLK_INC`] each cycle.
774    pub const fn clk(&self) -> u64 {
775        self.clk
776    }
777
778    #[inline]
779    /// Set the current clock.
780    pub fn set_clk(&mut self, new_clk: u64) {
781        self.clk = new_clk;
782    }
783
784    #[inline]
785    /// Set the next clock.
786    pub fn set_next_clk(&mut self, clk: u64) {
787        self.next_clk = clk;
788    }
789
790    #[inline]
791    #[must_use]
792    /// Get the global clock, this clock is incremented by 1 each cycle.
793    pub fn global_clk(&self) -> u64 {
794        self.global_clk
795    }
796
797    #[inline]
798    #[must_use]
799    /// Get the current program counter.
800    pub const fn pc(&self) -> u64 {
801        self.pc
802    }
803
804    #[inline]
805    #[must_use]
806    /// Get the next program counter that will be set in [`CoreVM::advance`].
807    pub const fn next_pc(&self) -> u64 {
808        self.next_pc
809    }
810
811    #[inline]
812    /// Set the next program counter.
813    pub fn set_next_pc(&mut self, pc: u64) {
814        self.next_pc = pc;
815    }
816
817    #[inline]
818    #[must_use]
819    /// Get the exit code.
820    pub fn exit_code(&self) -> u32 {
821        self.exit_code
822    }
823
824    #[inline]
825    /// Set the exit code.
826    pub fn set_exit_code(&mut self, exit_code: u32) {
827        self.exit_code = exit_code;
828    }
829
830    #[inline]
831    /// Set the program counter.
832    pub fn set_pc(&mut self, pc: u64) {
833        self.pc = pc;
834    }
835
836    #[inline]
837    /// Set the global clock.
838    pub fn set_global_clk(&mut self, global_clk: u64) {
839        self.global_clk = global_clk;
840    }
841
842    #[inline]
843    #[must_use]
844    /// Get the next clock that will be set in [`CoreVM::advance`].
845    pub const fn next_clk(&self) -> u64 {
846        self.next_clk
847    }
848
849    #[inline]
850    #[must_use]
851    /// Get the current registers (immutable).
852    pub fn registers(&self) -> &[MemoryRecord; 32] {
853        &self.registers
854    }
855
856    #[inline]
857    #[must_use]
858    /// Get the current registers (mutable).
859    pub fn registers_mut(&mut self) -> &mut [MemoryRecord; 32] {
860        &mut self.registers
861    }
862
863    #[inline]
864    /// Get the memory reads iterator.
865    pub fn mem_reads(&mut self) -> &mut MemReads<'a> {
866        &mut self.mem_reads
867    }
868
869    /// Check if the syscall is retained.
870    #[inline]
871    #[must_use]
872    pub fn is_retained_syscall(&self, syscall_code: SyscallCode) -> bool {
873        self.retained_syscall_codes.contains(&syscall_code)
874    }
875
876    /// Check if the trace has ended.
877    #[inline]
878    #[must_use]
879    pub const fn is_trace_end(&self) -> bool {
880        self.clk_end == self.clk()
881    }
882
883    /// Check if the program has halted.
884    #[inline]
885    #[must_use]
886    pub const fn is_done(&self) -> bool {
887        self.pc() == HALT_PC
888    }
889}
890
891fn sign_extend_imm(value: u64, bits: u8) -> i64 {
892    let shift = 64 - bits;
893    ((value as i64) << shift) >> shift
894}