Skip to main content

sp1_core_executor/
vm.rs

1#![allow(unknown_lints)]
2#![allow(clippy::manual_checked_ops)]
3
4use crate::{
5    events::{MemoryAccessPosition, MemoryReadRecord, MemoryRecord, MemoryWriteRecord},
6    vm::{
7        results::{
8            AluResult, BranchResult, CycleResult, EcallResult, JumpResult, LoadResult,
9            MaybeImmediate, StoreResult, UTypeResult,
10        },
11        syscall::{sp1_ecall_handler, SyscallRuntime},
12    },
13    ExecutionError, Instruction, Opcode, Program, Register, RetainedEventsPreset, SP1CoreOpts,
14    SyscallCode, CLK_INC as CLK_INC_32, HALT_PC, PC_INC as PC_INC_32,
15};
16use sp1_hypercube::air::{PROOF_NONCE_NUM_WORDS, PV_DIGEST_NUM_WORDS};
17use sp1_jit::{MemReads, MinimalTrace};
18use std::{mem::MaybeUninit, num::Wrapping, ptr::addr_of_mut, sync::Arc};
19
20pub(crate) mod gas;
21pub(crate) mod memory;
22pub(crate) mod results;
23pub(crate) mod shapes;
24pub(crate) mod syscall;
25
26const CLK_INC: u64 = CLK_INC_32 as u64;
27const PC_INC: u64 = PC_INC_32 as u64;
28
29/// A RISC-V VM that uses a [`MinimalTrace`] to oracle memory access.
30pub struct CoreVM<'a> {
31    registers: [MemoryRecord; 32],
32    /// The current clock of the VM.
33    clk: u64,
34    /// The global clock of the VM.
35    global_clk: u64,
36    /// The current program counter of the VM.
37    pc: u64,
38    /// The current exit code of the VM.
39    exit_code: u32,
40    /// The memory reads cursoir.
41    pub mem_reads: MemReads<'a>,
42    /// The next program counter that will be set in [`CoreVM::advance`].
43    next_pc: u64,
44    /// The next clock that will be set in [`CoreVM::advance`].
45    next_clk: u64,
46    /// The hint lenghts that read from within the vm.
47    hint_lens: std::slice::Iter<'a, usize>,
48    /// The program that is being executed.
49    pub program: Arc<Program>,
50    /// The syscalls that are not marked as external, ie. they stay in the same shard.
51    pub(crate) retained_syscall_codes: Vec<SyscallCode>,
52    /// The options to configure the VM, mostly for syscall / shard handling.
53    pub opts: SP1CoreOpts,
54    /// The end clk of the trace chunk.
55    pub clk_end: u64,
56    /// The public value digest.
57    pub public_value_digest: [u32; PV_DIGEST_NUM_WORDS],
58    /// The nonce associated with the proof.
59    pub proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
60}
61
62impl<'a> CoreVM<'a> {
63    /// Create a [`CoreVM`] from a [`MinimalTrace`] and a [`Program`].
64    pub fn new<T: MinimalTrace>(
65        trace: &'a T,
66        program: Arc<Program>,
67        opts: SP1CoreOpts,
68        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
69    ) -> Self {
70        let start_clk = trace.clk_start();
71
72        // SAFETY: We're mapping a [T; 32] -> [T; 32] infallibly.
73        let registers = unsafe {
74            trace
75                .start_registers()
76                .into_iter()
77                .map(|v| MemoryRecord { timestamp: start_clk - 1, value: v })
78                .collect::<Vec<_>>()
79                .try_into()
80                .unwrap_unchecked()
81        };
82        let start_pc = trace.pc_start();
83
84        let retained_syscall_codes = opts
85            .retained_events_presets
86            .iter()
87            .flat_map(RetainedEventsPreset::syscall_codes)
88            .copied()
89            .collect();
90
91        tracing::trace!("start_clk: {}", start_clk);
92        tracing::trace!("start_pc: {}", start_pc);
93        tracing::trace!("trace.clk_end(): {}", trace.clk_end());
94        tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
95        tracing::trace!("trace.hint_lens(): {:?}", trace.hint_lens().len());
96        tracing::trace!("trace.start_registers(): {:?}", trace.start_registers());
97
98        if trace.clk_start() == 1 {
99            assert_eq!(trace.pc_start(), program.pc_start_abs);
100        }
101
102        Self {
103            registers,
104            global_clk: 0,
105            clk: start_clk,
106            pc: start_pc,
107            program,
108            mem_reads: trace.mem_reads(),
109            next_pc: start_pc.wrapping_add(PC_INC),
110            next_clk: start_clk.wrapping_add(CLK_INC),
111            hint_lens: trace.hint_lens().iter(),
112            exit_code: 0,
113            retained_syscall_codes,
114            opts,
115            clk_end: trace.clk_end(),
116            public_value_digest: [0; PV_DIGEST_NUM_WORDS],
117            proof_nonce,
118        }
119    }
120
121    /// Fetch the next instruction from the program.
122    #[inline]
123    pub fn fetch(&mut self) -> Option<&Instruction> {
124        // todo: mprotect / kernel mode logic.
125        self.program.fetch(self.pc)
126    }
127
128    #[inline]
129    /// Increment the state of the VM by one cycle.
130    /// Calling this method will update the pc and the clk to the next cycle.
131    pub fn advance(&mut self) -> CycleResult {
132        self.clk = self.next_clk;
133        self.pc = self.next_pc;
134
135        // Reset the next_clk and next_pc to the next cycle.
136        self.next_clk = self.clk.wrapping_add(CLK_INC);
137        self.next_pc = self.pc.wrapping_add(PC_INC);
138        self.global_clk = self.global_clk.wrapping_add(1);
139
140        // Check if the program has halted.
141        if self.pc == HALT_PC {
142            return CycleResult::Done(true);
143        }
144
145        // Check if the shard limit has been reached.
146        if self.is_trace_end() {
147            return CycleResult::TraceEnd;
148        }
149
150        // Return that the program is still running.
151        CycleResult::Done(false)
152    }
153
154    /// Execute a load instruction.
155    #[inline]
156    pub fn execute_load(
157        &mut self,
158        instruction: &Instruction,
159    ) -> Result<LoadResult, ExecutionError> {
160        let (rd, rs1, imm) = instruction.i_type();
161
162        let rr_record = self.rr(rs1, MemoryAccessPosition::B);
163        let b = rr_record.value;
164
165        // Compute the address.
166        let addr = b.wrapping_add(imm);
167        let mr_record = self.mr(addr);
168        let word = mr_record.value;
169
170        let a = match instruction.opcode {
171            Opcode::LB => ((word >> ((addr % 8) * 8)) & 0xFF) as i8 as i64 as u64,
172            Opcode::LH => {
173                if !addr.is_multiple_of(2) {
174                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LH, addr));
175                }
176
177                ((word >> (((addr / 2) % 4) * 16)) & 0xFFFF) as i16 as i64 as u64
178            }
179            Opcode::LW => {
180                if !addr.is_multiple_of(4) {
181                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LW, addr));
182                }
183
184                ((word >> (((addr / 4) % 2) * 32)) & 0xFFFFFFFF) as i32 as u64
185            }
186            Opcode::LBU => ((word >> ((addr % 8) * 8)) & 0xFF) as u8 as u64,
187            Opcode::LHU => {
188                if !addr.is_multiple_of(2) {
189                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LHU, addr));
190                }
191
192                ((word >> (((addr / 2) % 4) * 16)) & 0xFFFF) as u16 as u64
193            }
194            // RISCV-64
195            Opcode::LWU => {
196                if !addr.is_multiple_of(4) {
197                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LWU, addr));
198                }
199
200                (word >> (((addr / 4) % 2) * 32)) & 0xFFFFFFFF
201            }
202            Opcode::LD => {
203                if !addr.is_multiple_of(8) {
204                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LD, addr));
205                }
206
207                word
208            }
209            _ => unreachable!("Invalid opcode for `execute_load`: {:?}", instruction.opcode),
210        };
211
212        let rw_record = self.rw(rd, a);
213
214        Ok(LoadResult { a, b, c: imm, addr, rs1, rd, rr_record, rw_record, mr_record })
215    }
216
217    /// Execute a store instruction.
218    #[inline]
219    pub fn execute_store(
220        &mut self,
221        instruction: &Instruction,
222    ) -> Result<StoreResult, ExecutionError> {
223        let (rs1, rs2, imm) = instruction.s_type();
224
225        let c = imm;
226        let rs2_record = self.rr(rs2, MemoryAccessPosition::B);
227        let rs1_record = self.rr(rs1, MemoryAccessPosition::A);
228
229        let b = rs2_record.value;
230        let a = rs1_record.value;
231        let addr = b.wrapping_add(c);
232        let mr_record = self.mr(addr);
233        let word = mr_record.value;
234
235        let memory_store_value = match instruction.opcode {
236            Opcode::SB => {
237                let shift = (addr % 8) * 8;
238                ((a & 0xFF) << shift) | (word & !(0xFF << shift))
239            }
240            Opcode::SH => {
241                if !addr.is_multiple_of(2) {
242                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SH, addr));
243                }
244                let shift = ((addr / 2) % 4) * 16;
245                ((a & 0xFFFF) << shift) | (word & !(0xFFFF << shift))
246            }
247            Opcode::SW => {
248                if !addr.is_multiple_of(4) {
249                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SW, addr));
250                }
251                let shift = ((addr / 4) % 2) * 32;
252                ((a & 0xFFFFFFFF) << shift) | (word & !(0xFFFFFFFF << shift))
253            }
254            // RISCV-64
255            Opcode::SD => {
256                if !addr.is_multiple_of(8) {
257                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SD, addr));
258                }
259                a
260            }
261            _ => unreachable!(),
262        };
263
264        let mw_record = self.mw(mr_record, memory_store_value);
265
266        Ok(StoreResult { a, b, c, addr, rs1, rs1_record, rs2, rs2_record, mw_record })
267    }
268
269    /// Execute an ALU instruction.
270    #[inline]
271    #[allow(clippy::too_many_lines)]
272    pub fn execute_alu(&mut self, instruction: &Instruction) -> AluResult {
273        let mut result = MaybeUninit::<AluResult>::uninit();
274        let result_ptr = result.as_mut_ptr();
275
276        let (rd, b, c) = if !instruction.imm_c {
277            let (rd, rs1, rs2) = instruction.r_type();
278            let c = self.rr(rs2, MemoryAccessPosition::C);
279            let b = self.rr(rs1, MemoryAccessPosition::B);
280
281            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
282            // `result`.
283            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Register(rs1, b)) };
284            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Register(rs2, c)) };
285
286            (rd, b.value, c.value)
287        } else if !instruction.imm_b && instruction.imm_c {
288            let (rd, rs1, imm) = instruction.i_type();
289            let (rd, b, c) = (rd, self.rr(rs1, MemoryAccessPosition::B), imm);
290
291            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
292            // `result`.
293            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Register(rs1, b)) };
294            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Immediate(c)) };
295
296            (rd, b.value, c)
297        } else {
298            debug_assert!(instruction.imm_b && instruction.imm_c);
299            let (rd, b, c) =
300                (Register::from_u8(instruction.op_a), instruction.op_b, instruction.op_c);
301
302            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
303            // `result`.
304            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Immediate(b)) };
305            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Immediate(c)) };
306
307            (rd, b, c)
308        };
309
310        let a = match instruction.opcode {
311            Opcode::ADD | Opcode::ADDI => (Wrapping(b) + Wrapping(c)).0,
312            Opcode::SUB => (Wrapping(b) - Wrapping(c)).0,
313            Opcode::XOR => b ^ c,
314            Opcode::OR => b | c,
315            Opcode::AND => b & c,
316            Opcode::SLL => b << (c & 0x3f),
317            Opcode::SRL => b >> (c & 0x3f),
318            Opcode::SRA => ((b as i64) >> (c & 0x3f)) as u64,
319            Opcode::SLT => {
320                if (b as i64) < (c as i64) {
321                    1
322                } else {
323                    0
324                }
325            }
326            Opcode::SLTU => {
327                if b < c {
328                    1
329                } else {
330                    0
331                }
332            }
333            Opcode::MUL => (Wrapping(b as i64) * Wrapping(c as i64)).0 as u64,
334            Opcode::MULH => (((b as i64) as i128).wrapping_mul((c as i64) as i128) >> 64) as u64,
335            Opcode::MULHU => ((b as u128 * c as u128) >> 64) as u64,
336            Opcode::MULHSU => ((((b as i64) as i128) * (c as i128)) >> 64) as u64,
337            Opcode::DIV => {
338                if c == 0 {
339                    u64::MAX
340                } else {
341                    (b as i64).wrapping_div(c as i64) as u64
342                }
343            }
344            Opcode::DIVU => {
345                if c == 0 {
346                    u64::MAX
347                } else {
348                    b / c
349                }
350            }
351            Opcode::REM => {
352                if c == 0 {
353                    b
354                } else {
355                    (b as i64).wrapping_rem(c as i64) as u64
356                }
357            }
358            Opcode::REMU => {
359                if c == 0 {
360                    b
361                } else {
362                    b % c
363                }
364            }
365            // RISCV-64 word operations
366            Opcode::ADDW => (Wrapping(b as i32) + Wrapping(c as i32)).0 as i64 as u64,
367            Opcode::SUBW => (Wrapping(b as i32) - Wrapping(c as i32)).0 as i64 as u64,
368            Opcode::MULW => (Wrapping(b as i32) * Wrapping(c as i32)).0 as i64 as u64,
369            Opcode::DIVW => {
370                if c as i32 == 0 {
371                    u64::MAX
372                } else {
373                    (b as i32).wrapping_div(c as i32) as i64 as u64
374                }
375            }
376            Opcode::DIVUW => {
377                if c as i32 == 0 {
378                    u64::MAX
379                } else {
380                    ((b as u32 / c as u32) as i32) as i64 as u64
381                }
382            }
383            Opcode::REMW => {
384                if c as i32 == 0 {
385                    (b as i32) as u64
386                } else {
387                    (b as i32).wrapping_rem(c as i32) as i64 as u64
388                }
389            }
390            Opcode::REMUW => {
391                if c as u32 == 0 {
392                    (b as i32) as u64
393                } else {
394                    (((b as u32) % (c as u32)) as i32) as i64 as u64
395                }
396            }
397            // RISCV-64 bit operations
398            Opcode::SLLW => (((b as i64) << (c & 0x1f)) as i32) as i64 as u64,
399            Opcode::SRLW => (((b as u32) >> ((c & 0x1f) as u32)) as i32) as u64,
400            Opcode::SRAW => {
401                (b as i32).wrapping_shr(((c as i64 & 0x1f) as i32) as u32) as i64 as u64
402            }
403            _ => unreachable!(),
404        };
405
406        let rw_record = self.rw(rd, a);
407
408        // SAFETY: We're writing to a valid pointer as we just created the pointer from the
409        // `result`.
410        unsafe { addr_of_mut!((*result_ptr).a).write(a) };
411        unsafe { addr_of_mut!((*result_ptr).b).write(b) };
412        unsafe { addr_of_mut!((*result_ptr).c).write(c) };
413        unsafe { addr_of_mut!((*result_ptr).rd).write(rd) };
414        unsafe { addr_of_mut!((*result_ptr).rw_record).write(rw_record) };
415
416        // SAFETY: All fields have been initialized by this point.
417        unsafe { result.assume_init() }
418    }
419
420    /// Execute a jump instruction.
421    pub fn execute_jump(&mut self, instruction: &Instruction) -> JumpResult {
422        match instruction.opcode {
423            Opcode::JAL => {
424                let (rd, imm) = instruction.j_type();
425                let imm_se = sign_extend_imm(imm, 21);
426                let a = self.pc.wrapping_add(4);
427                let rd_record = self.rw(rd, a);
428
429                let next_pc = ((self.pc as i64).wrapping_add(imm_se)) as u64;
430                let b = imm_se as u64;
431                let c = 0;
432
433                self.next_pc = next_pc;
434
435                JumpResult { a, b, c, rd, rd_record, rs1: MaybeImmediate::Immediate(b) }
436            }
437            Opcode::JALR => {
438                let (rd, rs1, c) = instruction.i_type();
439                let imm_se = sign_extend_imm(c, 12);
440                let b_record = self.rr(rs1, MemoryAccessPosition::B);
441                let a = self.pc.wrapping_add(4);
442
443                // Calculate next PC: (rs1 + imm) & ~1
444                let next_pc = ((b_record.value as i64).wrapping_add(imm_se) as u64) & !1_u64;
445                let rd_record = self.rw(rd, a);
446
447                self.next_pc = next_pc;
448
449                JumpResult {
450                    a,
451                    b: b_record.value,
452                    c,
453                    rd,
454                    rd_record,
455                    rs1: MaybeImmediate::Register(rs1, b_record),
456                }
457            }
458            _ => unreachable!("Invalid opcode for `execute_jump`: {:?}", instruction.opcode),
459        }
460    }
461
462    /// Execute a branch instruction.
463    pub fn execute_branch(&mut self, instruction: &Instruction) -> BranchResult {
464        let (rs1, rs2, imm) = instruction.b_type();
465
466        let c = imm;
467        let b_record = self.rr(rs2, MemoryAccessPosition::B);
468        let a_record = self.rr(rs1, MemoryAccessPosition::A);
469
470        let a = a_record.value;
471        let b = b_record.value;
472
473        let branch = match instruction.opcode {
474            Opcode::BEQ => a == b,
475            Opcode::BNE => a != b,
476            Opcode::BLT => (a as i64) < (b as i64),
477            Opcode::BGE => (a as i64) >= (b as i64),
478            Opcode::BLTU => a < b,
479            Opcode::BGEU => a >= b,
480            _ => {
481                unreachable!()
482            }
483        };
484
485        if branch {
486            self.next_pc = self.pc.wrapping_add(c);
487        }
488
489        BranchResult { a, rs1, a_record, b, rs2, b_record, c }
490    }
491
492    /// Execute a U-type instruction.
493    #[inline]
494    pub fn execute_utype(&mut self, instruction: &Instruction) -> UTypeResult {
495        let (rd, imm) = instruction.u_type();
496        let (b, c) = (imm, imm);
497        let a = if instruction.opcode == Opcode::AUIPC { self.pc.wrapping_add(imm) } else { imm };
498        let a_record = self.rw(rd, a);
499
500        UTypeResult { a, b, c, rd, rw_record: a_record }
501    }
502
503    #[inline]
504    /// Execute an ecall instruction.
505    ///
506    /// # WARNING:
507    ///
508    /// Its up to the syscall handler to update the shape checker abouut sent/internal ecalls.
509    pub fn execute_ecall<RT>(
510        rt: &mut RT,
511        instruction: &Instruction,
512        code: SyscallCode,
513    ) -> Result<EcallResult, ExecutionError>
514    where
515        RT: SyscallRuntime<'a>,
516    {
517        if !instruction.is_ecall_instruction() {
518            unreachable!("Invalid opcode for `execute_ecall`: {:?}", instruction.opcode);
519        }
520
521        let core = rt.core_mut();
522
523        let c_record = core.rr(Register::X11, MemoryAccessPosition::C);
524        let b_record = core.rr(Register::X10, MemoryAccessPosition::B);
525        let c = c_record.value;
526        let b = b_record.value;
527
528        // The only way unconstrained mode interacts with the parts of the program that proven is
529        // via hints, this means during tracing and splicing, we can just "skip" the whole
530        // set of unconstrained cycles, and rely on the fact that the hints are already
531        // apart of the minimal trace.
532        let a = if code == SyscallCode::ENTER_UNCONSTRAINED {
533            0
534        } else {
535            sp1_ecall_handler(rt, code, b, c).unwrap_or(code as u64)
536        };
537
538        // Bad borrow checker!
539        let core = rt.core_mut();
540
541        // Read the code from the x5 register.
542        let a_record = core.rw(Register::X5, a);
543
544        // Add 256 to the next clock to account for the ecall.
545        core.set_next_clk(core.next_clk() + 256);
546
547        Ok(EcallResult { a, a_record, b, b_record, c, c_record })
548    }
549
550    /// Peek to get the code from the x5 register.
551    #[must_use]
552    pub fn read_code(&self) -> SyscallCode {
553        // We peek at register x5 to get the syscall id. The reason we don't `self.rr` this
554        // register is that we write to it later.
555        let t0 = Register::X5;
556
557        // Peek at the register, we dont care about the read here.
558        let syscall_id = self.registers[t0 as usize].value;
559
560        // Convert the raw value to a SyscallCode.
561        SyscallCode::from_u32(syscall_id as u32)
562    }
563}
564
565impl CoreVM<'_> {
566    /// Read the next required memory read from the trace.
567    #[inline]
568    fn mr(&mut self, addr: u64) -> MemoryReadRecord {
569        #[allow(clippy::manual_let_else)]
570        let record = match self.mem_reads.next() {
571            Some(next) => next,
572            None => {
573                unreachable!("memory reads unexpectdely exhausted at {addr}, clk {}", self.clk);
574            }
575        };
576
577        MemoryReadRecord {
578            value: record.value,
579            timestamp: self.timestamp(MemoryAccessPosition::Memory),
580            prev_timestamp: record.clk,
581            prev_page_prot_record: None,
582        }
583    }
584
585    #[inline]
586    pub(crate) fn mr_slice_unsafe(&mut self, len: usize) -> Vec<u64> {
587        let mem_reads = self.mem_reads();
588
589        mem_reads.take(len).map(|value| value.value).collect()
590    }
591
592    #[inline]
593    pub(crate) fn mr_slice(&mut self, _addr: u64, len: usize) -> Vec<MemoryReadRecord> {
594        let current_clk = self.clk();
595        let mem_reads = self.mem_reads();
596
597        let records: Vec<MemoryReadRecord> = mem_reads
598            .take(len)
599            .map(|value| MemoryReadRecord {
600                value: value.value,
601                timestamp: current_clk,
602                prev_timestamp: value.clk,
603                prev_page_prot_record: None,
604            })
605            .collect();
606
607        records
608    }
609
610    #[inline]
611    pub(crate) fn mw_slice(&mut self, _addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
612        let mem_writes = self.mem_reads();
613
614        let raw_records: Vec<_> = mem_writes.take(len * 2).collect();
615        let records: Vec<MemoryWriteRecord> = raw_records
616            .chunks(2)
617            .map(|chunk| {
618                #[allow(clippy::manual_let_else)]
619                let (old, new) = match (chunk.first(), chunk.last()) {
620                    (Some(old), Some(new)) => (old, new),
621                    _ => unreachable!("Precompile memory write out of bounds"),
622                };
623
624                MemoryWriteRecord {
625                    prev_timestamp: old.clk,
626                    prev_value: old.value,
627                    timestamp: new.clk,
628                    value: new.value,
629                    prev_page_prot_record: None,
630                }
631            })
632            .collect();
633
634        records
635    }
636
637    #[inline]
638    fn mw(&mut self, read_record: MemoryReadRecord, value: u64) -> MemoryWriteRecord {
639        MemoryWriteRecord {
640            prev_timestamp: read_record.prev_timestamp,
641            prev_value: read_record.value,
642            timestamp: self.timestamp(MemoryAccessPosition::Memory),
643            value,
644            prev_page_prot_record: None,
645        }
646    }
647
648    /// Read a value from a register, updating the register entry and returning the record.
649    #[inline]
650    fn rr(&mut self, register: Register, position: MemoryAccessPosition) -> MemoryReadRecord {
651        let prev_record = self.registers[register as usize];
652        let new_record =
653            MemoryRecord { timestamp: self.timestamp(position), value: prev_record.value };
654
655        self.registers[register as usize] = new_record;
656
657        MemoryReadRecord {
658            value: new_record.value,
659            timestamp: new_record.timestamp,
660            prev_timestamp: prev_record.timestamp,
661            prev_page_prot_record: None,
662        }
663    }
664
665    /// Read a value from a register, updating the register entry and returning the record.
666    #[inline]
667    fn rr_precompile(&mut self, register: usize) -> MemoryReadRecord {
668        debug_assert!(register < 32, "out of bounds register: {register}");
669
670        let prev_record = self.registers[register];
671        let new_record = MemoryRecord { timestamp: self.clk(), value: prev_record.value };
672
673        self.registers[register] = new_record;
674
675        MemoryReadRecord {
676            value: new_record.value,
677            timestamp: new_record.timestamp,
678            prev_timestamp: prev_record.timestamp,
679            prev_page_prot_record: None,
680        }
681    }
682
683    /// Touch all the registers in the VM, bumping thier clock to `self.clk - 1`.
684    pub fn register_refresh(&mut self) -> [MemoryReadRecord; 32] {
685        fn bump_register(vm: &mut CoreVM, register: usize) -> MemoryReadRecord {
686            let prev_record = vm.registers[register];
687            let new_record = MemoryRecord { timestamp: vm.clk - 1, value: prev_record.value };
688
689            vm.registers[register] = new_record;
690
691            MemoryReadRecord {
692                value: new_record.value,
693                timestamp: new_record.timestamp,
694                prev_timestamp: prev_record.timestamp,
695                prev_page_prot_record: None,
696            }
697        }
698
699        tracing::trace!("register refresh to: {}", self.clk - 1);
700
701        let mut out = [MaybeUninit::uninit(); 32];
702        for (i, record) in out.iter_mut().enumerate() {
703            *record = MaybeUninit::new(bump_register(self, i));
704        }
705
706        // SAFETY: We're transmuting a [MaybeUninit<MemoryReadRecord>; 32] to a [MemoryReadRecord;
707        // 32], which we just initialized.
708        //
709        // These types are guaranteed to have the same representation.
710        unsafe { std::mem::transmute(out) }
711    }
712
713    /// Write a value to a register, updating the register entry and returning the record.
714    #[inline]
715    fn rw(&mut self, register: Register, value: u64) -> MemoryWriteRecord {
716        let value = if register == Register::X0 { 0 } else { value };
717
718        let prev_record = self.registers[register as usize];
719        let new_record = MemoryRecord { timestamp: self.timestamp(MemoryAccessPosition::A), value };
720
721        self.registers[register as usize] = new_record;
722
723        // if SHAPE_CHECKING {
724        //     self.shape_checker.handle_mem_event(register as u64, prev_record.timestamp);
725        // }
726
727        // if REPORT_GENERATING {
728        //     self.gas_calculator.handle_mem_event(register as u64, prev_record.timestamp);
729        // }
730
731        MemoryWriteRecord {
732            value: new_record.value,
733            timestamp: new_record.timestamp,
734            prev_timestamp: prev_record.timestamp,
735            prev_value: prev_record.value,
736            prev_page_prot_record: None,
737        }
738    }
739}
740
741impl CoreVM<'_> {
742    /// Get the current timestamp for a given memory access position.
743    #[inline]
744    #[must_use]
745    pub const fn timestamp(&self, position: MemoryAccessPosition) -> u64 {
746        self.clk + position as u64
747    }
748
749    /// Check if the top 24 bits have changed, which imply a `state bump` event needs to be emitted.
750    #[inline]
751    #[must_use]
752    pub const fn needs_bump_clk_high(&self) -> bool {
753        (self.next_clk() >> 24) ^ (self.clk() >> 24) > 0
754    }
755
756    /// Check if the state needs to be bumped, which implies a `state bump` event needs to be
757    /// emitted.
758    #[inline]
759    #[must_use]
760    pub const fn needs_state_bump(&self, instruction: &Instruction) -> bool {
761        let next_pc = self.next_pc();
762        let increment = self.next_clk() + 8 - self.clk();
763
764        let bump1 = self.clk() % (1 << 24) + increment >= (1 << 24);
765        let bump2 = !instruction.is_with_correct_next_pc()
766            && next_pc == self.pc().wrapping_add(4)
767            && (next_pc >> 16) != (self.pc() >> 16);
768
769        bump1 || bump2
770    }
771}
772
773impl<'a> CoreVM<'a> {
774    #[inline]
775    #[must_use]
776    /// Get the current clock, this clock is incremented by [`CLK_INC`] each cycle.
777    pub const fn clk(&self) -> u64 {
778        self.clk
779    }
780
781    #[inline]
782    /// Set the current clock.
783    pub fn set_clk(&mut self, new_clk: u64) {
784        self.clk = new_clk;
785    }
786
787    #[inline]
788    /// Set the next clock.
789    pub fn set_next_clk(&mut self, clk: u64) {
790        self.next_clk = clk;
791    }
792
793    #[inline]
794    #[must_use]
795    /// Get the global clock, this clock is incremented by 1 each cycle.
796    pub fn global_clk(&self) -> u64 {
797        self.global_clk
798    }
799
800    #[inline]
801    #[must_use]
802    /// Get the current program counter.
803    pub const fn pc(&self) -> u64 {
804        self.pc
805    }
806
807    #[inline]
808    #[must_use]
809    /// Get the next program counter that will be set in [`CoreVM::advance`].
810    pub const fn next_pc(&self) -> u64 {
811        self.next_pc
812    }
813
814    #[inline]
815    /// Set the next program counter.
816    pub fn set_next_pc(&mut self, pc: u64) {
817        self.next_pc = pc;
818    }
819
820    #[inline]
821    #[must_use]
822    /// Get the exit code.
823    pub fn exit_code(&self) -> u32 {
824        self.exit_code
825    }
826
827    #[inline]
828    /// Set the exit code.
829    pub fn set_exit_code(&mut self, exit_code: u32) {
830        self.exit_code = exit_code;
831    }
832
833    #[inline]
834    /// Set the program counter.
835    pub fn set_pc(&mut self, pc: u64) {
836        self.pc = pc;
837    }
838
839    #[inline]
840    /// Set the global clock.
841    pub fn set_global_clk(&mut self, global_clk: u64) {
842        self.global_clk = global_clk;
843    }
844
845    #[inline]
846    #[must_use]
847    /// Get the next clock that will be set in [`CoreVM::advance`].
848    pub const fn next_clk(&self) -> u64 {
849        self.next_clk
850    }
851
852    #[inline]
853    #[must_use]
854    /// Get the current registers (immutable).
855    pub fn registers(&self) -> &[MemoryRecord; 32] {
856        &self.registers
857    }
858
859    #[inline]
860    #[must_use]
861    /// Get the current registers (mutable).
862    pub fn registers_mut(&mut self) -> &mut [MemoryRecord; 32] {
863        &mut self.registers
864    }
865
866    #[inline]
867    /// Get the memory reads iterator.
868    pub fn mem_reads(&mut self) -> &mut MemReads<'a> {
869        &mut self.mem_reads
870    }
871
872    /// Check if the syscall is retained.
873    #[inline]
874    #[must_use]
875    pub fn is_retained_syscall(&self, syscall_code: SyscallCode) -> bool {
876        self.retained_syscall_codes.contains(&syscall_code)
877    }
878
879    /// Check if the trace has ended.
880    #[inline]
881    #[must_use]
882    pub const fn is_trace_end(&self) -> bool {
883        self.clk_end == self.clk()
884    }
885
886    /// Check if the program has halted.
887    #[inline]
888    #[must_use]
889    pub const fn is_done(&self) -> bool {
890        self.pc() == HALT_PC
891    }
892}
893
894fn sign_extend_imm(value: u64, bits: u8) -> i64 {
895    let shift = 64 - bits;
896    ((value as i64) << shift) >> shift
897}