1use std::sync::Arc;
2
3use serde::Serialize;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use sp1_jit::{MemReads, MemValue, MinimalTrace, TraceChunk};
6
7use crate::{
8 events::{MemoryReadRecord, MemoryWriteRecord},
9 vm::{
10 memory::CompressedMemory,
11 results::{CycleResult, LoadResult, StoreResult},
12 shapes::ShapeChecker,
13 syscall::SyscallRuntime,
14 CoreVM,
15 },
16 ExecutionError, Instruction, Opcode, Program, SP1CoreOpts, SyscallCode,
17};
18
19pub struct SplicingVM<'a> {
25 pub core: CoreVM<'a>,
27 pub shape_checker: ShapeChecker,
29 pub touched_addresses: &'a mut CompressedMemory,
31 pub hint_lens_idx: usize,
33}
34
35impl SplicingVM<'_> {
36 pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
38 if self.core.is_done() {
39 return Ok(CycleResult::Done(true));
40 }
41
42 loop {
43 let mut result = self.execute_instruction()?;
44
45 if !result.is_done() && self.shape_checker.check_shard_limit() {
47 result = CycleResult::ShardBoundary;
48 }
49
50 match result {
51 CycleResult::Done(false) => {}
52 CycleResult::ShardBoundary | CycleResult::TraceEnd => {
53 self.start_new_shard();
54 return Ok(CycleResult::ShardBoundary);
55 }
56 CycleResult::Done(true) => {
57 return Ok(CycleResult::Done(true));
58 }
59 }
60 }
61 }
62
63 pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
65 let instruction = self.core.fetch();
66 if instruction.is_none() {
67 unreachable!("Fetching the next instruction failed");
68 }
69
70 let instruction = unsafe { *instruction.unwrap_unchecked() };
72
73 match &instruction.opcode {
74 Opcode::ADD
75 | Opcode::ADDI
76 | Opcode::SUB
77 | Opcode::XOR
78 | Opcode::OR
79 | Opcode::AND
80 | Opcode::SLL
81 | Opcode::SLLW
82 | Opcode::SRL
83 | Opcode::SRA
84 | Opcode::SRLW
85 | Opcode::SRAW
86 | Opcode::SLT
87 | Opcode::SLTU
88 | Opcode::MUL
89 | Opcode::MULHU
90 | Opcode::MULHSU
91 | Opcode::MULH
92 | Opcode::MULW
93 | Opcode::DIVU
94 | Opcode::REMU
95 | Opcode::DIV
96 | Opcode::REM
97 | Opcode::DIVW
98 | Opcode::ADDW
99 | Opcode::SUBW
100 | Opcode::DIVUW
101 | Opcode::REMUW
102 | Opcode::REMW => {
103 self.execute_alu(&instruction);
104 }
105 Opcode::LB
106 | Opcode::LBU
107 | Opcode::LH
108 | Opcode::LHU
109 | Opcode::LW
110 | Opcode::LWU
111 | Opcode::LD => self.execute_load(&instruction)?,
112 Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
113 self.execute_store(&instruction)?;
114 }
115 Opcode::JAL | Opcode::JALR => {
116 self.execute_jump(&instruction);
117 }
118 Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
119 self.execute_branch(&instruction);
120 }
121 Opcode::LUI | Opcode::AUIPC => {
122 self.execute_utype(&instruction);
123 }
124 Opcode::ECALL => self.execute_ecall(&instruction)?,
125 Opcode::EBREAK | Opcode::UNIMP => {
126 unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
127 }
128 }
129
130 self.shape_checker.handle_instruction(
131 &instruction,
132 self.core.needs_bump_clk_high(),
133 instruction.is_memory_load_instruction() && instruction.op_a == 0,
134 self.core.needs_state_bump(&instruction),
135 );
136
137 Ok(self.core.advance())
138 }
139
140 pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
142 if self.core.is_trace_end() || self.core.is_done() {
144 return None;
145 }
146
147 let total_mem_reads = trace.num_mem_reads();
148
149 Some(SplicedMinimalTrace::new(
150 trace,
151 self.core.registers().iter().map(|v| v.value).collect::<Vec<_>>().try_into().unwrap(),
152 self.core.pc(),
153 self.core.clk(),
154 total_mem_reads as usize - self.core.mem_reads.len(),
155 self.hint_lens_idx,
156 ))
157 }
158
159 fn start_new_shard(&mut self) {
161 self.shape_checker.reset(self.core.clk());
162 self.core.register_refresh();
163 }
164}
165
166impl<'a> SplicingVM<'a> {
167 #[tracing::instrument(name = "SplicingVM::new", skip_all)]
169 pub fn new<T: MinimalTrace>(
170 trace: &'a T,
171 program: Arc<Program>,
172 touched_addresses: &'a mut CompressedMemory,
173 proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
174 opts: SP1CoreOpts,
175 ) -> Self {
176 let program_len = program.instructions.len() as u64;
177 let sharding_threshold = opts.sharding_threshold;
178 Self {
179 core: CoreVM::new(trace, program, opts, proof_nonce),
180 touched_addresses,
181 hint_lens_idx: 0,
182 shape_checker: ShapeChecker::new(program_len, trace.clk_start(), sharding_threshold),
183 }
184 }
185
186 #[inline]
193 pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
194 let LoadResult { addr, mr_record, .. } = self.core.execute_load(instruction)?;
195
196 self.touched_addresses.insert(addr & !0b111, true);
198
199 self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
200
201 Ok(())
202 }
203
204 #[inline]
211 pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
212 let StoreResult { addr, mw_record, .. } = self.core.execute_store(instruction)?;
213
214 self.touched_addresses.insert(addr & !0b111, true);
216
217 self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
218
219 Ok(())
220 }
221
222 #[inline]
224 pub fn execute_alu(&mut self, instruction: &Instruction) {
225 let _ = self.core.execute_alu(instruction);
226 }
227
228 #[inline]
230 pub fn execute_jump(&mut self, instruction: &Instruction) {
231 let _ = self.core.execute_jump(instruction);
232 }
233
234 #[inline]
236 pub fn execute_branch(&mut self, instruction: &Instruction) {
237 let _ = self.core.execute_branch(instruction);
238 }
239
240 #[inline]
242 pub fn execute_utype(&mut self, instruction: &Instruction) {
243 let _ = self.core.execute_utype(instruction);
244 }
245
246 #[inline]
248 pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
249 let code = self.core.read_code();
250
251 if code.should_send() == 1 {
252 if self.core.is_retained_syscall(code) {
253 self.shape_checker.handle_retained_syscall(code);
254 } else {
255 self.shape_checker.syscall_sent();
256 }
257 }
258
259 let _ = CoreVM::execute_ecall(self, instruction, code)?;
260
261 if code == SyscallCode::HINT_LEN {
262 self.hint_lens_idx += 1;
263 }
264
265 Ok(())
266 }
267}
268
269impl<'a> SyscallRuntime<'a> for SplicingVM<'a> {
270 const TRACING: bool = false;
271
272 fn core(&self) -> &CoreVM<'a> {
273 &self.core
274 }
275
276 fn core_mut(&mut self) -> &mut CoreVM<'a> {
277 &mut self.core
278 }
279
280 fn rr(&mut self, register: usize) -> MemoryReadRecord {
281 let record = SyscallRuntime::rr(self.core_mut(), register);
282
283 record
284 }
285
286 fn mw(&mut self, addr: u64) -> MemoryWriteRecord {
287 let record = SyscallRuntime::mw(self.core_mut(), addr);
288
289 self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
290
291 record
292 }
293
294 fn mr(&mut self, addr: u64) -> MemoryReadRecord {
295 let record = SyscallRuntime::mr(self.core_mut(), addr);
296
297 self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
298
299 record
300 }
301
302 fn mr_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
303 let records = SyscallRuntime::mr_slice(self.core_mut(), addr, len);
304
305 for (i, record) in records.iter().enumerate() {
306 self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
307 }
308
309 records
310 }
311
312 fn mw_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
313 let records = SyscallRuntime::mw_slice(self.core_mut(), addr, len);
314
315 for (i, record) in records.iter().enumerate() {
316 self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
317 }
318
319 records
320 }
321}
322
323#[derive(Debug, Clone)]
330pub struct SplicedMinimalTrace<T: MinimalTrace> {
331 inner: T,
332 start_registers: [u64; 32],
333 start_pc: u64,
334 start_clk: u64,
335 memory_reads_idx: usize,
336 hint_lens_idx: usize,
337 last_clk: u64,
338 last_mem_reads_idx: usize,
340}
341
342impl<T: MinimalTrace> SplicedMinimalTrace<T> {
343 #[tracing::instrument(name = "SplicedMinimalTrace::new", skip(inner), level = "trace")]
345 pub fn new(
346 inner: T,
347 start_registers: [u64; 32],
348 start_pc: u64,
349 start_clk: u64,
350 memory_reads_idx: usize,
351 hint_lens_idx: usize,
352 ) -> Self {
353 Self {
354 inner,
355 start_registers,
356 start_pc,
357 start_clk,
358 memory_reads_idx,
359 hint_lens_idx,
360 last_clk: 0,
361 last_mem_reads_idx: 0,
362 }
363 }
364
365 #[tracing::instrument(
367 name = "SplicedMinimalTrace::new_full_trace",
368 skip(trace),
369 level = "trace"
370 )]
371 pub fn new_full_trace(trace: T) -> Self {
372 let start_registers = trace.start_registers();
373 let start_pc = trace.pc_start();
374 let start_clk = trace.clk_start();
375
376 tracing::trace!("start_pc: {}", start_pc);
377 tracing::trace!("start_clk: {}", start_clk);
378 tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
379 tracing::trace!("trace.hint_lens(): {:?}", trace.hint_lens().len());
380
381 Self::new(trace, start_registers, start_pc, start_clk, 0, 0)
382 }
383
384 pub fn set_last_clk(&mut self, clk: u64) {
386 self.last_clk = clk;
387 }
388
389 pub fn set_last_mem_reads_idx(&mut self, mem_reads_idx: usize) {
391 self.last_mem_reads_idx = mem_reads_idx;
392 }
393}
394
395impl<T: MinimalTrace> MinimalTrace for SplicedMinimalTrace<T> {
396 fn start_registers(&self) -> [u64; 32] {
397 self.start_registers
398 }
399
400 fn pc_start(&self) -> u64 {
401 self.start_pc
402 }
403
404 fn clk_start(&self) -> u64 {
405 self.start_clk
406 }
407
408 fn clk_end(&self) -> u64 {
409 self.last_clk
410 }
411
412 fn num_mem_reads(&self) -> u64 {
413 self.inner.num_mem_reads() - self.memory_reads_idx as u64
414 }
415
416 fn mem_reads(&self) -> MemReads<'_> {
417 let mut reads = self.inner.mem_reads();
418 reads.advance(self.memory_reads_idx);
419
420 reads
421 }
422
423 fn hint_lens(&self) -> &[usize] {
424 let slice = self.inner.hint_lens();
425
426 &slice[self.hint_lens_idx..]
427 }
428}
429
430impl<T: MinimalTrace> Serialize for SplicedMinimalTrace<T> {
431 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
432 where
433 S: serde::Serializer,
434 {
435 let len = self.last_mem_reads_idx - self.memory_reads_idx;
436 let mem_reads = unsafe {
437 let mem_reads_buf = Arc::new_uninit_slice(len);
438 let start_mem_reads = self.mem_reads();
439 let src_ptr = start_mem_reads.head_raw();
440 std::ptr::copy_nonoverlapping(src_ptr, mem_reads_buf.as_ptr() as *mut MemValue, len);
441 mem_reads_buf.assume_init()
442 };
443
444 let trace = TraceChunk {
445 start_registers: self.start_registers,
446 pc_start: self.start_pc,
447 clk_start: self.start_clk,
448 clk_end: self.last_clk,
449 hint_lens: self.hint_lens().to_vec(),
452 mem_reads,
453 };
454
455 trace.serialize(serializer)
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use sp1_jit::MemValue;
462
463 use super::*;
464
465 #[test]
466 fn test_serialize_spliced_minimal_trace() {
467 let trace_chunk = TraceChunk {
468 start_registers: [1; 32],
469 pc_start: 2,
470 clk_start: 3,
471 clk_end: 4,
472 hint_lens: vec![5, 6, 7],
473 mem_reads: Arc::new([MemValue { clk: 8, value: 9 }, MemValue { clk: 10, value: 11 }]),
474 };
475
476 let mut trace = SplicedMinimalTrace::new(trace_chunk, [2; 32], 2, 3, 1, 1);
477 trace.set_last_mem_reads_idx(2);
478 trace.set_last_clk(2);
479
480 let serialized = bincode::serialize(&trace).unwrap();
481 let deserialized: TraceChunk = bincode::deserialize(&serialized).unwrap();
482
483 let expected = TraceChunk {
484 start_registers: [2; 32],
485 pc_start: 2,
486 clk_start: 3,
487 clk_end: 2,
488 hint_lens: vec![6, 7],
489 mem_reads: Arc::new([MemValue { clk: 10, value: 11 }]),
490 };
491
492 assert_eq!(deserialized, expected);
493 }
494}