Skip to main content

sp1_core_executor/
splicing.rs

1use std::sync::Arc;
2
3use serde::Serialize;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use sp1_jit::{MemReads, MemValue, MinimalTrace, TraceChunk};
6
7use crate::{
8    events::{MemoryReadRecord, MemoryWriteRecord},
9    vm::{
10        memory::CompressedMemory,
11        results::{CycleResult, LoadResult, StoreResult},
12        shapes::ShapeChecker,
13        syscall::SyscallRuntime,
14        CoreVM,
15    },
16    ExecutionError, Instruction, Opcode, Program, SP1CoreOpts, SyscallCode,
17};
18
19/// A RISC-V VM that uses a [`MinimalTrace`] to create multiple [`SplicedMinimalTrace`]s.
20///
21/// These new [`SplicedMinimalTrace`]s correspond to exactly 1 execuction shard to be proved.
22///
23/// Note that this is the only time we account for trace area throught the execution pipeline.
24pub struct SplicingVM<'a> {
25    /// The core VM.
26    pub core: CoreVM<'a>,
27    /// The shape checker, responsible for cutting the execution when a shard limit is reached.
28    pub shape_checker: ShapeChecker,
29    /// The addresses that have been touched.
30    pub touched_addresses: &'a mut CompressedMemory,
31    /// The index of the hint lens the next shard will use.
32    pub hint_lens_idx: usize,
33}
34
35impl SplicingVM<'_> {
36    /// Execute the program until it halts.
37    pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
38        if self.core.is_done() {
39            return Ok(CycleResult::Done(true));
40        }
41
42        loop {
43            let mut result = self.execute_instruction()?;
44
45            // If were not already done, ensure that we dont have a shard boundary.
46            if !result.is_done() && self.shape_checker.check_shard_limit() {
47                result = CycleResult::ShardBoundary;
48            }
49
50            match result {
51                CycleResult::Done(false) => {}
52                CycleResult::ShardBoundary | CycleResult::TraceEnd => {
53                    self.start_new_shard();
54                    return Ok(CycleResult::ShardBoundary);
55                }
56                CycleResult::Done(true) => {
57                    return Ok(CycleResult::Done(true));
58                }
59            }
60        }
61    }
62
63    /// Execute the next instruction at the current PC.
64    pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
65        let instruction = self.core.fetch();
66        if instruction.is_none() {
67            unreachable!("Fetching the next instruction failed");
68        }
69
70        // SAFETY: The instruction is guaranteed to be valid as we checked for `is_none` above.
71        let instruction = unsafe { *instruction.unwrap_unchecked() };
72
73        match &instruction.opcode {
74            Opcode::ADD
75            | Opcode::ADDI
76            | Opcode::SUB
77            | Opcode::XOR
78            | Opcode::OR
79            | Opcode::AND
80            | Opcode::SLL
81            | Opcode::SLLW
82            | Opcode::SRL
83            | Opcode::SRA
84            | Opcode::SRLW
85            | Opcode::SRAW
86            | Opcode::SLT
87            | Opcode::SLTU
88            | Opcode::MUL
89            | Opcode::MULHU
90            | Opcode::MULHSU
91            | Opcode::MULH
92            | Opcode::MULW
93            | Opcode::DIVU
94            | Opcode::REMU
95            | Opcode::DIV
96            | Opcode::REM
97            | Opcode::DIVW
98            | Opcode::ADDW
99            | Opcode::SUBW
100            | Opcode::DIVUW
101            | Opcode::REMUW
102            | Opcode::REMW => {
103                self.execute_alu(&instruction);
104            }
105            Opcode::LB
106            | Opcode::LBU
107            | Opcode::LH
108            | Opcode::LHU
109            | Opcode::LW
110            | Opcode::LWU
111            | Opcode::LD => self.execute_load(&instruction)?,
112            Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
113                self.execute_store(&instruction)?;
114            }
115            Opcode::JAL | Opcode::JALR => {
116                self.execute_jump(&instruction);
117            }
118            Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
119                self.execute_branch(&instruction);
120            }
121            Opcode::LUI | Opcode::AUIPC => {
122                self.execute_utype(&instruction);
123            }
124            Opcode::ECALL => self.execute_ecall(&instruction)?,
125            Opcode::EBREAK | Opcode::UNIMP => {
126                unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
127            }
128        }
129
130        self.shape_checker.handle_instruction(
131            &instruction,
132            self.core.needs_bump_clk_high(),
133            instruction.is_memory_load_instruction() && instruction.op_a == 0,
134            self.core.needs_state_bump(&instruction),
135        );
136
137        Ok(self.core.advance())
138    }
139
140    /// Splice a minimal trace, outputting a minimal trace for the NEXT shard.
141    pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
142        // If the trace has been exhausted, then the last splice is all thats needed.
143        if self.core.is_trace_end() || self.core.is_done() {
144            return None;
145        }
146
147        let total_mem_reads = trace.num_mem_reads();
148
149        Some(SplicedMinimalTrace::new(
150            trace,
151            self.core.registers().iter().map(|v| v.value).collect::<Vec<_>>().try_into().unwrap(),
152            self.core.pc(),
153            self.core.clk(),
154            total_mem_reads as usize - self.core.mem_reads.len(),
155            self.hint_lens_idx,
156        ))
157    }
158
159    // Indicate that a new shard is starting.
160    fn start_new_shard(&mut self) {
161        self.shape_checker.reset(self.core.clk());
162        self.core.register_refresh();
163    }
164}
165
166impl<'a> SplicingVM<'a> {
167    /// Create a new full-tracing VM from a minimal trace.
168    pub fn new<T: MinimalTrace>(
169        trace: &'a T,
170        program: Arc<Program>,
171        touched_addresses: &'a mut CompressedMemory,
172        proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
173        opts: SP1CoreOpts,
174    ) -> Self {
175        let program_len = program.instructions.len() as u64;
176        let sharding_threshold = opts.sharding_threshold;
177        Self {
178            core: CoreVM::new(trace, program, opts, proof_nonce),
179            touched_addresses,
180            hint_lens_idx: 0,
181            shape_checker: ShapeChecker::new(program_len, trace.clk_start(), sharding_threshold),
182        }
183    }
184
185    /// Execute a load instruction.
186    ///
187    /// This method will update the local memory access for the memory read, the register read,
188    /// and the register write.
189    ///
190    /// It will also emit the memory instruction event and the events for the load instruction.
191    #[inline]
192    pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
193        let LoadResult { addr, mr_record, .. } = self.core.execute_load(instruction)?;
194
195        // Ensure the address is aligned to 8 bytes.
196        self.touched_addresses.insert(addr & !0b111, true);
197
198        self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
199
200        Ok(())
201    }
202
203    /// Execute a store instruction.
204    ///
205    /// This method will update the local memory access for the memory read, the register read,
206    /// and the register write.
207    ///
208    /// It will also emit the memory instruction event and the events for the store instruction.
209    #[inline]
210    pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
211        let StoreResult { addr, mw_record, .. } = self.core.execute_store(instruction)?;
212
213        // Ensure the address is aligned to 8 bytes.
214        self.touched_addresses.insert(addr & !0b111, true);
215
216        self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
217
218        Ok(())
219    }
220
221    /// Execute an ALU instruction and emit the events.
222    #[inline]
223    pub fn execute_alu(&mut self, instruction: &Instruction) {
224        let _ = self.core.execute_alu(instruction);
225    }
226
227    /// Execute a jump instruction and emit the events.
228    #[inline]
229    pub fn execute_jump(&mut self, instruction: &Instruction) {
230        let _ = self.core.execute_jump(instruction);
231    }
232
233    /// Execute a branch instruction and emit the events.
234    #[inline]
235    pub fn execute_branch(&mut self, instruction: &Instruction) {
236        let _ = self.core.execute_branch(instruction);
237    }
238
239    /// Execute a U-type instruction and emit the events.   
240    #[inline]
241    pub fn execute_utype(&mut self, instruction: &Instruction) {
242        let _ = self.core.execute_utype(instruction);
243    }
244
245    /// Execute an ecall instruction and emit the events.
246    #[inline]
247    pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
248        let code = self.core.read_code();
249
250        if code.should_send() == 1 {
251            if self.core.is_retained_syscall(code) {
252                self.shape_checker.handle_retained_syscall(code);
253            } else {
254                self.shape_checker.syscall_sent();
255            }
256        }
257
258        if code == SyscallCode::COMMIT || code == SyscallCode::COMMIT_DEFERRED_PROOFS {
259            self.shape_checker.handle_commit();
260        }
261
262        let _ = CoreVM::execute_ecall(self, instruction, code)?;
263
264        if code == SyscallCode::HINT_LEN {
265            self.hint_lens_idx += 1;
266        }
267
268        Ok(())
269    }
270}
271
272impl<'a> SyscallRuntime<'a> for SplicingVM<'a> {
273    const TRACING: bool = false;
274
275    fn core(&self) -> &CoreVM<'a> {
276        &self.core
277    }
278
279    fn core_mut(&mut self) -> &mut CoreVM<'a> {
280        &mut self.core
281    }
282
283    fn rr(&mut self, register: usize) -> MemoryReadRecord {
284        let record = SyscallRuntime::rr(self.core_mut(), register);
285
286        record
287    }
288
289    fn mw(&mut self, addr: u64) -> MemoryWriteRecord {
290        let record = SyscallRuntime::mw(self.core_mut(), addr);
291
292        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
293
294        record
295    }
296
297    fn mr(&mut self, addr: u64) -> MemoryReadRecord {
298        let record = SyscallRuntime::mr(self.core_mut(), addr);
299
300        self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
301
302        record
303    }
304
305    fn mr_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
306        let records = SyscallRuntime::mr_slice(self.core_mut(), addr, len);
307
308        for (i, record) in records.iter().enumerate() {
309            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
310        }
311
312        records
313    }
314
315    fn mw_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
316        let records = SyscallRuntime::mw_slice(self.core_mut(), addr, len);
317
318        for (i, record) in records.iter().enumerate() {
319            self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
320        }
321
322        records
323    }
324}
325
326/// A minimal trace implentation that starts at a different point in the trace,
327/// but reuses the same memory reads and hint lens.
328///
329/// Note: This type implements [`Serialize`] but it is serialized as a [`TraceChunk`].
330///
331/// In order to deserialize this type, you must use the [`TraceChunk`] type.
332#[derive(Debug, Clone)]
333pub struct SplicedMinimalTrace<T: MinimalTrace> {
334    inner: T,
335    start_registers: [u64; 32],
336    start_pc: u64,
337    start_clk: u64,
338    memory_reads_idx: usize,
339    hint_lens_idx: usize,
340    last_clk: u64,
341    // Normally unused but can be set for the cluster.
342    last_mem_reads_idx: usize,
343}
344
345impl<T: MinimalTrace> SplicedMinimalTrace<T> {
346    /// Create a new spliced minimal trace.
347    #[tracing::instrument(name = "SplicedMinimalTrace::new", skip(inner), level = "trace")]
348    pub fn new(
349        inner: T,
350        start_registers: [u64; 32],
351        start_pc: u64,
352        start_clk: u64,
353        memory_reads_idx: usize,
354        hint_lens_idx: usize,
355    ) -> Self {
356        Self {
357            inner,
358            start_registers,
359            start_pc,
360            start_clk,
361            memory_reads_idx,
362            hint_lens_idx,
363            last_clk: 0,
364            last_mem_reads_idx: 0,
365        }
366    }
367
368    /// Create a new spliced minimal trace from a minimal trace without any splicing.
369    #[tracing::instrument(
370        name = "SplicedMinimalTrace::new_full_trace",
371        skip(trace),
372        level = "trace"
373    )]
374    pub fn new_full_trace(trace: T) -> Self {
375        let start_registers = trace.start_registers();
376        let start_pc = trace.pc_start();
377        let start_clk = trace.clk_start();
378
379        tracing::trace!("start_pc: {}", start_pc);
380        tracing::trace!("start_clk: {}", start_clk);
381        tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
382        tracing::trace!("trace.hint_lens(): {:?}", trace.hint_lens().len());
383
384        Self::new(trace, start_registers, start_pc, start_clk, 0, 0)
385    }
386
387    /// Set the last clock of the spliced minimal trace.
388    pub fn set_last_clk(&mut self, clk: u64) {
389        self.last_clk = clk;
390    }
391
392    /// Set the last memory reads index of the spliced minimal trace.
393    pub fn set_last_mem_reads_idx(&mut self, mem_reads_idx: usize) {
394        self.last_mem_reads_idx = mem_reads_idx;
395    }
396}
397
398impl<T: MinimalTrace> MinimalTrace for SplicedMinimalTrace<T> {
399    fn start_registers(&self) -> [u64; 32] {
400        self.start_registers
401    }
402
403    fn pc_start(&self) -> u64 {
404        self.start_pc
405    }
406
407    fn clk_start(&self) -> u64 {
408        self.start_clk
409    }
410
411    fn clk_end(&self) -> u64 {
412        self.last_clk
413    }
414
415    fn num_mem_reads(&self) -> u64 {
416        self.inner.num_mem_reads() - self.memory_reads_idx as u64
417    }
418
419    fn mem_reads(&self) -> MemReads<'_> {
420        let mut reads = self.inner.mem_reads();
421        reads.advance(self.memory_reads_idx);
422
423        reads
424    }
425
426    fn hint_lens(&self) -> &[usize] {
427        let slice = self.inner.hint_lens();
428
429        &slice[self.hint_lens_idx..]
430    }
431}
432
433impl<T: MinimalTrace> Serialize for SplicedMinimalTrace<T> {
434    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
435    where
436        S: serde::Serializer,
437    {
438        let len = self.last_mem_reads_idx - self.memory_reads_idx;
439        let mem_reads = unsafe {
440            let mem_reads_buf = Arc::new_uninit_slice(len);
441            let start_mem_reads = self.mem_reads();
442            let src_ptr = start_mem_reads.head_raw();
443            std::ptr::copy_nonoverlapping(src_ptr, mem_reads_buf.as_ptr() as *mut MemValue, len);
444            mem_reads_buf.assume_init()
445        };
446
447        let trace = TraceChunk {
448            start_registers: self.start_registers,
449            pc_start: self.start_pc,
450            clk_start: self.start_clk,
451            clk_end: self.last_clk,
452            // Just copy the whole hint_lens buffer,
453            // its small enough that sending the whole buffer is fine (for now).
454            hint_lens: self.hint_lens().to_vec(),
455            mem_reads,
456        };
457
458        trace.serialize(serializer)
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use sp1_jit::MemValue;
465
466    use super::*;
467
468    #[test]
469    fn test_serialize_spliced_minimal_trace() {
470        let trace_chunk = TraceChunk {
471            start_registers: [1; 32],
472            pc_start: 2,
473            clk_start: 3,
474            clk_end: 4,
475            hint_lens: vec![5, 6, 7],
476            mem_reads: Arc::new([MemValue { clk: 8, value: 9 }, MemValue { clk: 10, value: 11 }]),
477        };
478
479        let mut trace = SplicedMinimalTrace::new(trace_chunk, [2; 32], 2, 3, 1, 1);
480        trace.set_last_mem_reads_idx(2);
481        trace.set_last_clk(2);
482
483        let serialized = bincode::serialize(&trace).unwrap();
484        let deserialized: TraceChunk = bincode::deserialize(&serialized).unwrap();
485
486        let expected = TraceChunk {
487            start_registers: [2; 32],
488            pc_start: 2,
489            clk_start: 3,
490            clk_end: 2,
491            hint_lens: vec![6, 7],
492            mem_reads: Arc::new([MemValue { clk: 10, value: 11 }]),
493        };
494
495        assert_eq!(deserialized, expected);
496    }
497}