Skip to main content

sp1_core_executor/
splicing.rs

1use std::sync::Arc;
2
3use serde::Serialize;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use sp1_jit::{MemReads, MemValue, MinimalTrace, TraceChunk};
6
7use crate::{
8    events::{MemoryReadRecord, MemoryWriteRecord},
9    vm::{
10        memory::CompressedMemory,
11        results::{CycleResult, LoadResult, StoreResult},
12        shapes::ShapeChecker,
13        syscall::SyscallRuntime,
14        CoreVM,
15    },
16    ExecutionError, Instruction, Opcode, Program, SP1CoreOpts, SyscallCode,
17};
18
19/// A RISC-V VM that uses a [`MinimalTrace`] to create multiple [`SplicedMinimalTrace`]s.
20///
21/// These new [`SplicedMinimalTrace`]s correspond to exactly 1 execuction shard to be proved.
22///
23/// Note that this is the only time we account for trace area throught the execution pipeline.
24pub struct SplicingVM<'a> {
25    /// The core VM.
26    pub core: CoreVM<'a>,
27    /// The shape checker, responsible for cutting the execution when a shard limit is reached.
28    pub shape_checker: ShapeChecker,
29    /// The addresses that have been touched.
30    pub touched_addresses: &'a mut CompressedMemory,
31    /// The index of the hint lens the next shard will use.
32    pub hint_lens_idx: usize,
33}
34
35impl SplicingVM<'_> {
36    /// Execute the program until it halts.
37    pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
38        if self.core.is_done() {
39            return Ok(CycleResult::Done(true));
40        }
41
42        loop {
43            let mut result = self.execute_instruction()?;
44
45            // If were not already done, ensure that we dont have a shard boundary.
46            if !result.is_done() && self.shape_checker.check_shard_limit() {
47                result = CycleResult::ShardBoundary;
48            }
49
50            match result {
51                CycleResult::Done(false) => {}
52                CycleResult::ShardBoundary | CycleResult::TraceEnd => {
53                    self.start_new_shard();
54                    return Ok(CycleResult::ShardBoundary);
55                }
56                CycleResult::Done(true) => {
57                    return Ok(CycleResult::Done(true));
58                }
59            }
60        }
61    }
62
63    /// Execute the next instruction at the current PC.
64    pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
65        let instruction = self.core.fetch();
66        if instruction.is_none() {
67            unreachable!("Fetching the next instruction failed");
68        }
69
70        // SAFETY: The instruction is guaranteed to be valid as we checked for `is_none` above.
71        let instruction = unsafe { *instruction.unwrap_unchecked() };
72
73        match &instruction.opcode {
74            Opcode::ADD
75            | Opcode::ADDI
76            | Opcode::SUB
77            | Opcode::XOR
78            | Opcode::OR
79            | Opcode::AND
80            | Opcode::SLL
81            | Opcode::SLLW
82            | Opcode::SRL
83            | Opcode::SRA
84            | Opcode::SRLW
85            | Opcode::SRAW
86            | Opcode::SLT
87            | Opcode::SLTU
88            | Opcode::MUL
89            | Opcode::MULHU
90            | Opcode::MULHSU
91            | Opcode::MULH
92            | Opcode::MULW
93            | Opcode::DIVU
94            | Opcode::REMU
95            | Opcode::DIV
96            | Opcode::REM
97            | Opcode::DIVW
98            | Opcode::ADDW
99            | Opcode::SUBW
100            | Opcode::DIVUW
101            | Opcode::REMUW
102            | Opcode::REMW => {
103                self.execute_alu(&instruction);
104            }
105            Opcode::LB
106            | Opcode::LBU
107            | Opcode::LH
108            | Opcode::LHU
109            | Opcode::LW
110            | Opcode::LWU
111            | Opcode::LD => self.execute_load(&instruction)?,
112            Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
113                self.execute_store(&instruction)?;
114            }
115            Opcode::JAL | Opcode::JALR => {
116                self.execute_jump(&instruction);
117            }
118            Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
119                self.execute_branch(&instruction);
120            }
121            Opcode::LUI | Opcode::AUIPC => {
122                self.execute_utype(&instruction);
123            }
124            Opcode::ECALL => self.execute_ecall(&instruction)?,
125            Opcode::EBREAK | Opcode::UNIMP => {
126                unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
127            }
128        }
129
130        self.shape_checker.handle_instruction(
131            &instruction,
132            self.core.needs_bump_clk_high(),
133            instruction.is_memory_load_instruction() && instruction.op_a == 0,
134            self.core.needs_state_bump(&instruction),
135        );
136
137        Ok(self.core.advance())
138    }
139
140    /// Splice a minimal trace, outputting a minimal trace for the NEXT shard.
141    pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
142        // If the trace has been exhausted, then the last splice is all thats needed.
143        if self.core.is_trace_end() || self.core.is_done() {
144            return None;
145        }
146
147        let total_mem_reads = trace.num_mem_reads();
148
149        Some(SplicedMinimalTrace::new(
150            trace,
151            self.core.registers().iter().map(|v| v.value).collect::<Vec<_>>().try_into().unwrap(),
152            self.core.pc(),
153            self.core.clk(),
154            total_mem_reads as usize - self.core.mem_reads.len(),
155            self.hint_lens_idx,
156        ))
157    }
158
159    // Indicate that a new shard is starting.
160    fn start_new_shard(&mut self) {
161        self.shape_checker.reset(self.core.clk());
162        self.core.register_refresh();
163    }
164}
165
166impl<'a> SplicingVM<'a> {
167    /// Create a new full-tracing VM from a minimal trace.
168    #[tracing::instrument(name = "SplicingVM::new", skip_all)]
169    pub fn new<T: MinimalTrace>(
170        trace: &'a T,
171        program: Arc<Program>,
172        touched_addresses: &'a mut CompressedMemory,
173        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
174        opts: SP1CoreOpts,
175    ) -> Self {
176        let program_len = program.instructions.len() as u64;
177        let sharding_threshold = opts.sharding_threshold;
178        Self {
179            core: CoreVM::new(trace, program, opts, proof_nonce),
180            touched_addresses,
181            hint_lens_idx: 0,
182            shape_checker: ShapeChecker::new(program_len, trace.clk_start(), sharding_threshold),
183        }
184    }
185
186    /// Execute a load instruction.
187    ///
188    /// This method will update the local memory access for the memory read, the register read,
189    /// and the register write.
190    ///
191    /// It will also emit the memory instruction event and the events for the load instruction.
192    #[inline]
193    pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
194        let LoadResult { addr, mr_record, .. } = self.core.execute_load(instruction)?;
195
196        // Ensure the address is aligned to 8 bytes.
197        self.touched_addresses.insert(addr & !0b111, true);
198
199        self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
200
201        Ok(())
202    }
203
204    /// Execute a store instruction.
205    ///
206    /// This method will update the local memory access for the memory read, the register read,
207    /// and the register write.
208    ///
209    /// It will also emit the memory instruction event and the events for the store instruction.
210    #[inline]
211    pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
212        let StoreResult { addr, mw_record, .. } = self.core.execute_store(instruction)?;
213
214        // Ensure the address is aligned to 8 bytes.
215        self.touched_addresses.insert(addr & !0b111, true);
216
217        self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
218
219        Ok(())
220    }
221
222    /// Execute an ALU instruction and emit the events.
223    #[inline]
224    pub fn execute_alu(&mut self, instruction: &Instruction) {
225        let _ = self.core.execute_alu(instruction);
226    }
227
228    /// Execute a jump instruction and emit the events.
229    #[inline]
230    pub fn execute_jump(&mut self, instruction: &Instruction) {
231        let _ = self.core.execute_jump(instruction);
232    }
233
234    /// Execute a branch instruction and emit the events.
235    #[inline]
236    pub fn execute_branch(&mut self, instruction: &Instruction) {
237        let _ = self.core.execute_branch(instruction);
238    }
239
240    /// Execute a U-type instruction and emit the events.   
241    #[inline]
242    pub fn execute_utype(&mut self, instruction: &Instruction) {
243        let _ = self.core.execute_utype(instruction);
244    }
245
246    /// Execute an ecall instruction and emit the events.
247    #[inline]
248    pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
249        let code = self.core.read_code();
250
251        if code.should_send() == 1 {
252            if self.core.is_retained_syscall(code) {
253                self.shape_checker.handle_retained_syscall(code);
254            } else {
255                self.shape_checker.syscall_sent();
256            }
257        }
258
259        let _ = CoreVM::execute_ecall(self, instruction, code)?;
260
261        if code == SyscallCode::HINT_LEN {
262            self.hint_lens_idx += 1;
263        }
264
265        Ok(())
266    }
267}
268
269impl<'a> SyscallRuntime<'a> for SplicingVM<'a> {
270    const TRACING: bool = false;
271
272    fn core(&self) -> &CoreVM<'a> {
273        &self.core
274    }
275
276    fn core_mut(&mut self) -> &mut CoreVM<'a> {
277        &mut self.core
278    }
279
280    fn rr(&mut self, register: usize) -> MemoryReadRecord {
281        let record = SyscallRuntime::rr(self.core_mut(), register);
282
283        record
284    }
285
286    fn mw(&mut self, addr: u64) -> MemoryWriteRecord {
287        let record = SyscallRuntime::mw(self.core_mut(), addr);
288
289        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
290
291        record
292    }
293
294    fn mr(&mut self, addr: u64) -> MemoryReadRecord {
295        let record = SyscallRuntime::mr(self.core_mut(), addr);
296
297        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
298
299        record
300    }
301
302    fn mr_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
303        let records = SyscallRuntime::mr_slice(self.core_mut(), addr, len);
304
305        for (i, record) in records.iter().enumerate() {
306            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
307        }
308
309        records
310    }
311
312    fn mw_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
313        let records = SyscallRuntime::mw_slice(self.core_mut(), addr, len);
314
315        for (i, record) in records.iter().enumerate() {
316            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
317        }
318
319        records
320    }
321}
322
323/// A minimal trace implentation that starts at a different point in the trace,
324/// but reuses the same memory reads and hint lens.
325///
326/// Note: This type implements [`Serialize`] but it is serialized as a [`TraceChunk`].
327///
328/// In order to deserialize this type, you must use the [`TraceChunk`] type.
329#[derive(Debug, Clone)]
330pub struct SplicedMinimalTrace<T: MinimalTrace> {
331    inner: T,
332    start_registers: [u64; 32],
333    start_pc: u64,
334    start_clk: u64,
335    memory_reads_idx: usize,
336    hint_lens_idx: usize,
337    last_clk: u64,
338    // Normally unused but can be set for the cluster.
339    last_mem_reads_idx: usize,
340}
341
342impl<T: MinimalTrace> SplicedMinimalTrace<T> {
343    /// Create a new spliced minimal trace.
344    #[tracing::instrument(name = "SplicedMinimalTrace::new", skip(inner), level = "trace")]
345    pub fn new(
346        inner: T,
347        start_registers: [u64; 32],
348        start_pc: u64,
349        start_clk: u64,
350        memory_reads_idx: usize,
351        hint_lens_idx: usize,
352    ) -> Self {
353        Self {
354            inner,
355            start_registers,
356            start_pc,
357            start_clk,
358            memory_reads_idx,
359            hint_lens_idx,
360            last_clk: 0,
361            last_mem_reads_idx: 0,
362        }
363    }
364
365    /// Create a new spliced minimal trace from a minimal trace without any splicing.
366    #[tracing::instrument(
367        name = "SplicedMinimalTrace::new_full_trace",
368        skip(trace),
369        level = "trace"
370    )]
371    pub fn new_full_trace(trace: T) -> Self {
372        let start_registers = trace.start_registers();
373        let start_pc = trace.pc_start();
374        let start_clk = trace.clk_start();
375
376        tracing::trace!("start_pc: {}", start_pc);
377        tracing::trace!("start_clk: {}", start_clk);
378        tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
379        tracing::trace!("trace.hint_lens(): {:?}", trace.hint_lens().len());
380
381        Self::new(trace, start_registers, start_pc, start_clk, 0, 0)
382    }
383
384    /// Set the last clock of the spliced minimal trace.
385    pub fn set_last_clk(&mut self, clk: u64) {
386        self.last_clk = clk;
387    }
388
389    /// Set the last memory reads index of the spliced minimal trace.
390    pub fn set_last_mem_reads_idx(&mut self, mem_reads_idx: usize) {
391        self.last_mem_reads_idx = mem_reads_idx;
392    }
393}
394
395impl<T: MinimalTrace> MinimalTrace for SplicedMinimalTrace<T> {
396    fn start_registers(&self) -> [u64; 32] {
397        self.start_registers
398    }
399
400    fn pc_start(&self) -> u64 {
401        self.start_pc
402    }
403
404    fn clk_start(&self) -> u64 {
405        self.start_clk
406    }
407
408    fn clk_end(&self) -> u64 {
409        self.last_clk
410    }
411
412    fn num_mem_reads(&self) -> u64 {
413        self.inner.num_mem_reads() - self.memory_reads_idx as u64
414    }
415
416    fn mem_reads(&self) -> MemReads<'_> {
417        let mut reads = self.inner.mem_reads();
418        reads.advance(self.memory_reads_idx);
419
420        reads
421    }
422
423    fn hint_lens(&self) -> &[usize] {
424        let slice = self.inner.hint_lens();
425
426        &slice[self.hint_lens_idx..]
427    }
428}
429
430impl<T: MinimalTrace> Serialize for SplicedMinimalTrace<T> {
431    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
432    where
433        S: serde::Serializer,
434    {
435        let len = self.last_mem_reads_idx - self.memory_reads_idx;
436        let mem_reads = unsafe {
437            let mem_reads_buf = Arc::new_uninit_slice(len);
438            let start_mem_reads = self.mem_reads();
439            let src_ptr = start_mem_reads.head_raw();
440            std::ptr::copy_nonoverlapping(src_ptr, mem_reads_buf.as_ptr() as *mut MemValue, len);
441            mem_reads_buf.assume_init()
442        };
443
444        let trace = TraceChunk {
445            start_registers: self.start_registers,
446            pc_start: self.start_pc,
447            clk_start: self.start_clk,
448            clk_end: self.last_clk,
449            // Just copy the whole hint_lens buffer,
450            // its small enough that sending the whole buffer is fine (for now).
451            hint_lens: self.hint_lens().to_vec(),
452            mem_reads,
453        };
454
455        trace.serialize(serializer)
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use sp1_jit::MemValue;
462
463    use super::*;
464
465    #[test]
466    fn test_serialize_spliced_minimal_trace() {
467        let trace_chunk = TraceChunk {
468            start_registers: [1; 32],
469            pc_start: 2,
470            clk_start: 3,
471            clk_end: 4,
472            hint_lens: vec![5, 6, 7],
473            mem_reads: Arc::new([MemValue { clk: 8, value: 9 }, MemValue { clk: 10, value: 11 }]),
474        };
475
476        let mut trace = SplicedMinimalTrace::new(trace_chunk, [2; 32], 2, 3, 1, 1);
477        trace.set_last_mem_reads_idx(2);
478        trace.set_last_clk(2);
479
480        let serialized = bincode::serialize(&trace).unwrap();
481        let deserialized: TraceChunk = bincode::deserialize(&serialized).unwrap();
482
483        let expected = TraceChunk {
484            start_registers: [2; 32],
485            pc_start: 2,
486            clk_start: 3,
487            clk_end: 2,
488            hint_lens: vec![6, 7],
489            mem_reads: Arc::new([MemValue { clk: 10, value: 11 }]),
490        };
491
492        assert_eq!(deserialized, expected);
493    }
494}