Skip to main content

sp1_core_executor/
splicing.rs

1use std::sync::Arc;
2
3use serde::Serialize;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use sp1_jit::{MemReads, MemValue, MinimalTrace, TraceChunk};
6
7use crate::{
8    events::{MemoryReadRecord, MemoryWriteRecord},
9    vm::{
10        memory::CompressedMemory,
11        results::{CycleResult, LoadResult, StoreResult},
12        shapes::{ShapeChecker, HALT_AREA, HALT_HEIGHT},
13        syscall::SyscallRuntime,
14        CoreVM,
15    },
16    ExecutionError, Instruction, Opcode, Program, SP1CoreOpts, ShardingThreshold, SyscallCode,
17};
18
19/// A RISC-V VM that uses a [`MinimalTrace`] to create multiple [`SplicedMinimalTrace`]s.
20///
21/// These new [`SplicedMinimalTrace`]s correspond to exactly 1 execuction shard to be proved.
22///
23/// Note that this is the only time we account for trace area throught the execution pipeline.
24pub struct SplicingVM<'a> {
25    /// The core VM.
26    pub core: CoreVM<'a>,
27    /// The shape checker, responsible for cutting the execution when a shard limit is reached.
28    pub shape_checker: ShapeChecker,
29    /// The addresses that have been touched.
30    pub touched_addresses: &'a mut CompressedMemory,
31}
32
33impl SplicingVM<'_> {
34    /// Execute the program until it halts.
35    pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
36        if self.core.is_done() {
37            return Ok(CycleResult::Done(true));
38        }
39
40        loop {
41            let mut result = self.execute_instruction()?;
42
43            // If were not already done, ensure that we dont have a shard boundary.
44            if !result.is_done() && self.shape_checker.check_shard_limit() {
45                result = CycleResult::ShardBoundary;
46            }
47
48            match result {
49                CycleResult::Done(false) => {}
50                CycleResult::ShardBoundary | CycleResult::TraceEnd => {
51                    self.start_new_shard();
52                    return Ok(CycleResult::ShardBoundary);
53                }
54                CycleResult::Done(true) => {
55                    return Ok(CycleResult::Done(true));
56                }
57            }
58        }
59    }
60
61    /// Execute the next instruction at the current PC.
62    pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
63        let instruction = self.core.fetch();
64        if instruction.is_none() {
65            unreachable!("Fetching the next instruction failed");
66        }
67
68        // SAFETY: The instruction is guaranteed to be valid as we checked for `is_none` above.
69        let instruction = unsafe { *instruction.unwrap_unchecked() };
70
71        match &instruction.opcode {
72            Opcode::ADD
73            | Opcode::ADDI
74            | Opcode::SUB
75            | Opcode::XOR
76            | Opcode::OR
77            | Opcode::AND
78            | Opcode::SLL
79            | Opcode::SLLW
80            | Opcode::SRL
81            | Opcode::SRA
82            | Opcode::SRLW
83            | Opcode::SRAW
84            | Opcode::SLT
85            | Opcode::SLTU
86            | Opcode::MUL
87            | Opcode::MULHU
88            | Opcode::MULHSU
89            | Opcode::MULH
90            | Opcode::MULW
91            | Opcode::DIVU
92            | Opcode::REMU
93            | Opcode::DIV
94            | Opcode::REM
95            | Opcode::DIVW
96            | Opcode::ADDW
97            | Opcode::SUBW
98            | Opcode::DIVUW
99            | Opcode::REMUW
100            | Opcode::REMW => {
101                self.execute_alu(&instruction);
102            }
103            Opcode::LB
104            | Opcode::LBU
105            | Opcode::LH
106            | Opcode::LHU
107            | Opcode::LW
108            | Opcode::LWU
109            | Opcode::LD => self.execute_load(&instruction)?,
110            Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
111                self.execute_store(&instruction)?;
112            }
113            Opcode::JAL | Opcode::JALR => {
114                self.execute_jump(&instruction);
115            }
116            Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
117                self.execute_branch(&instruction);
118            }
119            Opcode::LUI | Opcode::AUIPC => {
120                self.execute_utype(&instruction);
121            }
122            Opcode::ECALL => self.execute_ecall(&instruction)?,
123            Opcode::EBREAK | Opcode::UNIMP => {
124                unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
125            }
126        }
127
128        self.shape_checker.handle_instruction(
129            &instruction,
130            self.core.needs_bump_clk_high(),
131            instruction.is_memory_load_instruction() && instruction.op_a == 0,
132            self.core.needs_state_bump(&instruction),
133        );
134
135        Ok(self.core.advance())
136    }
137
138    /// Splice a minimal trace, outputting a minimal trace for the NEXT shard.
139    pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
140        // If the trace has been exhausted, then the last splice is all thats needed.
141        if self.core.is_trace_end() || self.core.is_done() {
142            return None;
143        }
144
145        let total_mem_reads = trace.num_mem_reads();
146
147        Some(SplicedMinimalTrace::new(
148            trace,
149            self.core.registers().iter().map(|v| v.value).collect::<Vec<_>>().try_into().unwrap(),
150            self.core.pc(),
151            self.core.clk(),
152            total_mem_reads as usize - self.core.mem_reads.len(),
153        ))
154    }
155
156    // Indicate that a new shard is starting.
157    fn start_new_shard(&mut self) {
158        self.shape_checker.reset(self.core.clk());
159        self.core.register_refresh();
160    }
161}
162
163impl<'a> SplicingVM<'a> {
164    /// Create a new full-tracing VM from a minimal trace.
165    pub fn new<T: MinimalTrace>(
166        trace: &'a T,
167        program: Arc<Program>,
168        touched_addresses: &'a mut CompressedMemory,
169        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
170        opts: SP1CoreOpts,
171    ) -> Self {
172        let program_len = program.instructions.len() as u64;
173        let ShardingThreshold { element_threshold, height_threshold } = opts.sharding_threshold;
174        assert!(
175            element_threshold >= HALT_AREA && height_threshold >= HALT_HEIGHT,
176            "invalid sharding threshold"
177        );
178        Self {
179            core: CoreVM::new(trace, program, opts, proof_nonce),
180            touched_addresses,
181            shape_checker: ShapeChecker::new(
182                program_len,
183                trace.clk_start(),
184                ShardingThreshold {
185                    element_threshold: element_threshold - HALT_AREA,
186                    height_threshold: height_threshold - HALT_HEIGHT,
187                },
188            ),
189        }
190    }
191
192    /// Execute a load instruction.
193    ///
194    /// This method will update the local memory access for the memory read, the register read,
195    /// and the register write.
196    ///
197    /// It will also emit the memory instruction event and the events for the load instruction.
198    #[inline]
199    pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
200        let LoadResult { addr, mr_record, .. } = self.core.execute_load(instruction)?;
201
202        // Ensure the address is aligned to 8 bytes.
203        self.touched_addresses.insert(addr & !0b111, true);
204
205        self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
206
207        Ok(())
208    }
209
210    /// Execute a store instruction.
211    ///
212    /// This method will update the local memory access for the memory read, the register read,
213    /// and the register write.
214    ///
215    /// It will also emit the memory instruction event and the events for the store instruction.
216    #[inline]
217    pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
218        let StoreResult { addr, mw_record, .. } = self.core.execute_store(instruction)?;
219
220        // Ensure the address is aligned to 8 bytes.
221        self.touched_addresses.insert(addr & !0b111, true);
222
223        self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
224
225        Ok(())
226    }
227
228    /// Execute an ALU instruction and emit the events.
229    #[inline]
230    pub fn execute_alu(&mut self, instruction: &Instruction) {
231        let _ = self.core.execute_alu(instruction);
232    }
233
234    /// Execute a jump instruction and emit the events.
235    #[inline]
236    pub fn execute_jump(&mut self, instruction: &Instruction) {
237        let _ = self.core.execute_jump(instruction);
238    }
239
240    /// Execute a branch instruction and emit the events.
241    #[inline]
242    pub fn execute_branch(&mut self, instruction: &Instruction) {
243        let _ = self.core.execute_branch(instruction);
244    }
245
246    /// Execute a U-type instruction and emit the events.   
247    #[inline]
248    pub fn execute_utype(&mut self, instruction: &Instruction) {
249        let _ = self.core.execute_utype(instruction);
250    }
251
252    /// Execute an ecall instruction and emit the events.
253    #[inline]
254    pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
255        let code = self.core.read_code();
256
257        if code.should_send() == 1 {
258            if self.core.is_retained_syscall(code) {
259                self.shape_checker.handle_retained_syscall(code);
260            } else {
261                self.shape_checker.syscall_sent();
262            }
263        }
264
265        if code == SyscallCode::COMMIT || code == SyscallCode::COMMIT_DEFERRED_PROOFS {
266            self.shape_checker.handle_commit();
267        }
268
269        let _ = CoreVM::execute_ecall(self, instruction, code)?;
270
271        Ok(())
272    }
273}
274
275impl<'a> SyscallRuntime<'a> for SplicingVM<'a> {
276    const TRACING: bool = false;
277
278    fn core(&self) -> &CoreVM<'a> {
279        &self.core
280    }
281
282    fn core_mut(&mut self) -> &mut CoreVM<'a> {
283        &mut self.core
284    }
285
286    fn rr(&mut self, register: usize) -> MemoryReadRecord {
287        let record = SyscallRuntime::rr(self.core_mut(), register);
288
289        record
290    }
291
292    fn mw(&mut self, addr: u64) -> MemoryWriteRecord {
293        let record = SyscallRuntime::mw(self.core_mut(), addr);
294
295        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
296
297        record
298    }
299
300    fn mr(&mut self, addr: u64) -> MemoryReadRecord {
301        let record = SyscallRuntime::mr(self.core_mut(), addr);
302
303        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
304
305        record
306    }
307
308    fn mr_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
309        let records = SyscallRuntime::mr_slice(self.core_mut(), addr, len);
310
311        for (i, record) in records.iter().enumerate() {
312            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
313        }
314
315        records
316    }
317
318    fn mw_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
319        let records = SyscallRuntime::mw_slice(self.core_mut(), addr, len);
320
321        for (i, record) in records.iter().enumerate() {
322            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
323        }
324
325        records
326    }
327}
328
329/// A minimal trace implentation that starts at a different point in the trace,
330/// but reuses the same memory reads and hint lens.
331///
332/// Note: This type implements [`Serialize`] but it is serialized as a [`TraceChunk`].
333///
334/// In order to deserialize this type, you must use the [`TraceChunk`] type.
335#[derive(Debug, Clone)]
336pub struct SplicedMinimalTrace<T: MinimalTrace> {
337    inner: T,
338    start_registers: [u64; 32],
339    start_pc: u64,
340    start_clk: u64,
341    memory_reads_idx: usize,
342    last_clk: u64,
343    // Normally unused but can be set for the cluster.
344    last_mem_reads_idx: usize,
345}
346
347impl<T: MinimalTrace> SplicedMinimalTrace<T> {
348    /// Create a new spliced minimal trace.
349    #[tracing::instrument(name = "SplicedMinimalTrace::new", skip(inner), level = "trace")]
350    pub fn new(
351        inner: T,
352        start_registers: [u64; 32],
353        start_pc: u64,
354        start_clk: u64,
355        memory_reads_idx: usize,
356    ) -> Self {
357        Self {
358            inner,
359            start_registers,
360            start_pc,
361            start_clk,
362            memory_reads_idx,
363            last_clk: 0,
364            last_mem_reads_idx: 0,
365        }
366    }
367
368    /// Create a new spliced minimal trace from a minimal trace without any splicing.
369    #[tracing::instrument(
370        name = "SplicedMinimalTrace::new_full_trace",
371        skip(trace),
372        level = "trace"
373    )]
374    pub fn new_full_trace(trace: T) -> Self {
375        let start_registers = trace.start_registers();
376        let start_pc = trace.pc_start();
377        let start_clk = trace.clk_start();
378
379        tracing::trace!("start_pc: {}", start_pc);
380        tracing::trace!("start_clk: {}", start_clk);
381        tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
382
383        Self::new(trace, start_registers, start_pc, start_clk, 0)
384    }
385
386    /// Set the last clock of the spliced minimal trace.
387    pub fn set_last_clk(&mut self, clk: u64) {
388        self.last_clk = clk;
389    }
390
391    /// Set the last memory reads index of the spliced minimal trace.
392    pub fn set_last_mem_reads_idx(&mut self, mem_reads_idx: usize) {
393        self.last_mem_reads_idx = mem_reads_idx;
394    }
395}
396
397impl<T: MinimalTrace> MinimalTrace for SplicedMinimalTrace<T> {
398    fn start_registers(&self) -> [u64; 32] {
399        self.start_registers
400    }
401
402    fn pc_start(&self) -> u64 {
403        self.start_pc
404    }
405
406    fn clk_start(&self) -> u64 {
407        self.start_clk
408    }
409
410    fn clk_end(&self) -> u64 {
411        self.last_clk
412    }
413
414    fn num_mem_reads(&self) -> u64 {
415        self.inner.num_mem_reads() - self.memory_reads_idx as u64
416    }
417
418    fn mem_reads(&self) -> MemReads<'_> {
419        let mut reads = self.inner.mem_reads();
420        reads.advance(self.memory_reads_idx);
421
422        reads
423    }
424}
425
426impl<T: MinimalTrace> Serialize for SplicedMinimalTrace<T> {
427    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
428    where
429        S: serde::Serializer,
430    {
431        let len = self.last_mem_reads_idx - self.memory_reads_idx;
432        let mem_reads = unsafe {
433            let mem_reads_buf = Arc::new_uninit_slice(len);
434            let start_mem_reads = self.mem_reads();
435            let src_ptr = start_mem_reads.head_raw();
436            std::ptr::copy_nonoverlapping(src_ptr, mem_reads_buf.as_ptr() as *mut MemValue, len);
437            mem_reads_buf.assume_init()
438        };
439
440        let trace = TraceChunk {
441            start_registers: self.start_registers,
442            pc_start: self.start_pc,
443            clk_start: self.start_clk,
444            clk_end: self.last_clk,
445            mem_reads,
446        };
447
448        trace.serialize(serializer)
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use sp1_jit::MemValue;
455
456    use super::*;
457
458    #[test]
459    fn test_serialize_spliced_minimal_trace() {
460        let trace_chunk = TraceChunk {
461            start_registers: [1; 32],
462            pc_start: 2,
463            clk_start: 3,
464            clk_end: 4,
465            mem_reads: Arc::new([MemValue { clk: 8, value: 9 }, MemValue { clk: 10, value: 11 }]),
466        };
467
468        let mut trace = SplicedMinimalTrace::new(trace_chunk, [2; 32], 2, 3, 1);
469        trace.set_last_mem_reads_idx(2);
470        trace.set_last_clk(2);
471
472        let serialized = bincode::serialize(&trace).unwrap();
473        let deserialized: TraceChunk = bincode::deserialize(&serialized).unwrap();
474
475        let expected = TraceChunk {
476            start_registers: [2; 32],
477            pc_start: 2,
478            clk_start: 3,
479            clk_end: 2,
480            mem_reads: Arc::new([MemValue { clk: 10, value: 11 }]),
481        };
482
483        assert_eq!(deserialized, expected);
484    }
485}