Skip to main content

sp1_core_executor/
splicing.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use serde::Serialize;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use sp1_jit::{MemReads, MemValue, MinimalTrace, TraceChunk};
6use sp1_primitives::consts::LOG_PAGE_SIZE;
7
8use crate::{
9    events::{MemoryReadRecord, MemoryRecord, MemoryWriteRecord, PageProtRecord},
10    vm::{
11        memory::{CompressedMemory, CompressedPages},
12        results::{
13            CycleResult, FetchResult, LoadResult, LoadResultSupervisor, StoreResult,
14            StoreResultSupervisor, TrapResult,
15        },
16        shapes::{ShapeChecker, HALT_AREA, HALT_HEIGHT},
17        syscall::SyscallRuntime,
18        CoreVM,
19    },
20    ExecutionError, ExecutionMode, Instruction, Opcode, Program, SP1CoreOpts, ShardingThreshold,
21    SupervisorMode, SyscallCode, TrapError, UserMode,
22};
23
24/// A RISC-V VM that uses a [`MinimalTrace`] to create multiple [`SplicedMinimalTrace`]s.
25///
26/// These new [`SplicedMinimalTrace`]s correspond to exactly 1 execuction shard to be proved.
27///
28/// Note that this is the only time we account for trace area throught the execution pipeline.
29///
30/// The type parameter `M` determines whether page protection checks are enabled.
31pub struct SplicingVM<'a, M: ExecutionMode> {
32    /// The core VM.
33    pub core: CoreVM<'a, M>,
34    /// The shape checker, responsible for cutting the execution when a shard limit is reached.
35    pub shape_checker: ShapeChecker<M>,
36    /// The addresses that have been touched.
37    pub touched_addresses: &'a mut CompressedMemory,
38    /// The page indices that have been touched (for page protection tracking).
39    pub touched_pages: &'a mut CompressedPages,
40    /// Phantom data for the execution mode.
41    _mode: PhantomData<M>,
42}
43
44impl SplicingVM<'_, SupervisorMode> {
45    /// Execute the program until it halts.
46    pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
47        if self.core.is_done() {
48            return Ok(CycleResult::Done(true));
49        }
50
51        loop {
52            let mut result = self.execute_instruction()?;
53
54            // If we're not already done, ensure that we don't have a shard boundary.
55            if !result.is_done() && self.shape_checker.check_shard_limit() {
56                result = CycleResult::ShardBoundary;
57            }
58
59            match result {
60                CycleResult::Done(false) => {}
61                CycleResult::ShardBoundary | CycleResult::TraceEnd => {
62                    self.start_new_shard();
63                    return Ok(CycleResult::ShardBoundary);
64                }
65                CycleResult::Done(true) => {
66                    return Ok(CycleResult::Done(true));
67                }
68            }
69        }
70    }
71
72    /// Execute the next instruction at the current PC.
73    pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
74        let instruction = self.core.fetch();
75
76        match &instruction.opcode {
77            Opcode::ADD
78            | Opcode::ADDI
79            | Opcode::SUB
80            | Opcode::XOR
81            | Opcode::OR
82            | Opcode::AND
83            | Opcode::SLL
84            | Opcode::SLLW
85            | Opcode::SRL
86            | Opcode::SRA
87            | Opcode::SRLW
88            | Opcode::SRAW
89            | Opcode::SLT
90            | Opcode::SLTU
91            | Opcode::MUL
92            | Opcode::MULHU
93            | Opcode::MULHSU
94            | Opcode::MULH
95            | Opcode::MULW
96            | Opcode::DIVU
97            | Opcode::REMU
98            | Opcode::DIV
99            | Opcode::REM
100            | Opcode::DIVW
101            | Opcode::ADDW
102            | Opcode::SUBW
103            | Opcode::DIVUW
104            | Opcode::REMUW
105            | Opcode::REMW => {
106                self.execute_alu(&instruction);
107            }
108            Opcode::LB
109            | Opcode::LBU
110            | Opcode::LH
111            | Opcode::LHU
112            | Opcode::LW
113            | Opcode::LWU
114            | Opcode::LD => self.execute_load(&instruction)?,
115            Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
116                self.execute_store(&instruction)?;
117            }
118            Opcode::JAL | Opcode::JALR => {
119                self.execute_jump(&instruction);
120            }
121            Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
122                self.execute_branch(&instruction);
123            }
124            Opcode::LUI | Opcode::AUIPC => {
125                self.execute_utype(&instruction);
126            }
127            Opcode::ECALL => self.execute_ecall(&instruction)?,
128            Opcode::EBREAK | Opcode::UNIMP => {
129                unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
130            }
131        }
132
133        self.shape_checker.handle_instruction(
134            &instruction,
135            self.core.needs_bump_clk_high(),
136            instruction.is_alu_instruction() && instruction.op_a == 0,
137            instruction.is_memory_load_instruction() && instruction.op_a == 0,
138            self.core.needs_state_bump(&instruction),
139        );
140
141        Ok(self.core.advance())
142    }
143}
144
145impl SplicingVM<'_, SupervisorMode> {
146    /// Execute a load instruction.
147    ///
148    /// This method will update the local memory access for the memory read, the register read,
149    /// and the register write.
150    ///
151    /// It will also emit the memory instruction event and the events for the load instruction.
152    #[inline]
153    pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
154        let LoadResultSupervisor { addr, mr_record, .. } = self.core.execute_load(instruction)?;
155
156        self.touched_addresses.insert(addr & !0b111, true);
157        self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
158
159        Ok(())
160    }
161
162    /// Execute a store instruction.
163    ///
164    /// This method will update the local memory access for the memory read, the register read,
165    /// and the register write.
166    ///
167    /// It will also emit the memory instruction event and the events for the store instruction.
168    #[inline]
169    pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
170        let StoreResultSupervisor { addr, mw_record, .. } = self.core.execute_store(instruction)?;
171
172        self.touched_addresses.insert(addr & !0b111, true);
173        self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
174
175        Ok(())
176    }
177}
178
179impl SplicingVM<'_, UserMode> {
180    /// Execute the program until it halts.
181    pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
182        if self.core.is_done() {
183            return Ok(CycleResult::Done(true));
184        }
185
186        loop {
187            let mut result = self.execute_instruction()?;
188
189            // If we're not already done, ensure that we don't have a shard boundary.
190            if !result.is_done() && self.shape_checker.check_shard_limit() {
191                result = CycleResult::ShardBoundary;
192            }
193
194            match result {
195                CycleResult::Done(false) => {}
196                CycleResult::ShardBoundary | CycleResult::TraceEnd => {
197                    self.start_new_shard();
198                    return Ok(CycleResult::ShardBoundary);
199                }
200                CycleResult::Done(true) => {
201                    return Ok(CycleResult::Done(true));
202                }
203            }
204        }
205    }
206
207    /// Execute the next instruction at the current PC.
208    pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
209        let FetchResult { instruction, mr_record, pc, error } = self.core.fetch()?;
210        let mut num_page_prot_accesses = 0;
211
212        if let Some(error) = error {
213            self.handle_error(error)?;
214            let page_idx = pc >> LOG_PAGE_SIZE;
215            self.shape_checker.handle_page_prot_event(
216                page_idx,
217                mr_record.unwrap().prev_page_prot_record.unwrap().timestamp,
218            );
219            self.touched_pages.insert(page_idx, true);
220            num_page_prot_accesses += 1;
221            self.shape_checker.handle_trap_exec_event();
222            self.shape_checker
223                .handle_trap_events(self.core().needs_bump_clk_high(), num_page_prot_accesses);
224            return Ok(self.core.advance());
225        }
226
227        if instruction.is_none() {
228            unreachable!("Fetching the next instruction failed");
229        }
230
231        if let Some(mr_record) = mr_record {
232            let instruction_value = (mr_record.value >> ((pc % 8) * 8)) as u32;
233            self.touched_addresses.insert(pc & !0b111, true);
234            self.shape_checker.handle_untrusted_instruction(instruction_value);
235            self.shape_checker.handle_mem_event(pc & !0b111, mr_record.prev_timestamp);
236            let page_idx = pc >> LOG_PAGE_SIZE;
237            self.shape_checker.handle_page_prot_event(
238                page_idx,
239                mr_record.prev_page_prot_record.unwrap().timestamp,
240            );
241            self.touched_pages.insert(page_idx, true);
242            num_page_prot_accesses += 1;
243        }
244
245        // SAFETY: The instruction is guaranteed to be valid as we checked for `is_none` above.
246        let instruction = unsafe { instruction.unwrap_unchecked() };
247
248        match &instruction.opcode {
249            Opcode::ADD
250            | Opcode::ADDI
251            | Opcode::SUB
252            | Opcode::XOR
253            | Opcode::OR
254            | Opcode::AND
255            | Opcode::SLL
256            | Opcode::SLLW
257            | Opcode::SRL
258            | Opcode::SRA
259            | Opcode::SRLW
260            | Opcode::SRAW
261            | Opcode::SLT
262            | Opcode::SLTU
263            | Opcode::MUL
264            | Opcode::MULHU
265            | Opcode::MULHSU
266            | Opcode::MULH
267            | Opcode::MULW
268            | Opcode::DIVU
269            | Opcode::REMU
270            | Opcode::DIV
271            | Opcode::REM
272            | Opcode::DIVW
273            | Opcode::ADDW
274            | Opcode::SUBW
275            | Opcode::DIVUW
276            | Opcode::REMUW
277            | Opcode::REMW => {
278                self.execute_alu(&instruction);
279            }
280            Opcode::LB
281            | Opcode::LBU
282            | Opcode::LH
283            | Opcode::LHU
284            | Opcode::LW
285            | Opcode::LWU
286            | Opcode::LD => self.execute_load(&instruction)?,
287            Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
288                self.execute_store(&instruction)?;
289            }
290            Opcode::JAL | Opcode::JALR => {
291                self.execute_jump(&instruction);
292            }
293            Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
294                self.execute_branch(&instruction);
295            }
296            Opcode::LUI | Opcode::AUIPC => {
297                self.execute_utype(&instruction);
298            }
299            Opcode::ECALL => self.execute_ecall(&instruction)?,
300            Opcode::EBREAK | Opcode::UNIMP => {
301                unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
302            }
303        }
304
305        if instruction.is_memory_load_instruction() || instruction.is_memory_store_instruction() {
306            num_page_prot_accesses += 1;
307        }
308
309        self.shape_checker.handle_instruction(
310            &instruction,
311            self.core.needs_bump_clk_high(),
312            instruction.is_alu_instruction() && instruction.op_a == 0,
313            instruction.is_memory_load_instruction() && instruction.op_a == 0,
314            self.core.needs_state_bump(&instruction),
315            num_page_prot_accesses,
316        );
317
318        Ok(self.core.advance())
319    }
320}
321
322impl SplicingVM<'_, UserMode> {
323    /// Execute a load instruction.
324    ///
325    /// This method will update the local memory access for the memory read, the register read,
326    /// and the register write.
327    ///
328    /// It will also emit the memory instruction event and the events for the load instruction.
329    #[inline]
330    pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
331        let LoadResult { addr, mr_record, error, .. } = self.core.execute_load(instruction)?;
332
333        if let Some(error) = error {
334            self.handle_error(error)?;
335            self.shape_checker.handle_trap_mem_event();
336        } else {
337            self.touched_addresses.insert(addr & !0b111, true);
338            self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
339        }
340
341        if let Some(record) = mr_record.prev_page_prot_record {
342            self.shape_checker.handle_page_prot_event(record.page_idx, record.timestamp);
343            self.touched_pages.insert(record.page_idx, true);
344        }
345
346        Ok(())
347    }
348
349    /// Execute a store instruction.
350    ///
351    /// This method will update the local memory access for the memory read, the register read,
352    /// and the register write.
353    ///
354    /// It will also emit the memory instruction event and the events for the store instruction.
355    #[inline]
356    pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
357        let StoreResult { addr, mw_record, error, .. } = self.core.execute_store(instruction)?;
358
359        if let Some(error) = error {
360            self.handle_error(error)?;
361            self.shape_checker.handle_trap_mem_event();
362        } else {
363            self.touched_addresses.insert(addr & !0b111, true);
364            self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
365        }
366
367        if let Some(record) = mw_record.prev_page_prot_record {
368            self.shape_checker.handle_page_prot_event(record.page_idx, record.timestamp);
369            self.touched_pages.insert(record.page_idx, true);
370        }
371
372        Ok(())
373    }
374}
375
376impl<M: ExecutionMode> SplicingVM<'_, M> {
377    /// Splice a minimal trace, outputting a minimal trace for the NEXT shard.
378    pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
379        // If the trace has been exhausted, then the last splice is all thats needed.
380        if self.core.is_trace_end() || self.core.is_done() {
381            return None;
382        }
383
384        let total_mem_reads = trace.num_mem_reads();
385
386        Some(SplicedMinimalTrace::new(
387            trace,
388            self.core.registers().iter().map(|v| v.value).collect::<Vec<_>>().try_into().unwrap(),
389            self.core.pc(),
390            self.core.clk(),
391            total_mem_reads as usize - self.core.mem_reads.len(),
392        ))
393    }
394
395    // Indicate that a new shard is starting.
396    fn start_new_shard(&mut self) {
397        self.shape_checker.reset(self.core.clk());
398        self.core.register_refresh();
399    }
400}
401
402impl<'a, M: ExecutionMode> SplicingVM<'a, M> {
403    /// Create a new full-tracing VM from a minimal trace.
404    pub fn new<T: MinimalTrace>(
405        trace: &'a T,
406        program: Arc<Program>,
407        touched_addresses: &'a mut CompressedMemory,
408        touched_pages: &'a mut CompressedPages,
409        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
410        opts: SP1CoreOpts,
411    ) -> Self {
412        let program_len = program.instructions.len() as u64;
413        let ShardingThreshold { element_threshold, height_threshold } = opts.sharding_threshold;
414        assert!(
415            element_threshold >= HALT_AREA && height_threshold >= HALT_HEIGHT,
416            "invalid sharding threshold"
417        );
418
419        Self {
420            core: CoreVM::new(trace, program, opts, proof_nonce),
421            touched_addresses,
422            touched_pages,
423            shape_checker: ShapeChecker::new(
424                program_len,
425                trace.clk_start(),
426                ShardingThreshold {
427                    element_threshold: element_threshold - HALT_AREA,
428                    height_threshold: height_threshold - HALT_HEIGHT,
429                },
430            ),
431            _mode: PhantomData,
432        }
433    }
434
435    /// Handles recoverable errors such as traps.
436    pub fn handle_error(&mut self, e: TrapError) -> Result<(), ExecutionError> {
437        let TrapResult { context, code_record, pc_record, handler_record } =
438            self.core.handle_error(e)?;
439
440        self.touched_addresses.insert(context & !0b111, true);
441        self.touched_addresses.insert((context + 8) & !0b111, true);
442        self.touched_addresses.insert((context + 16) & !0b111, true);
443
444        self.shape_checker.handle_mem_event(context, handler_record.prev_timestamp);
445        self.shape_checker.handle_mem_event(context + 8, code_record.prev_timestamp);
446        self.shape_checker.handle_mem_event(context + 16, pc_record.prev_timestamp);
447
448        Ok(())
449    }
450
451    /// Execute an ALU instruction and emit the events.
452    #[inline]
453    pub fn execute_alu(&mut self, instruction: &Instruction) {
454        let _ = self.core.execute_alu(instruction);
455    }
456
457    /// Execute a jump instruction and emit the events.
458    #[inline]
459    pub fn execute_jump(&mut self, instruction: &Instruction) {
460        let _ = self.core.execute_jump(instruction);
461    }
462
463    /// Execute a branch instruction and emit the events.
464    #[inline]
465    pub fn execute_branch(&mut self, instruction: &Instruction) {
466        let _ = self.core.execute_branch(instruction);
467    }
468
469    /// Execute a U-type instruction and emit the events.   
470    #[inline]
471    pub fn execute_utype(&mut self, instruction: &Instruction) {
472        let _ = self.core.execute_utype(instruction);
473    }
474
475    /// Execute an ecall instruction and emit the events.
476    #[inline]
477    pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
478        let code = self.core.read_code();
479
480        if code.should_send() == 1 {
481            if self.core.is_retained_syscall(code) {
482                self.shape_checker.handle_retained_syscall(code);
483            } else {
484                self.shape_checker.syscall_sent();
485            }
486        }
487
488        if code == SyscallCode::COMMIT || code == SyscallCode::COMMIT_DEFERRED_PROOFS {
489            self.shape_checker.handle_commit();
490        }
491
492        let result = CoreVM::execute_ecall(self, instruction, code)?;
493
494        let syscall_sent = self.shape_checker.get_syscall_sent();
495        self.shape_checker.set_syscall_sent(false);
496
497        if let Some(error) = result.error {
498            self.handle_error(error)?;
499        }
500
501        if let Some(record) = result.sig_return_pc_record {
502            self.shape_checker.handle_mem_event(result.b, record.prev_timestamp);
503        }
504        self.shape_checker.set_syscall_sent(syscall_sent);
505
506        Ok(())
507    }
508}
509
510impl<'a, M: ExecutionMode> SyscallRuntime<'a, M> for SplicingVM<'a, M> {
511    const TRACING: bool = false;
512
513    fn core(&self) -> &CoreVM<'a, M> {
514        &self.core
515    }
516
517    fn core_mut(&mut self) -> &mut CoreVM<'a, M> {
518        &mut self.core
519    }
520
521    fn rr(&mut self, register: usize) -> MemoryReadRecord {
522        let record = SyscallRuntime::rr(self.core_mut(), register);
523        self.shape_checker.local_mem_syscall_rr();
524        record
525    }
526
527    fn rw(&mut self, register: usize, value: u64) -> MemoryWriteRecord {
528        let record = SyscallRuntime::rw(self.core_mut(), register, value);
529        self.shape_checker.local_mem_syscall_rr();
530        record
531    }
532
533    fn page_prot_write(&mut self, page_idx: u64, prot: u8) -> PageProtRecord {
534        let prev_page_prot_record = self.core_mut().page_prot_write(page_idx, prot);
535        self.shape_checker.handle_page_prot_event(
536            prev_page_prot_record.page_idx,
537            prev_page_prot_record.timestamp,
538        );
539        self.touched_pages.insert(prev_page_prot_record.page_idx, true);
540        prev_page_prot_record
541    }
542
543    fn page_prot_range_check(
544        &mut self,
545        start_page_idx: u64,
546        end_page_idx: u64,
547        page_prot_bitmap: u8,
548    ) -> (Vec<PageProtRecord>, Option<TrapError>) {
549        let (page_prot_records, error) =
550            self.core_mut().page_prot_range_check(start_page_idx, end_page_idx, page_prot_bitmap);
551        for record in page_prot_records.iter() {
552            self.shape_checker.handle_page_prot_event(record.page_idx, record.timestamp);
553            self.touched_pages.insert(record.page_idx, true);
554        }
555        (page_prot_records, error)
556    }
557
558    fn mr_without_prot(&mut self, addr: u64) -> MemoryReadRecord {
559        let record = self.core_mut().mr_without_prot(addr);
560        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
561        record
562    }
563
564    fn mw_without_prot(&mut self, addr: u64) -> MemoryWriteRecord {
565        let record = self.core_mut().mw_without_prot(addr);
566        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
567        record
568    }
569
570    fn mr_slice_without_prot(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
571        let records = self.core_mut().mr_slice_without_prot(addr, len);
572        for (i, record) in records.iter().enumerate() {
573            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
574        }
575
576        records
577    }
578
579    fn mw_slice_without_prot(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
580        let records = self.core_mut().mw_slice_without_prot(addr, len);
581        for (i, record) in records.iter().enumerate() {
582            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
583        }
584
585        records
586    }
587}
588
589/// A minimal trace implentation that starts at a different point in the trace,
590/// but reuses the same memory reads and hint lens.
591///
592/// Note: This type implements [`Serialize`] but it is serialized as a [`TraceChunk`].
593///
594/// In order to deserialize this type, you must use the [`TraceChunk`] type.
595#[derive(Debug, Clone)]
596pub struct SplicedMinimalTrace<T: MinimalTrace> {
597    inner: T,
598    start_registers: [u64; 32],
599    start_pc: u64,
600    start_clk: u64,
601    memory_reads_idx: usize,
602    last_clk: u64,
603    // Normally unused but can be set for the cluster.
604    last_mem_reads_idx: usize,
605}
606
607impl<T: MinimalTrace> SplicedMinimalTrace<T> {
608    /// Create a new spliced minimal trace.
609    #[tracing::instrument(name = "SplicedMinimalTrace::new", skip(inner), level = "trace")]
610    pub fn new(
611        inner: T,
612        start_registers: [u64; 32],
613        start_pc: u64,
614        start_clk: u64,
615        memory_reads_idx: usize,
616    ) -> Self {
617        Self {
618            inner,
619            start_registers,
620            start_pc,
621            start_clk,
622            memory_reads_idx,
623            last_clk: 0,
624            last_mem_reads_idx: 0,
625        }
626    }
627
628    /// Create a new spliced minimal trace from a minimal trace without any splicing.
629    #[tracing::instrument(
630        name = "SplicedMinimalTrace::new_full_trace",
631        skip(trace),
632        level = "trace"
633    )]
634    pub fn new_full_trace(trace: T) -> Self {
635        let start_registers = trace.start_registers();
636        let start_pc = trace.pc_start();
637        let start_clk = trace.clk_start();
638
639        tracing::trace!("start_pc: {}", start_pc);
640        tracing::trace!("start_clk: {}", start_clk);
641        tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
642
643        Self::new(trace, start_registers, start_pc, start_clk, 0)
644    }
645
646    /// Set the last clock of the spliced minimal trace.
647    pub fn set_last_clk(&mut self, clk: u64) {
648        self.last_clk = clk;
649    }
650
651    /// Set the last memory reads index of the spliced minimal trace.
652    pub fn set_last_mem_reads_idx(&mut self, mem_reads_idx: usize) {
653        self.last_mem_reads_idx = mem_reads_idx;
654    }
655}
656
657impl<T: MinimalTrace> MinimalTrace for SplicedMinimalTrace<T> {
658    fn start_registers(&self) -> [u64; 32] {
659        self.start_registers
660    }
661
662    fn pc_start(&self) -> u64 {
663        self.start_pc
664    }
665
666    fn clk_start(&self) -> u64 {
667        self.start_clk
668    }
669
670    fn clk_end(&self) -> u64 {
671        self.last_clk
672    }
673
674    fn num_mem_reads(&self) -> u64 {
675        self.inner.num_mem_reads() - self.memory_reads_idx as u64
676    }
677
678    fn mem_reads(&self) -> MemReads<'_> {
679        let mut reads = self.inner.mem_reads();
680        reads.advance(self.memory_reads_idx);
681
682        reads
683    }
684}
685
686impl<T: MinimalTrace> Serialize for SplicedMinimalTrace<T> {
687    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
688    where
689        S: serde::Serializer,
690    {
691        let len = self.last_mem_reads_idx - self.memory_reads_idx;
692        let mem_reads = unsafe {
693            let mem_reads_buf = Arc::new_uninit_slice(len);
694            let start_mem_reads = self.mem_reads();
695            let src_ptr = start_mem_reads.head_raw();
696            std::ptr::copy_nonoverlapping(src_ptr, mem_reads_buf.as_ptr() as *mut MemValue, len);
697            mem_reads_buf.assume_init()
698        };
699
700        let trace = TraceChunk {
701            start_registers: self.start_registers,
702            pc_start: self.start_pc,
703            clk_start: self.start_clk,
704            clk_end: self.last_clk,
705            mem_reads,
706        };
707
708        trace.serialize(serializer)
709    }
710}
711
712/// Wrapper enum to handle `SplicingVM` with different execution modes at runtime.
713pub enum SplicingVMEnum<'a> {
714    /// `SplicingVM` for `SupervisorMode`.
715    Supervisor(SplicingVM<'a, SupervisorMode>),
716    /// `SplicingVM` for `UserMode`.
717    User(SplicingVM<'a, UserMode>),
718}
719
720impl<'a> SplicingVMEnum<'a> {
721    /// Create a new `SplicingVMEnum` based on program's `enable_untrusted_programs` flag.
722    pub fn new<T: MinimalTrace>(
723        trace: &'a T,
724        program: Arc<Program>,
725        touched_addresses: &'a mut CompressedMemory,
726        touched_pages: &'a mut CompressedPages,
727        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
728        opts: SP1CoreOpts,
729    ) -> Self {
730        if program.enable_untrusted_programs {
731            Self::User(SplicingVM::<UserMode>::new(
732                trace,
733                program,
734                touched_addresses,
735                touched_pages,
736                proof_nonce,
737                opts,
738            ))
739        } else {
740            Self::Supervisor(SplicingVM::<SupervisorMode>::new(
741                trace,
742                program,
743                touched_addresses,
744                touched_pages,
745                proof_nonce,
746                opts,
747            ))
748        }
749    }
750
751    /// Execute the program until it halts or reaches a shard boundary.
752    pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
753        match self {
754            Self::Supervisor(vm) => vm.execute(),
755            Self::User(vm) => vm.execute(),
756        }
757    }
758
759    /// Splice a minimal trace, outputting a minimal trace for the NEXT shard.
760    pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
761        match self {
762            Self::Supervisor(vm) => vm.splice(trace),
763            Self::User(vm) => vm.splice(trace),
764        }
765    }
766
767    /// Get the current clock.
768    #[must_use]
769    pub fn clk(&self) -> u64 {
770        match self {
771            Self::Supervisor(vm) => vm.core.clk(),
772            Self::User(vm) => vm.core.clk(),
773        }
774    }
775
776    /// Get the global clock.
777    #[must_use]
778    pub fn global_clk(&self) -> u64 {
779        match self {
780            Self::Supervisor(vm) => vm.core.global_clk(),
781            Self::User(vm) => vm.core.global_clk(),
782        }
783    }
784
785    /// Get the current PC.
786    #[must_use]
787    pub fn pc(&self) -> u64 {
788        match self {
789            Self::Supervisor(vm) => vm.core.pc(),
790            Self::User(vm) => vm.core.pc(),
791        }
792    }
793
794    /// Get the number of remaining memory reads.
795    #[must_use]
796    pub fn mem_reads_len(&self) -> usize {
797        match self {
798            Self::Supervisor(vm) => vm.core.mem_reads.len(),
799            Self::User(vm) => vm.core.mem_reads.len(),
800        }
801    }
802
803    /// Get the registers.
804    #[must_use]
805    pub fn registers(&self) -> [MemoryRecord; 32] {
806        match self {
807            Self::Supervisor(vm) => *vm.core.registers(),
808            Self::User(vm) => *vm.core.registers(),
809        }
810    }
811
812    /// Get the exit code.
813    #[must_use]
814    pub fn exit_code(&self) -> u32 {
815        match self {
816            Self::Supervisor(vm) => vm.core.exit_code(),
817            Self::User(vm) => vm.core.exit_code(),
818        }
819    }
820
821    /// Check if done.
822    #[must_use]
823    pub fn is_done(&self) -> bool {
824        match self {
825            Self::Supervisor(vm) => vm.core.is_done(),
826            Self::User(vm) => vm.core.is_done(),
827        }
828    }
829
830    /// Get the public value digest.
831    #[must_use]
832    pub fn public_value_digest(&self) -> [u32; sp1_hypercube::air::PV_DIGEST_NUM_WORDS] {
833        match self {
834            Self::Supervisor(vm) => vm.core.public_value_digest,
835            Self::User(vm) => vm.core.public_value_digest,
836        }
837    }
838
839    /// Get the proof nonce.
840    #[must_use]
841    pub fn proof_nonce(&self) -> [u32; sp1_hypercube::air::PROOF_NONCE_NUM_WORDS] {
842        match self {
843            Self::Supervisor(vm) => vm.core.proof_nonce,
844            Self::User(vm) => vm.core.proof_nonce,
845        }
846    }
847}
848
849#[cfg(test)]
850mod tests {
851    use sp1_jit::MemValue;
852
853    use super::*;
854
855    #[test]
856    fn test_serialize_spliced_minimal_trace() {
857        let trace_chunk = TraceChunk {
858            start_registers: [1; 32],
859            pc_start: 2,
860            clk_start: 3,
861            clk_end: 4,
862            mem_reads: Arc::new([MemValue { clk: 8, value: 9 }, MemValue { clk: 10, value: 11 }]),
863        };
864
865        let mut trace = SplicedMinimalTrace::new(trace_chunk, [2; 32], 2, 3, 1);
866        trace.set_last_mem_reads_idx(2);
867        trace.set_last_clk(2);
868
869        let serialized = bincode::serialize(&trace).unwrap();
870        let deserialized: TraceChunk = bincode::deserialize(&serialized).unwrap();
871
872        let expected = TraceChunk {
873            start_registers: [2; 32],
874            pc_start: 2,
875            clk_start: 3,
876            clk_end: 2,
877            mem_reads: Arc::new([MemValue { clk: 10, value: 11 }]),
878        };
879
880        assert_eq!(deserialized, expected);
881    }
882}