1use std::sync::Arc;
2
3use serde::Serialize;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use sp1_jit::{MemReads, MemValue, MinimalTrace, TraceChunk};
6
7use crate::{
8 events::{MemoryReadRecord, MemoryWriteRecord},
9 vm::{
10 memory::CompressedMemory,
11 results::{CycleResult, LoadResult, StoreResult},
12 shapes::ShapeChecker,
13 syscall::SyscallRuntime,
14 CoreVM,
15 },
16 ExecutionError, Instruction, Opcode, Program, SP1CoreOpts, SyscallCode,
17};
18
19pub struct SplicingVM<'a> {
25 pub core: CoreVM<'a>,
27 pub shape_checker: ShapeChecker,
29 pub touched_addresses: &'a mut CompressedMemory,
31 pub hint_lens_idx: usize,
33}
34
35impl SplicingVM<'_> {
36 pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
38 if self.core.is_done() {
39 return Ok(CycleResult::Done(true));
40 }
41
42 loop {
43 let mut result = self.execute_instruction()?;
44
45 if !result.is_done() && self.shape_checker.check_shard_limit() {
47 result = CycleResult::ShardBoundary;
48 }
49
50 match result {
51 CycleResult::Done(false) => {}
52 CycleResult::ShardBoundary | CycleResult::TraceEnd => {
53 self.start_new_shard();
54 return Ok(CycleResult::ShardBoundary);
55 }
56 CycleResult::Done(true) => {
57 return Ok(CycleResult::Done(true));
58 }
59 }
60 }
61 }
62
63 pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
65 let instruction = self.core.fetch();
66 if instruction.is_none() {
67 unreachable!("Fetching the next instruction failed");
68 }
69
70 let instruction = unsafe { *instruction.unwrap_unchecked() };
72
73 match &instruction.opcode {
74 Opcode::ADD
75 | Opcode::ADDI
76 | Opcode::SUB
77 | Opcode::XOR
78 | Opcode::OR
79 | Opcode::AND
80 | Opcode::SLL
81 | Opcode::SLLW
82 | Opcode::SRL
83 | Opcode::SRA
84 | Opcode::SRLW
85 | Opcode::SRAW
86 | Opcode::SLT
87 | Opcode::SLTU
88 | Opcode::MUL
89 | Opcode::MULHU
90 | Opcode::MULHSU
91 | Opcode::MULH
92 | Opcode::MULW
93 | Opcode::DIVU
94 | Opcode::REMU
95 | Opcode::DIV
96 | Opcode::REM
97 | Opcode::DIVW
98 | Opcode::ADDW
99 | Opcode::SUBW
100 | Opcode::DIVUW
101 | Opcode::REMUW
102 | Opcode::REMW => {
103 self.execute_alu(&instruction);
104 }
105 Opcode::LB
106 | Opcode::LBU
107 | Opcode::LH
108 | Opcode::LHU
109 | Opcode::LW
110 | Opcode::LWU
111 | Opcode::LD => self.execute_load(&instruction)?,
112 Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
113 self.execute_store(&instruction)?;
114 }
115 Opcode::JAL | Opcode::JALR => {
116 self.execute_jump(&instruction);
117 }
118 Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
119 self.execute_branch(&instruction);
120 }
121 Opcode::LUI | Opcode::AUIPC => {
122 self.execute_utype(&instruction);
123 }
124 Opcode::ECALL => self.execute_ecall(&instruction)?,
125 Opcode::EBREAK | Opcode::UNIMP => {
126 unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
127 }
128 }
129
130 self.shape_checker.handle_instruction(
131 &instruction,
132 self.core.needs_bump_clk_high(),
133 instruction.is_memory_load_instruction() && instruction.op_a == 0,
134 self.core.needs_state_bump(&instruction),
135 );
136
137 Ok(self.core.advance())
138 }
139
140 pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
142 if self.core.is_trace_end() || self.core.is_done() {
144 return None;
145 }
146
147 let total_mem_reads = trace.num_mem_reads();
148
149 Some(SplicedMinimalTrace::new(
150 trace,
151 self.core.registers().iter().map(|v| v.value).collect::<Vec<_>>().try_into().unwrap(),
152 self.core.pc(),
153 self.core.clk(),
154 total_mem_reads as usize - self.core.mem_reads.len(),
155 self.hint_lens_idx,
156 ))
157 }
158
159 fn start_new_shard(&mut self) {
161 self.shape_checker.reset(self.core.clk());
162 self.core.register_refresh();
163 }
164}
165
166impl<'a> SplicingVM<'a> {
167 pub fn new<T: MinimalTrace>(
169 trace: &'a T,
170 program: Arc<Program>,
171 touched_addresses: &'a mut CompressedMemory,
172 proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
173 opts: SP1CoreOpts,
174 ) -> Self {
175 let program_len = program.instructions.len() as u64;
176 let sharding_threshold = opts.sharding_threshold;
177 Self {
178 core: CoreVM::new(trace, program, opts, proof_nonce),
179 touched_addresses,
180 hint_lens_idx: 0,
181 shape_checker: ShapeChecker::new(program_len, trace.clk_start(), sharding_threshold),
182 }
183 }
184
185 #[inline]
192 pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
193 let LoadResult { addr, mr_record, .. } = self.core.execute_load(instruction)?;
194
195 self.touched_addresses.insert(addr & !0b111, true);
197
198 self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
199
200 Ok(())
201 }
202
203 #[inline]
210 pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
211 let StoreResult { addr, mw_record, .. } = self.core.execute_store(instruction)?;
212
213 self.touched_addresses.insert(addr & !0b111, true);
215
216 self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
217
218 Ok(())
219 }
220
221 #[inline]
223 pub fn execute_alu(&mut self, instruction: &Instruction) {
224 let _ = self.core.execute_alu(instruction);
225 }
226
227 #[inline]
229 pub fn execute_jump(&mut self, instruction: &Instruction) {
230 let _ = self.core.execute_jump(instruction);
231 }
232
233 #[inline]
235 pub fn execute_branch(&mut self, instruction: &Instruction) {
236 let _ = self.core.execute_branch(instruction);
237 }
238
239 #[inline]
241 pub fn execute_utype(&mut self, instruction: &Instruction) {
242 let _ = self.core.execute_utype(instruction);
243 }
244
245 #[inline]
247 pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
248 let code = self.core.read_code();
249
250 if code.should_send() == 1 {
251 if self.core.is_retained_syscall(code) {
252 self.shape_checker.handle_retained_syscall(code);
253 } else {
254 self.shape_checker.syscall_sent();
255 }
256 }
257
258 if code == SyscallCode::COMMIT || code == SyscallCode::COMMIT_DEFERRED_PROOFS {
259 self.shape_checker.handle_commit();
260 }
261
262 let _ = CoreVM::execute_ecall(self, instruction, code)?;
263
264 if code == SyscallCode::HINT_LEN {
265 self.hint_lens_idx += 1;
266 }
267
268 Ok(())
269 }
270}
271
272impl<'a> SyscallRuntime<'a> for SplicingVM<'a> {
273 const TRACING: bool = false;
274
275 fn core(&self) -> &CoreVM<'a> {
276 &self.core
277 }
278
279 fn core_mut(&mut self) -> &mut CoreVM<'a> {
280 &mut self.core
281 }
282
283 fn rr(&mut self, register: usize) -> MemoryReadRecord {
284 let record = SyscallRuntime::rr(self.core_mut(), register);
285
286 record
287 }
288
289 fn mw(&mut self, addr: u64) -> MemoryWriteRecord {
290 let record = SyscallRuntime::mw(self.core_mut(), addr);
291
292 self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
293
294 record
295 }
296
297 fn mr(&mut self, addr: u64) -> MemoryReadRecord {
298 let record = SyscallRuntime::mr(self.core_mut(), addr);
299
300 self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
301
302 record
303 }
304
305 fn mr_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
306 let records = SyscallRuntime::mr_slice(self.core_mut(), addr, len);
307
308 for (i, record) in records.iter().enumerate() {
309 self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
310 }
311
312 records
313 }
314
315 fn mw_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
316 let records = SyscallRuntime::mw_slice(self.core_mut(), addr, len);
317
318 for (i, record) in records.iter().enumerate() {
319 self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
320 }
321
322 records
323 }
324}
325
326#[derive(Debug, Clone)]
333pub struct SplicedMinimalTrace<T: MinimalTrace> {
334 inner: T,
335 start_registers: [u64; 32],
336 start_pc: u64,
337 start_clk: u64,
338 memory_reads_idx: usize,
339 hint_lens_idx: usize,
340 last_clk: u64,
341 last_mem_reads_idx: usize,
343}
344
345impl<T: MinimalTrace> SplicedMinimalTrace<T> {
346 #[tracing::instrument(name = "SplicedMinimalTrace::new", skip(inner), level = "trace")]
348 pub fn new(
349 inner: T,
350 start_registers: [u64; 32],
351 start_pc: u64,
352 start_clk: u64,
353 memory_reads_idx: usize,
354 hint_lens_idx: usize,
355 ) -> Self {
356 Self {
357 inner,
358 start_registers,
359 start_pc,
360 start_clk,
361 memory_reads_idx,
362 hint_lens_idx,
363 last_clk: 0,
364 last_mem_reads_idx: 0,
365 }
366 }
367
368 #[tracing::instrument(
370 name = "SplicedMinimalTrace::new_full_trace",
371 skip(trace),
372 level = "trace"
373 )]
374 pub fn new_full_trace(trace: T) -> Self {
375 let start_registers = trace.start_registers();
376 let start_pc = trace.pc_start();
377 let start_clk = trace.clk_start();
378
379 tracing::trace!("start_pc: {}", start_pc);
380 tracing::trace!("start_clk: {}", start_clk);
381 tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
382 tracing::trace!("trace.hint_lens(): {:?}", trace.hint_lens().len());
383
384 Self::new(trace, start_registers, start_pc, start_clk, 0, 0)
385 }
386
387 pub fn set_last_clk(&mut self, clk: u64) {
389 self.last_clk = clk;
390 }
391
392 pub fn set_last_mem_reads_idx(&mut self, mem_reads_idx: usize) {
394 self.last_mem_reads_idx = mem_reads_idx;
395 }
396}
397
398impl<T: MinimalTrace> MinimalTrace for SplicedMinimalTrace<T> {
399 fn start_registers(&self) -> [u64; 32] {
400 self.start_registers
401 }
402
403 fn pc_start(&self) -> u64 {
404 self.start_pc
405 }
406
407 fn clk_start(&self) -> u64 {
408 self.start_clk
409 }
410
411 fn clk_end(&self) -> u64 {
412 self.last_clk
413 }
414
415 fn num_mem_reads(&self) -> u64 {
416 self.inner.num_mem_reads() - self.memory_reads_idx as u64
417 }
418
419 fn mem_reads(&self) -> MemReads<'_> {
420 let mut reads = self.inner.mem_reads();
421 reads.advance(self.memory_reads_idx);
422
423 reads
424 }
425
426 fn hint_lens(&self) -> &[usize] {
427 let slice = self.inner.hint_lens();
428
429 &slice[self.hint_lens_idx..]
430 }
431}
432
433impl<T: MinimalTrace> Serialize for SplicedMinimalTrace<T> {
434 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
435 where
436 S: serde::Serializer,
437 {
438 let len = self.last_mem_reads_idx - self.memory_reads_idx;
439 let mem_reads = unsafe {
440 let mem_reads_buf = Arc::new_uninit_slice(len);
441 let start_mem_reads = self.mem_reads();
442 let src_ptr = start_mem_reads.head_raw();
443 std::ptr::copy_nonoverlapping(src_ptr, mem_reads_buf.as_ptr() as *mut MemValue, len);
444 mem_reads_buf.assume_init()
445 };
446
447 let trace = TraceChunk {
448 start_registers: self.start_registers,
449 pc_start: self.start_pc,
450 clk_start: self.start_clk,
451 clk_end: self.last_clk,
452 hint_lens: self.hint_lens().to_vec(),
455 mem_reads,
456 };
457
458 trace.serialize(serializer)
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use sp1_jit::MemValue;
465
466 use super::*;
467
468 #[test]
469 fn test_serialize_spliced_minimal_trace() {
470 let trace_chunk = TraceChunk {
471 start_registers: [1; 32],
472 pc_start: 2,
473 clk_start: 3,
474 clk_end: 4,
475 hint_lens: vec![5, 6, 7],
476 mem_reads: Arc::new([MemValue { clk: 8, value: 9 }, MemValue { clk: 10, value: 11 }]),
477 };
478
479 let mut trace = SplicedMinimalTrace::new(trace_chunk, [2; 32], 2, 3, 1, 1);
480 trace.set_last_mem_reads_idx(2);
481 trace.set_last_clk(2);
482
483 let serialized = bincode::serialize(&trace).unwrap();
484 let deserialized: TraceChunk = bincode::deserialize(&serialized).unwrap();
485
486 let expected = TraceChunk {
487 start_registers: [2; 32],
488 pc_start: 2,
489 clk_start: 3,
490 clk_end: 2,
491 hint_lens: vec![6, 7],
492 mem_reads: Arc::new([MemValue { clk: 10, value: 11 }]),
493 };
494
495 assert_eq!(deserialized, expected);
496 }
497}