Skip to main content

sp1_core_executor/
vm.rs

1#![allow(unknown_lints)]
2#![allow(clippy::manual_checked_ops)]
3
4use crate::{
5    disassembler::InstructionTranspiler,
6    events::{MemoryAccessPosition, MemoryReadRecord, MemoryRecord, MemoryWriteRecord},
7    vm::{
8        results::{
9            AluResult, BranchResult, CycleResult, EcallResult, FetchResult, JumpResult, LoadResult,
10            MaybeImmediate, StoreResult, TrapResult, UTypeResult,
11        },
12        syscall::{sp1_ecall_handler, SyscallRuntime},
13    },
14    ExecutionError, ExecutionMode, Instruction, Opcode, Program, Register, RetainedEventsPreset,
15    SP1CoreOpts, SupervisorMode, SyscallCode, TrapError, UserMode, CLK_INC as CLK_INC_32, HALT_PC,
16    PC_INC as PC_INC_32,
17};
18use hashbrown::HashMap;
19use results::{LoadResultSupervisor, StoreResultSupervisor};
20use rrs_lib::process_instruction;
21use std::{marker::PhantomData, mem::MaybeUninit, num::Wrapping, ptr::addr_of_mut, sync::Arc};
22
23use sp1_hypercube::air::{PROOF_NONCE_NUM_WORDS, PV_DIGEST_NUM_WORDS};
24use sp1_jit::{MemReads, MinimalTrace};
25use sp1_primitives::consts::{LOG_PAGE_SIZE, PROT_EXEC, PROT_READ, PROT_WRITE};
26
27pub(crate) mod gas;
28pub(crate) mod memory;
29pub(crate) mod results;
30pub(crate) mod shapes;
31pub(crate) mod syscall;
32
33const CLK_INC: u64 = CLK_INC_32 as u64;
34const PC_INC: u64 = PC_INC_32 as u64;
35
36/// A RISC-V VM that uses a [`MinimalTrace`] to oracle memory & page permission access.
37///
38/// The type parameter `M` determines whether page protection checks are enabled.
39pub struct CoreVM<'a, M: ExecutionMode> {
40    registers: [MemoryRecord; 32],
41    /// The current clock of the VM.
42    clk: u64,
43    /// The global clock of the VM.
44    global_clk: u64,
45    /// The current program counter of the VM.
46    pc: u64,
47    /// The current exit code of the VM.
48    exit_code: u32,
49    /// The memory reads cursor.
50    pub mem_reads: MemReads<'a>,
51    /// The next program counter that will be set in [`CoreVM::advance`].
52    next_pc: u64,
53    /// The next clock that will be set in [`CoreVM::advance`].
54    next_clk: u64,
55    /// The program that is being executed.
56    pub program: Arc<Program>,
57    /// The syscalls that are not marked as external, ie. they stay in the same shard.
58    pub(crate) retained_syscall_codes: Vec<SyscallCode>,
59    /// The options to configure the VM, mostly for syscall / shard handling.
60    pub opts: SP1CoreOpts,
61    /// The end clk of the trace chunk.
62    pub clk_end: u64,
63    /// The public value digest.
64    pub public_value_digest: [u32; PV_DIGEST_NUM_WORDS],
65    /// The nonce associated with the proof.
66    pub proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
67    /// The transpiler of the program.
68    transpiler: InstructionTranspiler,
69    /// Decoded instruction cache.
70    decoded_instruction_cache: HashMap<u32, Instruction>,
71    /// Phantom data for the execution mode.
72    _mode: PhantomData<M>,
73}
74
75impl<'a, M: ExecutionMode> CoreVM<'a, M> {
76    /// Create a [`CoreVM`] from a [`MinimalTrace`] and a [`Program`].
77    pub fn new<T: MinimalTrace>(
78        trace: &'a T,
79        program: Arc<Program>,
80        opts: SP1CoreOpts,
81        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
82    ) -> Self {
83        let start_clk = trace.clk_start();
84
85        // SAFETY: We're mapping a [T; 32] -> [T; 32] infallibly.
86        let registers = unsafe {
87            trace
88                .start_registers()
89                .into_iter()
90                .map(|v| MemoryRecord { timestamp: start_clk - 1, value: v })
91                .collect::<Vec<_>>()
92                .try_into()
93                .unwrap_unchecked()
94        };
95        let start_pc = trace.pc_start();
96
97        let retained_syscall_codes = opts
98            .retained_events_presets
99            .iter()
100            .flat_map(RetainedEventsPreset::syscall_codes)
101            .copied()
102            .collect();
103
104        tracing::trace!("start_clk: {}", start_clk);
105        tracing::trace!("start_pc: {}", start_pc);
106        tracing::trace!("trace.clk_end(): {}", trace.clk_end());
107        tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
108        tracing::trace!("trace.start_registers(): {:?}", trace.start_registers());
109
110        if trace.clk_start() == 1 {
111            assert_eq!(trace.pc_start(), program.pc_start_abs);
112        }
113
114        Self {
115            registers,
116            global_clk: 0,
117            clk: start_clk,
118            pc: start_pc,
119            program,
120            mem_reads: trace.mem_reads(),
121            next_pc: start_pc.wrapping_add(PC_INC),
122            next_clk: start_clk.wrapping_add(CLK_INC),
123            exit_code: 0,
124            retained_syscall_codes,
125            opts,
126            clk_end: trace.clk_end(),
127            public_value_digest: [0; PV_DIGEST_NUM_WORDS],
128            proof_nonce,
129            transpiler: InstructionTranspiler,
130            decoded_instruction_cache: HashMap::new(),
131            _mode: PhantomData,
132        }
133    }
134
135    #[inline]
136    /// Certain execution errors could be handled internally. For example,
137    /// when trapping is enabled, page permission faults simply traps. This
138    /// method shall be called before `advance` to give the VM a chance to
139    /// handle some errors.
140    pub fn handle_error(&mut self, e: TrapError) -> Result<TrapResult, ExecutionError> {
141        #[allow(irrefutable_let_patterns)]
142        if let TrapError::PagePermissionViolation(code) = e {
143            if let Some(trap_context_address) = self.program.trap_context {
144                // As discussed in MinimalExecutor, page permissions are ignored
145                // when handling traps
146                let handler_record = self.mr_without_prot(trap_context_address);
147                self.next_pc = handler_record.value;
148                let code_record = self.mw_without_prot(trap_context_address + 8);
149                assert_eq!(code_record.value, code);
150                let pc_record = self.mw_without_prot(trap_context_address + 16);
151
152                return Ok(TrapResult {
153                    context: trap_context_address,
154                    code_record,
155                    pc_record,
156                    handler_record,
157                });
158            }
159        }
160        Err(ExecutionError::UnhandledTrap(e))
161    }
162
163    #[inline]
164    /// Increment the state of the VM by one cycle.
165    /// Calling this method will update the pc and the clk to the next cycle.
166    pub fn advance(&mut self) -> CycleResult {
167        self.clk = self.next_clk;
168        self.pc = self.next_pc;
169
170        // Reset the next_clk and next_pc to the next cycle.
171        self.next_clk = self.clk.wrapping_add(CLK_INC);
172        self.next_pc = self.pc.wrapping_add(PC_INC);
173        self.global_clk = self.global_clk.wrapping_add(1);
174
175        // Check if the program has halted.
176        if self.pc == HALT_PC {
177            return CycleResult::Done(true);
178        }
179
180        // Check if the shard limit has been reached.
181        if self.is_trace_end() {
182            return CycleResult::TraceEnd;
183        }
184
185        // Return that the program is still running.
186        CycleResult::Done(false)
187    }
188
189    /// Execute an ALU instruction.
190    #[inline]
191    #[allow(clippy::too_many_lines)]
192    pub fn execute_alu(&mut self, instruction: &Instruction) -> AluResult {
193        let mut result = MaybeUninit::<AluResult>::uninit();
194        let result_ptr = result.as_mut_ptr();
195
196        let (rd, b, c) = if !instruction.imm_c {
197            let (rd, rs1, rs2) = instruction.r_type();
198            let c = self.rr(rs2, MemoryAccessPosition::C);
199            let b = self.rr(rs1, MemoryAccessPosition::B);
200
201            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
202            // `result`.
203            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Register(rs1, b)) };
204            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Register(rs2, c)) };
205
206            (rd, b.value, c.value)
207        } else if !instruction.imm_b && instruction.imm_c {
208            let (rd, rs1, imm) = instruction.i_type();
209            let (rd, b, c) = (rd, self.rr(rs1, MemoryAccessPosition::B), imm);
210
211            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
212            // `result`.
213            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Register(rs1, b)) };
214            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Immediate(c)) };
215
216            (rd, b.value, c)
217        } else {
218            debug_assert!(instruction.imm_b && instruction.imm_c);
219            let (rd, b, c) =
220                (Register::from_u8(instruction.op_a), instruction.op_b, instruction.op_c);
221
222            // SAFETY: We're writing to a valid pointer as we just created the pointer from the
223            // `result`.
224            unsafe { addr_of_mut!((*result_ptr).rs1).write(MaybeImmediate::Immediate(b)) };
225            unsafe { addr_of_mut!((*result_ptr).rs2).write(MaybeImmediate::Immediate(c)) };
226
227            (rd, b, c)
228        };
229
230        let a = match instruction.opcode {
231            Opcode::ADD | Opcode::ADDI => (Wrapping(b) + Wrapping(c)).0,
232            Opcode::SUB => (Wrapping(b) - Wrapping(c)).0,
233            Opcode::XOR => b ^ c,
234            Opcode::OR => b | c,
235            Opcode::AND => b & c,
236            Opcode::SLL => b << (c & 0x3f),
237            Opcode::SRL => b >> (c & 0x3f),
238            Opcode::SRA => ((b as i64) >> (c & 0x3f)) as u64,
239            Opcode::SLT => {
240                if (b as i64) < (c as i64) {
241                    1
242                } else {
243                    0
244                }
245            }
246            Opcode::SLTU => {
247                if b < c {
248                    1
249                } else {
250                    0
251                }
252            }
253            Opcode::MUL => (Wrapping(b as i64) * Wrapping(c as i64)).0 as u64,
254            Opcode::MULH => (((b as i64) as i128).wrapping_mul((c as i64) as i128) >> 64) as u64,
255            Opcode::MULHU => ((b as u128 * c as u128) >> 64) as u64,
256            Opcode::MULHSU => ((((b as i64) as i128) * (c as i128)) >> 64) as u64,
257            Opcode::DIV => {
258                if c == 0 {
259                    u64::MAX
260                } else {
261                    (b as i64).wrapping_div(c as i64) as u64
262                }
263            }
264            Opcode::DIVU => {
265                if c == 0 {
266                    u64::MAX
267                } else {
268                    b / c
269                }
270            }
271            Opcode::REM => {
272                if c == 0 {
273                    b
274                } else {
275                    (b as i64).wrapping_rem(c as i64) as u64
276                }
277            }
278            Opcode::REMU => {
279                if c == 0 {
280                    b
281                } else {
282                    b % c
283                }
284            }
285            // RISCV-64 word operations
286            Opcode::ADDW => (Wrapping(b as i32) + Wrapping(c as i32)).0 as i64 as u64,
287            Opcode::SUBW => (Wrapping(b as i32) - Wrapping(c as i32)).0 as i64 as u64,
288            Opcode::MULW => (Wrapping(b as i32) * Wrapping(c as i32)).0 as i64 as u64,
289            Opcode::DIVW => {
290                if c as i32 == 0 {
291                    u64::MAX
292                } else {
293                    (b as i32).wrapping_div(c as i32) as i64 as u64
294                }
295            }
296            Opcode::DIVUW => {
297                if c as i32 == 0 {
298                    u64::MAX
299                } else {
300                    ((b as u32 / c as u32) as i32) as i64 as u64
301                }
302            }
303            Opcode::REMW => {
304                if c as i32 == 0 {
305                    (b as i32) as u64
306                } else {
307                    (b as i32).wrapping_rem(c as i32) as i64 as u64
308                }
309            }
310            Opcode::REMUW => {
311                if c as u32 == 0 {
312                    (b as i32) as u64
313                } else {
314                    (((b as u32) % (c as u32)) as i32) as i64 as u64
315                }
316            }
317            // RISCV-64 bit operations
318            Opcode::SLLW => (((b as i64) << (c & 0x1f)) as i32) as i64 as u64,
319            Opcode::SRLW => (((b as u32) >> ((c & 0x1f) as u32)) as i32) as u64,
320            Opcode::SRAW => {
321                (b as i32).wrapping_shr(((c as i64 & 0x1f) as i32) as u32) as i64 as u64
322            }
323            _ => unreachable!(),
324        };
325
326        let rw_record = self.rw(rd, a);
327
328        // SAFETY: We're writing to a valid pointer as we just created the pointer from the
329        // `result`.
330        unsafe { addr_of_mut!((*result_ptr).a).write(a) };
331        unsafe { addr_of_mut!((*result_ptr).b).write(b) };
332        unsafe { addr_of_mut!((*result_ptr).c).write(c) };
333        unsafe { addr_of_mut!((*result_ptr).rd).write(rd) };
334        unsafe { addr_of_mut!((*result_ptr).rw_record).write(rw_record) };
335
336        // SAFETY: All fields have been initialized by this point.
337        unsafe { result.assume_init() }
338    }
339
340    /// Execute a jump instruction.
341    pub fn execute_jump(&mut self, instruction: &Instruction) -> JumpResult {
342        match instruction.opcode {
343            Opcode::JAL => {
344                let (rd, imm) = instruction.j_type();
345                let imm_se = sign_extend_imm(imm, 21);
346                let a = self.pc.wrapping_add(4);
347                let rd_record = self.rw(rd, a);
348
349                let next_pc = ((self.pc as i64).wrapping_add(imm_se)) as u64;
350                let b = imm_se as u64;
351                let c = 0;
352
353                self.next_pc = next_pc;
354
355                JumpResult { a, b, c, rd, rd_record, rs1: MaybeImmediate::Immediate(b) }
356            }
357            Opcode::JALR => {
358                let (rd, rs1, c) = instruction.i_type();
359                let imm_se = sign_extend_imm(c, 12);
360                let b_record = self.rr(rs1, MemoryAccessPosition::B);
361                let a = self.pc.wrapping_add(4);
362
363                // Calculate next PC: (rs1 + imm) & ~1
364                let next_pc = ((b_record.value as i64).wrapping_add(imm_se) as u64) & !1_u64;
365                let rd_record = self.rw(rd, a);
366
367                self.next_pc = next_pc;
368
369                JumpResult {
370                    a,
371                    b: b_record.value,
372                    c,
373                    rd,
374                    rd_record,
375                    rs1: MaybeImmediate::Register(rs1, b_record),
376                }
377            }
378            _ => unreachable!("Invalid opcode for `execute_jump`: {:?}", instruction.opcode),
379        }
380    }
381
382    /// Execute a branch instruction.
383    pub fn execute_branch(&mut self, instruction: &Instruction) -> BranchResult {
384        let (rs1, rs2, imm) = instruction.b_type();
385
386        let c = imm;
387        let b_record = self.rr(rs2, MemoryAccessPosition::B);
388        let a_record = self.rr(rs1, MemoryAccessPosition::A);
389
390        let a = a_record.value;
391        let b = b_record.value;
392
393        let branch = match instruction.opcode {
394            Opcode::BEQ => a == b,
395            Opcode::BNE => a != b,
396            Opcode::BLT => (a as i64) < (b as i64),
397            Opcode::BGE => (a as i64) >= (b as i64),
398            Opcode::BLTU => a < b,
399            Opcode::BGEU => a >= b,
400            _ => {
401                unreachable!()
402            }
403        };
404
405        if branch {
406            self.next_pc = self.pc.wrapping_add(c);
407        }
408
409        BranchResult { a, rs1, a_record, b, rs2, b_record, c }
410    }
411
412    /// Execute a U-type instruction.
413    #[inline]
414    pub fn execute_utype(&mut self, instruction: &Instruction) -> UTypeResult {
415        let (rd, imm) = instruction.u_type();
416        let (b, c) = (imm, imm);
417        let a = if instruction.opcode == Opcode::AUIPC { self.pc.wrapping_add(imm) } else { imm };
418        let a_record = self.rw(rd, a);
419
420        UTypeResult { a, b, c, rd, rw_record: a_record }
421    }
422
423    #[inline]
424    /// Execute an ecall instruction.
425    ///
426    /// # WARNING:
427    ///
428    /// Its up to the syscall handler to update the shape checker abouut sent/internal ecalls.
429    pub fn execute_ecall<RT>(
430        rt: &mut RT,
431        instruction: &Instruction,
432        _code: SyscallCode,
433    ) -> Result<EcallResult, ExecutionError>
434    where
435        RT: SyscallRuntime<'a, M>,
436    {
437        if !instruction.is_ecall_instruction() {
438            unreachable!("Invalid opcode for `execute_ecall`: {:?}", instruction.opcode);
439        }
440
441        let core = rt.core_mut();
442
443        // We peek at register x5 to get the syscall id. The reason we don't `self.rr` this
444        // register is that we write to it later.
445        let t0 = Register::X5;
446        // Peek at the register, we dont care about the read here.
447        let syscall_id = core.registers[t0 as usize].value;
448        let code = SyscallCode::from_u32(syscall_id as u32);
449
450        let c_record = core.rr(Register::X11, MemoryAccessPosition::C);
451        let b_record = core.rr(Register::X10, MemoryAccessPosition::B);
452        let c = c_record.value;
453        let b = b_record.value;
454
455        let is_sigreturn = code == SyscallCode::SIG_RETURN;
456
457        let mut a_record: MemoryWriteRecord = MemoryWriteRecord::default();
458        if is_sigreturn {
459            a_record = core.rw(Register::X5, syscall_id);
460        }
461
462        let sig_return_pc_record = if is_sigreturn {
463            let record = core.mr_without_prot(b);
464            core.set_next_pc(record.value);
465            Some(record)
466        } else {
467            None
468        };
469
470        // The only way unconstrained mode interacts with the parts of the program that proven is
471        // via hints, this means during tracing and splicing, we can just "skip" the whole
472        // set of unconstrained cycles, and rely on the fact that the hints are already
473        // apart of the minimal trace.
474        let (a, error) = if code == SyscallCode::ENTER_UNCONSTRAINED {
475            (0, None)
476        } else {
477            let result = sp1_ecall_handler(rt, code, b, c);
478            if is_sigreturn {
479                (code as u64, None)
480            } else if let Ok(ret) = result {
481                (ret.unwrap_or(code as u64), None)
482            } else {
483                (code as u64, result.err())
484            }
485        };
486
487        // Bad borrow checker!
488        let core = rt.core_mut();
489
490        if !is_sigreturn {
491            a_record = core.rw(Register::X5, a);
492        }
493
494        // Add 256 to the next clock to account for the ecall.
495        core.set_next_clk(core.next_clk() + 256);
496
497        Ok(EcallResult { a, a_record, b, b_record, c, c_record, error, sig_return_pc_record })
498    }
499
500    /// Peek to get the code from the x5 register.
501    #[must_use]
502    pub fn read_code(&self) -> SyscallCode {
503        // We peek at register x5 to get the syscall id. The reason we don't `self.rr` this
504        // register is that we write to it later.
505        let t0 = Register::X5;
506
507        // Peek at the register, we dont care about the read here.
508        let syscall_id = self.registers[t0 as usize].value;
509
510        // Convert the raw value to a SyscallCode.
511        SyscallCode::from_u32(syscall_id as u32)
512    }
513
514    /// Compute the value to load based on opcode, address, and memory word.
515    #[allow(clippy::inline_always)]
516    #[inline(always)]
517    fn compute_load_value(opcode: Opcode, addr: u64, word: u64) -> Result<u64, ExecutionError> {
518        match opcode {
519            Opcode::LB => Ok(((word >> ((addr % 8) * 8)) & 0xFF) as i8 as i64 as u64),
520            Opcode::LH => {
521                if !addr.is_multiple_of(2) {
522                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LH, addr));
523                }
524                Ok(((word >> (((addr / 2) % 4) * 16)) & 0xFFFF) as i16 as i64 as u64)
525            }
526            Opcode::LW => {
527                if !addr.is_multiple_of(4) {
528                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LW, addr));
529                }
530                Ok(((word >> (((addr / 4) % 2) * 32)) & 0xFFFFFFFF) as i32 as u64)
531            }
532            Opcode::LBU => Ok(((word >> ((addr % 8) * 8)) & 0xFF) as u8 as u64),
533            Opcode::LHU => {
534                if !addr.is_multiple_of(2) {
535                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LHU, addr));
536                }
537                Ok(((word >> (((addr / 2) % 4) * 16)) & 0xFFFF) as u16 as u64)
538            }
539            Opcode::LWU => {
540                if !addr.is_multiple_of(4) {
541                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LWU, addr));
542                }
543                Ok((word >> (((addr / 4) % 2) * 32)) & 0xFFFFFFFF)
544            }
545            Opcode::LD => {
546                if !addr.is_multiple_of(8) {
547                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::LD, addr));
548                }
549                Ok(word)
550            }
551            _ => unreachable!("Invalid opcode for `compute_load_value`: {:?}", opcode),
552        }
553    }
554
555    /// Compute the value to store based on opcode, source value, address, and current memory word.
556    #[allow(clippy::inline_always)]
557    #[inline(always)]
558    fn compute_store_value(
559        opcode: Opcode,
560        src: u64,
561        addr: u64,
562        word: u64,
563    ) -> Result<u64, ExecutionError> {
564        match opcode {
565            Opcode::SB => {
566                let shift = (addr % 8) * 8;
567                Ok(((src & 0xFF) << shift) | (word & !(0xFF << shift)))
568            }
569            Opcode::SH => {
570                if !addr.is_multiple_of(2) {
571                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SH, addr));
572                }
573                let shift = ((addr / 2) % 4) * 16;
574                Ok(((src & 0xFFFF) << shift) | (word & !(0xFFFF << shift)))
575            }
576            Opcode::SW => {
577                if !addr.is_multiple_of(4) {
578                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SW, addr));
579                }
580                let shift = ((addr / 4) % 2) * 32;
581                Ok(((src & 0xFFFFFFFF) << shift) | (word & !(0xFFFFFFFF << shift)))
582            }
583            Opcode::SD => {
584                if !addr.is_multiple_of(8) {
585                    return Err(ExecutionError::InvalidMemoryAccess(Opcode::SD, addr));
586                }
587                Ok(src)
588            }
589            _ => unreachable!("Invalid opcode for `compute_store_value`: {:?}", opcode),
590        }
591    }
592}
593
594impl CoreVM<'_, SupervisorMode> {
595    /// Fetch the next instruction from the program.
596    #[inline]
597    pub fn fetch(&mut self) -> Instruction {
598        *self.program.fetch(self.pc).unwrap()
599    }
600
601    #[allow(clippy::inline_always)]
602    #[inline(always)]
603    fn mr(&mut self, addr: u64) -> MemoryReadRecord {
604        #[allow(clippy::manual_let_else)]
605        let record = match self.mem_reads.next() {
606            Some(next) => next,
607            None => {
608                unreachable!("memory reads unexpectedly exhausted at {addr}, clk {}", self.clk());
609            }
610        };
611
612        MemoryReadRecord {
613            value: record.value,
614            timestamp: self.timestamp(MemoryAccessPosition::Memory),
615            prev_timestamp: record.clk,
616            prev_page_prot_record: None,
617        }
618    }
619
620    #[allow(clippy::inline_always)]
621    #[inline(always)]
622    fn mw(&mut self, read_record: MemoryReadRecord, value: u64) -> MemoryWriteRecord {
623        MemoryWriteRecord {
624            prev_timestamp: read_record.prev_timestamp,
625            prev_value: read_record.value,
626            timestamp: self.timestamp(MemoryAccessPosition::Memory),
627            value,
628            prev_page_prot_record: None,
629        }
630    }
631
632    /// Execute a load instruction.
633    #[inline]
634    pub fn execute_load(
635        &mut self,
636        instruction: &Instruction,
637    ) -> Result<LoadResultSupervisor, ExecutionError> {
638        let (rd, rs1, imm) = instruction.i_type();
639
640        let rr_record = self.rr(rs1, MemoryAccessPosition::B);
641        let b = rr_record.value;
642
643        // Compute the address.
644        let addr = b.wrapping_add(imm);
645        let mr_record = self.mr(addr);
646        let word = mr_record.value;
647
648        let a = Self::compute_load_value(instruction.opcode, addr, word)?;
649        let rw_record = self.rw(rd, a);
650
651        Ok(LoadResultSupervisor { a, b, c: imm, addr, rs1, rd, rr_record, rw_record, mr_record })
652    }
653
654    /// Execute a store instruction.
655    #[inline]
656    pub fn execute_store(
657        &mut self,
658        instruction: &Instruction,
659    ) -> Result<StoreResultSupervisor, ExecutionError> {
660        let (rs1, rs2, imm) = instruction.s_type();
661
662        let c = imm;
663        let rs2_record = self.rr(rs2, MemoryAccessPosition::B);
664        let rs1_record = self.rr(rs1, MemoryAccessPosition::A);
665
666        let b = rs2_record.value;
667        let a = rs1_record.value;
668        let addr = b.wrapping_add(c);
669        let mr_record = self.mr(addr);
670        let word = mr_record.value;
671
672        let memory_store_value = Self::compute_store_value(instruction.opcode, a, addr, word)?;
673        let mw_record = self.mw(mr_record, memory_store_value);
674
675        Ok(StoreResultSupervisor { a, b, c, addr, rs1, rs1_record, rs2, rs2_record, mw_record })
676    }
677}
678
679impl CoreVM<'_, UserMode> {
680    /// Fetch the next instruction from the program.
681    #[inline]
682    pub fn fetch(&mut self) -> Result<FetchResult, ExecutionError> {
683        if let Some(instruction) = self.program.fetch(self.pc) {
684            Ok(FetchResult {
685                pc: self.pc,
686                instruction: Some(*instruction),
687                mr_record: None,
688                error: None,
689            })
690        } else {
691            let aligned_pc = self.pc & !0b111;
692            let (mr_record, error) = self.mr_instr(
693                aligned_pc,
694                PROT_READ | PROT_EXEC,
695                MemoryAccessPosition::UntrustedInstruction,
696            );
697            if error.is_some() {
698                return Ok(FetchResult {
699                    pc: self.pc,
700                    instruction: None,
701                    mr_record: Some(mr_record),
702                    error,
703                });
704            }
705            let word = mr_record.value;
706            if !self.pc.is_multiple_of(4) {
707                return Err(ExecutionError::InvalidMemoryAccessUntrustedProgram(self.pc));
708            }
709            let aligned_offset = self.pc - aligned_pc;
710            let instruction_value: u32 =
711                (word >> (aligned_offset * 8) & 0xffffffff).try_into().unwrap();
712            let instruction = if let Some(cached_instruction) =
713                self.decoded_instruction_cache.get(&instruction_value)
714            {
715                *cached_instruction
716            } else {
717                let instruction =
718                    process_instruction(&mut self.transpiler, instruction_value).unwrap();
719                self.decoded_instruction_cache.insert(instruction_value, instruction);
720                instruction
721            };
722            Ok(FetchResult {
723                pc: self.pc,
724                instruction: Some(instruction),
725                mr_record: Some(mr_record),
726                error: None,
727            })
728        }
729    }
730
731    /// Execute a load instruction.
732    #[inline]
733    pub fn execute_load(
734        &mut self,
735        instruction: &Instruction,
736    ) -> Result<LoadResult, ExecutionError> {
737        let (rd, rs1, imm) = instruction.i_type();
738
739        let rr_record = self.rr(rs1, MemoryAccessPosition::B);
740        let b = rr_record.value;
741
742        // Compute the address.
743        let addr = b.wrapping_add(imm);
744        let (mr_record, error) = self.mr_instr(addr, PROT_READ, MemoryAccessPosition::Memory);
745        let word = mr_record.value;
746
747        let mut a = Self::compute_load_value(instruction.opcode, addr, word)?;
748
749        // If there is a trap, then the write to `op_a` is a no-op, so write the original value.
750        if error.is_some() {
751            a = self.registers[rd as usize].value;
752        }
753
754        let rw_record = self.rw(rd, a);
755
756        Ok(LoadResult { a, b, c: imm, addr, rs1, rd, rr_record, rw_record, mr_record, error })
757    }
758
759    /// Execute a store instruction.
760    #[inline]
761    pub fn execute_store(
762        &mut self,
763        instruction: &Instruction,
764    ) -> Result<StoreResult, ExecutionError> {
765        let (rs1, rs2, imm) = instruction.s_type();
766
767        let c = imm;
768        let rs2_record = self.rr(rs2, MemoryAccessPosition::B);
769        let rs1_record = self.rr(rs1, MemoryAccessPosition::A);
770
771        let b = rs2_record.value;
772        let a = rs1_record.value;
773        let addr = b.wrapping_add(c);
774        let (mut mw_record, error) = self.mw_instr(addr, MemoryAccessPosition::Memory);
775        let word = mw_record.prev_value;
776
777        let memory_store_value = Self::compute_store_value(instruction.opcode, a, addr, word)?;
778        mw_record.value = memory_store_value;
779
780        Ok(StoreResult { a, b, c, addr, rs1, rs1_record, rs2, rs2_record, mw_record, error })
781    }
782
783    #[inline]
784    fn mr_instr(
785        &mut self,
786        addr: u64,
787        page_prot_bitmap: u8,
788        position: MemoryAccessPosition,
789    ) -> (MemoryReadRecord, Option<TrapError>) {
790        let (prev_page_prot_record, error) =
791            self.page_prot_check(addr >> LOG_PAGE_SIZE, page_prot_bitmap);
792
793        if error.is_some() {
794            return (
795                MemoryReadRecord {
796                    value: 0,
797                    timestamp: self.timestamp(position),
798                    prev_timestamp: 0,
799                    prev_page_prot_record,
800                },
801                error,
802            );
803        }
804
805        #[allow(clippy::manual_let_else)]
806        let record = match self.mem_reads.next() {
807            Some(next) => next,
808            None => {
809                unreachable!("memory reads unexpectdely exhausted at {addr}, clk {}", self.clk());
810            }
811        };
812
813        (
814            MemoryReadRecord {
815                value: record.value,
816                timestamp: self.timestamp(position),
817                prev_timestamp: record.clk,
818                prev_page_prot_record,
819            },
820            None,
821        )
822    }
823
824    #[inline]
825    fn mw_instr(
826        &mut self,
827        addr: u64,
828        position: MemoryAccessPosition,
829    ) -> (MemoryWriteRecord, Option<TrapError>) {
830        let (prev_page_prot_record, error) =
831            self.page_prot_check(addr >> LOG_PAGE_SIZE, PROT_WRITE);
832
833        if error.is_some() {
834            return (
835                MemoryWriteRecord {
836                    prev_timestamp: 0,
837                    prev_value: 0,
838                    timestamp: self.timestamp(position),
839                    value: 0,
840                    prev_page_prot_record,
841                },
842                error,
843            );
844        }
845
846        let mem_writes = self.core_mut().mem_reads();
847        let old = mem_writes.next().expect("Precompile memory read out of bounds");
848
849        (
850            MemoryWriteRecord {
851                prev_timestamp: old.clk,
852                prev_value: old.value,
853                timestamp: self.timestamp(position),
854                // This will be updated in execute_store
855                value: 0,
856                prev_page_prot_record,
857            },
858            None,
859        )
860    }
861}
862
863impl<M: ExecutionMode> CoreVM<'_, M> {
864    /// Read a value from a register, updating the register entry and returning the record.
865    #[inline]
866    fn rr(&mut self, register: Register, position: MemoryAccessPosition) -> MemoryReadRecord {
867        let prev_record = self.registers[register as usize];
868        let new_record =
869            MemoryRecord { timestamp: self.timestamp(position), value: prev_record.value };
870
871        self.registers[register as usize] = new_record;
872
873        MemoryReadRecord {
874            value: new_record.value,
875            timestamp: new_record.timestamp,
876            prev_timestamp: prev_record.timestamp,
877            prev_page_prot_record: None,
878        }
879    }
880
881    /// Read a value from a register, updating the register entry and returning the record.
882    #[inline]
883    #[must_use]
884    pub fn rr_peek(&self, register: Register, position: MemoryAccessPosition) -> MemoryReadRecord {
885        let prev_record = self.registers[register as usize];
886        let new_record =
887            MemoryRecord { timestamp: self.timestamp(position), value: prev_record.value };
888
889        MemoryReadRecord {
890            value: new_record.value,
891            timestamp: new_record.timestamp,
892            prev_timestamp: prev_record.timestamp,
893            prev_page_prot_record: None,
894        }
895    }
896
897    /// Touch all the registers in the VM, bumping their clock to `self.clk - 1`.
898    pub fn register_refresh(&mut self) -> [MemoryReadRecord; 32] {
899        #[inline]
900        fn bump_register<N: ExecutionMode>(
901            vm: &mut CoreVM<N>,
902            register: usize,
903        ) -> MemoryReadRecord {
904            let prev_record = vm.registers[register];
905            let new_record = MemoryRecord { timestamp: vm.clk - 1, value: prev_record.value };
906
907            vm.registers[register] = new_record;
908
909            MemoryReadRecord {
910                value: new_record.value,
911                timestamp: new_record.timestamp,
912                prev_timestamp: prev_record.timestamp,
913                prev_page_prot_record: None,
914            }
915        }
916
917        tracing::trace!("register refresh to: {}", self.clk - 1);
918
919        let mut out = [MaybeUninit::uninit(); 32];
920        for (i, record) in out.iter_mut().enumerate() {
921            *record = MaybeUninit::new(bump_register(self, i));
922        }
923
924        // SAFETY: We're transmuting a [MaybeUninit<MemoryReadRecord>; 32] to a [MemoryReadRecord;
925        // 32], which we just initialized.
926        //
927        // These types are guaranteed to have the same representation.
928        unsafe { std::mem::transmute(out) }
929    }
930
931    /// Write a value to a register, updating the register entry and returning the record.
932    #[inline]
933    fn rw(&mut self, register: Register, value: u64) -> MemoryWriteRecord {
934        let value = if register == Register::X0 { 0 } else { value };
935
936        let prev_record = self.registers[register as usize];
937        let new_record = MemoryRecord { timestamp: self.timestamp(MemoryAccessPosition::A), value };
938
939        self.registers[register as usize] = new_record;
940
941        MemoryWriteRecord {
942            value: new_record.value,
943            timestamp: new_record.timestamp,
944            prev_timestamp: prev_record.timestamp,
945            prev_value: prev_record.value,
946            prev_page_prot_record: None,
947        }
948    }
949}
950
951impl<M: ExecutionMode> CoreVM<'_, M> {
952    /// Get the current timestamp for a given memory access position.
953    #[inline]
954    #[must_use]
955    pub const fn timestamp(&self, position: MemoryAccessPosition) -> u64 {
956        self.clk + position as u64
957    }
958
959    /// Check if the top 24 bits have changed, which imply a `state bump` event needs to be emitted.
960    #[inline]
961    #[must_use]
962    pub const fn needs_bump_clk_high(&self) -> bool {
963        (self.next_clk() >> 24) ^ (self.clk() >> 24) > 0
964    }
965
966    /// Check if the state needs to be bumped, which implies a `state bump` event needs to be
967    /// emitted.
968    #[inline]
969    #[must_use]
970    pub const fn needs_state_bump(&self, instruction: &Instruction) -> bool {
971        let next_pc = self.next_pc();
972        let increment = self.next_clk() + 8 - self.clk();
973
974        let bump1 = self.clk() % (1 << 24) + increment >= (1 << 24);
975        let bump2 = !instruction.is_with_correct_next_pc()
976            && next_pc == self.pc().wrapping_add(4)
977            && (next_pc >> 16) != (self.pc() >> 16);
978
979        bump1 || bump2
980    }
981}
982
983impl<'a, M: ExecutionMode> CoreVM<'a, M> {
984    #[inline]
985    #[must_use]
986    /// Get the current clock, this clock is incremented by [`CLK_INC`] each cycle.
987    pub const fn clk(&self) -> u64 {
988        self.clk
989    }
990
991    #[inline]
992    /// Set the current clock.
993    pub fn set_clk(&mut self, new_clk: u64) {
994        self.clk = new_clk;
995    }
996
997    #[inline]
998    /// Set the next clock.
999    pub fn set_next_clk(&mut self, clk: u64) {
1000        self.next_clk = clk;
1001    }
1002
1003    #[inline]
1004    #[must_use]
1005    /// Get the global clock, this clock is incremented by 1 each cycle.
1006    pub fn global_clk(&self) -> u64 {
1007        self.global_clk
1008    }
1009
1010    #[inline]
1011    #[must_use]
1012    /// Get the current program counter.
1013    pub const fn pc(&self) -> u64 {
1014        self.pc
1015    }
1016
1017    #[inline]
1018    #[must_use]
1019    /// Get the next program counter that will be set in [`CoreVM::advance`].
1020    pub const fn next_pc(&self) -> u64 {
1021        self.next_pc
1022    }
1023
1024    #[inline]
1025    /// Set the next program counter.
1026    pub fn set_next_pc(&mut self, pc: u64) {
1027        self.next_pc = pc;
1028    }
1029
1030    #[inline]
1031    #[must_use]
1032    /// Get the exit code.
1033    pub fn exit_code(&self) -> u32 {
1034        self.exit_code
1035    }
1036
1037    #[inline]
1038    /// Set the exit code.
1039    pub fn set_exit_code(&mut self, exit_code: u32) {
1040        self.exit_code = exit_code;
1041    }
1042
1043    #[inline]
1044    /// Set the program counter.
1045    pub fn set_pc(&mut self, pc: u64) {
1046        self.pc = pc;
1047    }
1048
1049    #[inline]
1050    /// Set the global clock.
1051    pub fn set_global_clk(&mut self, global_clk: u64) {
1052        self.global_clk = global_clk;
1053    }
1054
1055    #[inline]
1056    #[must_use]
1057    /// Get the next clock that will be set in [`CoreVM::advance`].
1058    pub const fn next_clk(&self) -> u64 {
1059        self.next_clk
1060    }
1061
1062    #[inline]
1063    #[must_use]
1064    /// Get the current registers (immutable).
1065    pub fn registers(&self) -> &[MemoryRecord; 32] {
1066        &self.registers
1067    }
1068
1069    #[inline]
1070    #[must_use]
1071    /// Get the current registers (mutable).
1072    pub fn registers_mut(&mut self) -> &mut [MemoryRecord; 32] {
1073        &mut self.registers
1074    }
1075
1076    #[inline]
1077    /// Get the memory reads iterator.
1078    pub fn mem_reads(&mut self) -> &mut MemReads<'a> {
1079        &mut self.mem_reads
1080    }
1081
1082    /// Check if the syscall is retained.
1083    #[inline]
1084    #[must_use]
1085    pub fn is_retained_syscall(&self, syscall_code: SyscallCode) -> bool {
1086        self.retained_syscall_codes.contains(&syscall_code)
1087    }
1088
1089    /// Check if the trace has ended.
1090    #[inline]
1091    #[must_use]
1092    pub const fn is_trace_end(&self) -> bool {
1093        self.clk_end == self.clk()
1094    }
1095
1096    /// Check if the program has halted.
1097    #[inline]
1098    #[must_use]
1099    pub const fn is_done(&self) -> bool {
1100        self.pc() == HALT_PC
1101    }
1102}
1103
1104fn sign_extend_imm(value: u64, bits: u8) -> i64 {
1105    let shift = 64 - bits;
1106    ((value as i64) << shift) >> shift
1107}