Skip to main content

sp1_core_executor/
splicing.rs

1use std::sync::Arc;
2
3use serde::Serialize;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use sp1_jit::{MemReads, MemValue, MinimalTrace, TraceChunk};
6
7use crate::{
8    events::{MemoryReadRecord, MemoryWriteRecord},
9    vm::{
10        memory::CompressedMemory,
11        results::{CycleResult, LoadResult, StoreResult},
12        shapes::{ShapeChecker, HALT_AREA, HALT_HEIGHT},
13        syscall::SyscallRuntime,
14        CoreVM,
15    },
16    ExecutionError, Instruction, Opcode, Program, SP1CoreOpts, ShardingThreshold, SyscallCode,
17};
18
19/// A RISC-V VM that uses a [`MinimalTrace`] to create multiple [`SplicedMinimalTrace`]s.
20///
21/// These new [`SplicedMinimalTrace`]s correspond to exactly 1 execuction shard to be proved.
22///
23/// Note that this is the only time we account for trace area throught the execution pipeline.
24pub struct SplicingVM<'a> {
25    /// The core VM.
26    pub core: CoreVM<'a>,
27    /// The shape checker, responsible for cutting the execution when a shard limit is reached.
28    pub shape_checker: ShapeChecker,
29    /// The addresses that have been touched.
30    pub touched_addresses: &'a mut CompressedMemory,
31}
32
33impl SplicingVM<'_> {
34    /// Execute the program until it halts.
35    pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
36        if self.core.is_done() {
37            return Ok(CycleResult::Done(true));
38        }
39
40        loop {
41            let mut result = self.execute_instruction()?;
42
43            // If were not already done, ensure that we dont have a shard boundary.
44            if !result.is_done() && self.shape_checker.check_shard_limit() {
45                result = CycleResult::ShardBoundary;
46            }
47
48            match result {
49                CycleResult::Done(false) => {}
50                CycleResult::ShardBoundary | CycleResult::TraceEnd => {
51                    self.start_new_shard();
52                    return Ok(CycleResult::ShardBoundary);
53                }
54                CycleResult::Done(true) => {
55                    return Ok(CycleResult::Done(true));
56                }
57            }
58        }
59    }
60
61    /// Execute the next instruction at the current PC.
62    pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
63        let instruction = self.core.fetch();
64        if instruction.is_none() {
65            unreachable!("Fetching the next instruction failed");
66        }
67
68        // SAFETY: The instruction is guaranteed to be valid as we checked for `is_none` above.
69        let instruction = unsafe { *instruction.unwrap_unchecked() };
70
71        match &instruction.opcode {
72            Opcode::ADD
73            | Opcode::ADDI
74            | Opcode::SUB
75            | Opcode::XOR
76            | Opcode::OR
77            | Opcode::AND
78            | Opcode::SLL
79            | Opcode::SLLW
80            | Opcode::SRL
81            | Opcode::SRA
82            | Opcode::SRLW
83            | Opcode::SRAW
84            | Opcode::SLT
85            | Opcode::SLTU
86            | Opcode::MUL
87            | Opcode::MULHU
88            | Opcode::MULHSU
89            | Opcode::MULH
90            | Opcode::MULW
91            | Opcode::DIVU
92            | Opcode::REMU
93            | Opcode::DIV
94            | Opcode::REM
95            | Opcode::DIVW
96            | Opcode::ADDW
97            | Opcode::SUBW
98            | Opcode::DIVUW
99            | Opcode::REMUW
100            | Opcode::REMW => {
101                self.execute_alu(&instruction);
102            }
103            Opcode::LB
104            | Opcode::LBU
105            | Opcode::LH
106            | Opcode::LHU
107            | Opcode::LW
108            | Opcode::LWU
109            | Opcode::LD => self.execute_load(&instruction)?,
110            Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
111                self.execute_store(&instruction)?;
112            }
113            Opcode::JAL | Opcode::JALR => {
114                self.execute_jump(&instruction);
115            }
116            Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
117                self.execute_branch(&instruction);
118            }
119            Opcode::LUI | Opcode::AUIPC => {
120                self.execute_utype(&instruction);
121            }
122            Opcode::ECALL => self.execute_ecall(&instruction)?,
123            Opcode::EBREAK | Opcode::UNIMP => {
124                unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
125            }
126        }
127
128        self.shape_checker.handle_instruction(
129            &instruction,
130            self.core.needs_bump_clk_high(),
131            instruction.is_alu_instruction() && instruction.op_a == 0,
132            instruction.is_memory_load_instruction() && instruction.op_a == 0,
133            self.core.needs_state_bump(&instruction),
134        );
135
136        Ok(self.core.advance())
137    }
138
139    /// Splice a minimal trace, outputting a minimal trace for the NEXT shard.
140    pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
141        // If the trace has been exhausted, then the last splice is all thats needed.
142        if self.core.is_trace_end() || self.core.is_done() {
143            return None;
144        }
145
146        let total_mem_reads = trace.num_mem_reads();
147
148        Some(SplicedMinimalTrace::new(
149            trace,
150            self.core.registers().iter().map(|v| v.value).collect::<Vec<_>>().try_into().unwrap(),
151            self.core.pc(),
152            self.core.clk(),
153            total_mem_reads as usize - self.core.mem_reads.len(),
154        ))
155    }
156
157    // Indicate that a new shard is starting.
158    fn start_new_shard(&mut self) {
159        self.shape_checker.reset(self.core.clk());
160        self.core.register_refresh();
161    }
162}
163
164impl<'a> SplicingVM<'a> {
165    /// Create a new full-tracing VM from a minimal trace.
166    pub fn new<T: MinimalTrace>(
167        trace: &'a T,
168        program: Arc<Program>,
169        touched_addresses: &'a mut CompressedMemory,
170        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
171        opts: SP1CoreOpts,
172    ) -> Self {
173        let program_len = program.instructions.len() as u64;
174        let ShardingThreshold { element_threshold, height_threshold } = opts.sharding_threshold;
175        assert!(
176            element_threshold >= HALT_AREA && height_threshold >= HALT_HEIGHT,
177            "invalid sharding threshold"
178        );
179        Self {
180            core: CoreVM::new(trace, program, opts, proof_nonce),
181            touched_addresses,
182            shape_checker: ShapeChecker::new(
183                program_len,
184                trace.clk_start(),
185                ShardingThreshold {
186                    element_threshold: element_threshold - HALT_AREA,
187                    height_threshold: height_threshold - HALT_HEIGHT,
188                },
189            ),
190        }
191    }
192
193    /// Execute a load instruction.
194    ///
195    /// This method will update the local memory access for the memory read, the register read,
196    /// and the register write.
197    ///
198    /// It will also emit the memory instruction event and the events for the load instruction.
199    #[inline]
200    pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
201        let LoadResult { addr, mr_record, .. } = self.core.execute_load(instruction)?;
202
203        // Ensure the address is aligned to 8 bytes.
204        self.touched_addresses.insert(addr & !0b111, true);
205
206        self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
207
208        Ok(())
209    }
210
211    /// Execute a store instruction.
212    ///
213    /// This method will update the local memory access for the memory read, the register read,
214    /// and the register write.
215    ///
216    /// It will also emit the memory instruction event and the events for the store instruction.
217    #[inline]
218    pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
219        let StoreResult { addr, mw_record, .. } = self.core.execute_store(instruction)?;
220
221        // Ensure the address is aligned to 8 bytes.
222        self.touched_addresses.insert(addr & !0b111, true);
223
224        self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
225
226        Ok(())
227    }
228
229    /// Execute an ALU instruction and emit the events.
230    #[inline]
231    pub fn execute_alu(&mut self, instruction: &Instruction) {
232        let _ = self.core.execute_alu(instruction);
233    }
234
235    /// Execute a jump instruction and emit the events.
236    #[inline]
237    pub fn execute_jump(&mut self, instruction: &Instruction) {
238        let _ = self.core.execute_jump(instruction);
239    }
240
241    /// Execute a branch instruction and emit the events.
242    #[inline]
243    pub fn execute_branch(&mut self, instruction: &Instruction) {
244        let _ = self.core.execute_branch(instruction);
245    }
246
247    /// Execute a U-type instruction and emit the events.   
248    #[inline]
249    pub fn execute_utype(&mut self, instruction: &Instruction) {
250        let _ = self.core.execute_utype(instruction);
251    }
252
253    /// Execute an ecall instruction and emit the events.
254    #[inline]
255    pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
256        let code = self.core.read_code();
257
258        if code.should_send() == 1 {
259            if self.core.is_retained_syscall(code) {
260                self.shape_checker.handle_retained_syscall(code);
261            } else {
262                self.shape_checker.syscall_sent();
263            }
264        }
265
266        if code == SyscallCode::COMMIT || code == SyscallCode::COMMIT_DEFERRED_PROOFS {
267            self.shape_checker.handle_commit();
268        }
269
270        let _ = CoreVM::execute_ecall(self, instruction, code)?;
271
272        Ok(())
273    }
274}
275
276impl<'a> SyscallRuntime<'a> for SplicingVM<'a> {
277    const TRACING: bool = false;
278
279    fn core(&self) -> &CoreVM<'a> {
280        &self.core
281    }
282
283    fn core_mut(&mut self) -> &mut CoreVM<'a> {
284        &mut self.core
285    }
286
287    fn rr(&mut self, register: usize) -> MemoryReadRecord {
288        let record = SyscallRuntime::rr(self.core_mut(), register);
289
290        record
291    }
292
293    fn mw(&mut self, addr: u64) -> MemoryWriteRecord {
294        let record = SyscallRuntime::mw(self.core_mut(), addr);
295
296        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
297
298        record
299    }
300
301    fn mr(&mut self, addr: u64) -> MemoryReadRecord {
302        let record = SyscallRuntime::mr(self.core_mut(), addr);
303
304        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
305
306        record
307    }
308
309    fn mr_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
310        let records = SyscallRuntime::mr_slice(self.core_mut(), addr, len);
311
312        for (i, record) in records.iter().enumerate() {
313            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
314        }
315
316        records
317    }
318
319    fn mw_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
320        let records = SyscallRuntime::mw_slice(self.core_mut(), addr, len);
321
322        for (i, record) in records.iter().enumerate() {
323            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
324        }
325
326        records
327    }
328}
329
330/// A minimal trace implentation that starts at a different point in the trace,
331/// but reuses the same memory reads and hint lens.
332///
333/// Note: This type implements [`Serialize`] but it is serialized as a [`TraceChunk`].
334///
335/// In order to deserialize this type, you must use the [`TraceChunk`] type.
336#[derive(Debug, Clone)]
337pub struct SplicedMinimalTrace<T: MinimalTrace> {
338    inner: T,
339    start_registers: [u64; 32],
340    start_pc: u64,
341    start_clk: u64,
342    memory_reads_idx: usize,
343    last_clk: u64,
344    // Normally unused but can be set for the cluster.
345    last_mem_reads_idx: usize,
346}
347
348impl<T: MinimalTrace> SplicedMinimalTrace<T> {
349    /// Create a new spliced minimal trace.
350    #[tracing::instrument(name = "SplicedMinimalTrace::new", skip(inner), level = "trace")]
351    pub fn new(
352        inner: T,
353        start_registers: [u64; 32],
354        start_pc: u64,
355        start_clk: u64,
356        memory_reads_idx: usize,
357    ) -> Self {
358        Self {
359            inner,
360            start_registers,
361            start_pc,
362            start_clk,
363            memory_reads_idx,
364            last_clk: 0,
365            last_mem_reads_idx: 0,
366        }
367    }
368
369    /// Create a new spliced minimal trace from a minimal trace without any splicing.
370    #[tracing::instrument(
371        name = "SplicedMinimalTrace::new_full_trace",
372        skip(trace),
373        level = "trace"
374    )]
375    pub fn new_full_trace(trace: T) -> Self {
376        let start_registers = trace.start_registers();
377        let start_pc = trace.pc_start();
378        let start_clk = trace.clk_start();
379
380        tracing::trace!("start_pc: {}", start_pc);
381        tracing::trace!("start_clk: {}", start_clk);
382        tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
383
384        Self::new(trace, start_registers, start_pc, start_clk, 0)
385    }
386
387    /// Set the last clock of the spliced minimal trace.
388    pub fn set_last_clk(&mut self, clk: u64) {
389        self.last_clk = clk;
390    }
391
392    /// Set the last memory reads index of the spliced minimal trace.
393    pub fn set_last_mem_reads_idx(&mut self, mem_reads_idx: usize) {
394        self.last_mem_reads_idx = mem_reads_idx;
395    }
396}
397
398impl<T: MinimalTrace> MinimalTrace for SplicedMinimalTrace<T> {
399    fn start_registers(&self) -> [u64; 32] {
400        self.start_registers
401    }
402
403    fn pc_start(&self) -> u64 {
404        self.start_pc
405    }
406
407    fn clk_start(&self) -> u64 {
408        self.start_clk
409    }
410
411    fn clk_end(&self) -> u64 {
412        self.last_clk
413    }
414
415    fn num_mem_reads(&self) -> u64 {
416        self.inner.num_mem_reads() - self.memory_reads_idx as u64
417    }
418
419    fn mem_reads(&self) -> MemReads<'_> {
420        let mut reads = self.inner.mem_reads();
421        reads.advance(self.memory_reads_idx);
422
423        reads
424    }
425}
426
427impl<T: MinimalTrace> Serialize for SplicedMinimalTrace<T> {
428    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
429    where
430        S: serde::Serializer,
431    {
432        let len = self.last_mem_reads_idx - self.memory_reads_idx;
433        let mem_reads = unsafe {
434            let mem_reads_buf = Arc::new_uninit_slice(len);
435            let start_mem_reads = self.mem_reads();
436            let src_ptr = start_mem_reads.head_raw();
437            std::ptr::copy_nonoverlapping(src_ptr, mem_reads_buf.as_ptr() as *mut MemValue, len);
438            mem_reads_buf.assume_init()
439        };
440
441        let trace = TraceChunk {
442            start_registers: self.start_registers,
443            pc_start: self.start_pc,
444            clk_start: self.start_clk,
445            clk_end: self.last_clk,
446            mem_reads,
447        };
448
449        trace.serialize(serializer)
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use sp1_jit::MemValue;
456
457    use super::*;
458
459    #[test]
460    fn test_serialize_spliced_minimal_trace() {
461        let trace_chunk = TraceChunk {
462            start_registers: [1; 32],
463            pc_start: 2,
464            clk_start: 3,
465            clk_end: 4,
466            mem_reads: Arc::new([MemValue { clk: 8, value: 9 }, MemValue { clk: 10, value: 11 }]),
467        };
468
469        let mut trace = SplicedMinimalTrace::new(trace_chunk, [2; 32], 2, 3, 1);
470        trace.set_last_mem_reads_idx(2);
471        trace.set_last_clk(2);
472
473        let serialized = bincode::serialize(&trace).unwrap();
474        let deserialized: TraceChunk = bincode::deserialize(&serialized).unwrap();
475
476        let expected = TraceChunk {
477            start_registers: [2; 32],
478            pc_start: 2,
479            clk_start: 3,
480            clk_end: 2,
481            mem_reads: Arc::new([MemValue { clk: 10, value: 11 }]),
482        };
483
484        assert_eq!(deserialized, expected);
485    }
486}