1use std::sync::Arc;
2
3use serde::Serialize;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use sp1_jit::{MemReads, MemValue, MinimalTrace, TraceChunk};
6
7use crate::{
8 events::{MemoryReadRecord, MemoryWriteRecord},
9 vm::{
10 memory::CompressedMemory,
11 results::{CycleResult, LoadResult, StoreResult},
12 shapes::{ShapeChecker, HALT_AREA, HALT_HEIGHT},
13 syscall::SyscallRuntime,
14 CoreVM,
15 },
16 ExecutionError, Instruction, Opcode, Program, SP1CoreOpts, ShardingThreshold, SyscallCode,
17};
18
19pub struct SplicingVM<'a> {
25 pub core: CoreVM<'a>,
27 pub shape_checker: ShapeChecker,
29 pub touched_addresses: &'a mut CompressedMemory,
31}
32
33impl SplicingVM<'_> {
34 pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
36 if self.core.is_done() {
37 return Ok(CycleResult::Done(true));
38 }
39
40 loop {
41 let mut result = self.execute_instruction()?;
42
43 if !result.is_done() && self.shape_checker.check_shard_limit() {
45 result = CycleResult::ShardBoundary;
46 }
47
48 match result {
49 CycleResult::Done(false) => {}
50 CycleResult::ShardBoundary | CycleResult::TraceEnd => {
51 self.start_new_shard();
52 return Ok(CycleResult::ShardBoundary);
53 }
54 CycleResult::Done(true) => {
55 return Ok(CycleResult::Done(true));
56 }
57 }
58 }
59 }
60
61 pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
63 let instruction = self.core.fetch();
64 if instruction.is_none() {
65 unreachable!("Fetching the next instruction failed");
66 }
67
68 let instruction = unsafe { *instruction.unwrap_unchecked() };
70
71 match &instruction.opcode {
72 Opcode::ADD
73 | Opcode::ADDI
74 | Opcode::SUB
75 | Opcode::XOR
76 | Opcode::OR
77 | Opcode::AND
78 | Opcode::SLL
79 | Opcode::SLLW
80 | Opcode::SRL
81 | Opcode::SRA
82 | Opcode::SRLW
83 | Opcode::SRAW
84 | Opcode::SLT
85 | Opcode::SLTU
86 | Opcode::MUL
87 | Opcode::MULHU
88 | Opcode::MULHSU
89 | Opcode::MULH
90 | Opcode::MULW
91 | Opcode::DIVU
92 | Opcode::REMU
93 | Opcode::DIV
94 | Opcode::REM
95 | Opcode::DIVW
96 | Opcode::ADDW
97 | Opcode::SUBW
98 | Opcode::DIVUW
99 | Opcode::REMUW
100 | Opcode::REMW => {
101 self.execute_alu(&instruction);
102 }
103 Opcode::LB
104 | Opcode::LBU
105 | Opcode::LH
106 | Opcode::LHU
107 | Opcode::LW
108 | Opcode::LWU
109 | Opcode::LD => self.execute_load(&instruction)?,
110 Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
111 self.execute_store(&instruction)?;
112 }
113 Opcode::JAL | Opcode::JALR => {
114 self.execute_jump(&instruction);
115 }
116 Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
117 self.execute_branch(&instruction);
118 }
119 Opcode::LUI | Opcode::AUIPC => {
120 self.execute_utype(&instruction);
121 }
122 Opcode::ECALL => self.execute_ecall(&instruction)?,
123 Opcode::EBREAK | Opcode::UNIMP => {
124 unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
125 }
126 }
127
128 self.shape_checker.handle_instruction(
129 &instruction,
130 self.core.needs_bump_clk_high(),
131 instruction.is_alu_instruction() && instruction.op_a == 0,
132 instruction.is_memory_load_instruction() && instruction.op_a == 0,
133 self.core.needs_state_bump(&instruction),
134 );
135
136 Ok(self.core.advance())
137 }
138
139 pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
141 if self.core.is_trace_end() || self.core.is_done() {
143 return None;
144 }
145
146 let total_mem_reads = trace.num_mem_reads();
147
148 Some(SplicedMinimalTrace::new(
149 trace,
150 self.core.registers().iter().map(|v| v.value).collect::<Vec<_>>().try_into().unwrap(),
151 self.core.pc(),
152 self.core.clk(),
153 total_mem_reads as usize - self.core.mem_reads.len(),
154 ))
155 }
156
157 fn start_new_shard(&mut self) {
159 self.shape_checker.reset(self.core.clk());
160 self.core.register_refresh();
161 }
162}
163
164impl<'a> SplicingVM<'a> {
165 pub fn new<T: MinimalTrace>(
167 trace: &'a T,
168 program: Arc<Program>,
169 touched_addresses: &'a mut CompressedMemory,
170 proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
171 opts: SP1CoreOpts,
172 ) -> Self {
173 let program_len = program.instructions.len() as u64;
174 let ShardingThreshold { element_threshold, height_threshold } = opts.sharding_threshold;
175 assert!(
176 element_threshold >= HALT_AREA && height_threshold >= HALT_HEIGHT,
177 "invalid sharding threshold"
178 );
179 Self {
180 core: CoreVM::new(trace, program, opts, proof_nonce),
181 touched_addresses,
182 shape_checker: ShapeChecker::new(
183 program_len,
184 trace.clk_start(),
185 ShardingThreshold {
186 element_threshold: element_threshold - HALT_AREA,
187 height_threshold: height_threshold - HALT_HEIGHT,
188 },
189 ),
190 }
191 }
192
193 #[inline]
200 pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
201 let LoadResult { addr, mr_record, .. } = self.core.execute_load(instruction)?;
202
203 self.touched_addresses.insert(addr & !0b111, true);
205
206 self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
207
208 Ok(())
209 }
210
211 #[inline]
218 pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
219 let StoreResult { addr, mw_record, .. } = self.core.execute_store(instruction)?;
220
221 self.touched_addresses.insert(addr & !0b111, true);
223
224 self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
225
226 Ok(())
227 }
228
229 #[inline]
231 pub fn execute_alu(&mut self, instruction: &Instruction) {
232 let _ = self.core.execute_alu(instruction);
233 }
234
235 #[inline]
237 pub fn execute_jump(&mut self, instruction: &Instruction) {
238 let _ = self.core.execute_jump(instruction);
239 }
240
241 #[inline]
243 pub fn execute_branch(&mut self, instruction: &Instruction) {
244 let _ = self.core.execute_branch(instruction);
245 }
246
247 #[inline]
249 pub fn execute_utype(&mut self, instruction: &Instruction) {
250 let _ = self.core.execute_utype(instruction);
251 }
252
253 #[inline]
255 pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
256 let code = self.core.read_code();
257
258 if code.should_send() == 1 {
259 if self.core.is_retained_syscall(code) {
260 self.shape_checker.handle_retained_syscall(code);
261 } else {
262 self.shape_checker.syscall_sent();
263 }
264 }
265
266 if code == SyscallCode::COMMIT || code == SyscallCode::COMMIT_DEFERRED_PROOFS {
267 self.shape_checker.handle_commit();
268 }
269
270 let _ = CoreVM::execute_ecall(self, instruction, code)?;
271
272 Ok(())
273 }
274}
275
276impl<'a> SyscallRuntime<'a> for SplicingVM<'a> {
277 const TRACING: bool = false;
278
279 fn core(&self) -> &CoreVM<'a> {
280 &self.core
281 }
282
283 fn core_mut(&mut self) -> &mut CoreVM<'a> {
284 &mut self.core
285 }
286
287 fn rr(&mut self, register: usize) -> MemoryReadRecord {
288 let record = SyscallRuntime::rr(self.core_mut(), register);
289
290 record
291 }
292
293 fn mw(&mut self, addr: u64) -> MemoryWriteRecord {
294 let record = SyscallRuntime::mw(self.core_mut(), addr);
295
296 self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
297
298 record
299 }
300
301 fn mr(&mut self, addr: u64) -> MemoryReadRecord {
302 let record = SyscallRuntime::mr(self.core_mut(), addr);
303
304 self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
305
306 record
307 }
308
309 fn mr_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
310 let records = SyscallRuntime::mr_slice(self.core_mut(), addr, len);
311
312 for (i, record) in records.iter().enumerate() {
313 self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
314 }
315
316 records
317 }
318
319 fn mw_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
320 let records = SyscallRuntime::mw_slice(self.core_mut(), addr, len);
321
322 for (i, record) in records.iter().enumerate() {
323 self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
324 }
325
326 records
327 }
328}
329
330#[derive(Debug, Clone)]
337pub struct SplicedMinimalTrace<T: MinimalTrace> {
338 inner: T,
339 start_registers: [u64; 32],
340 start_pc: u64,
341 start_clk: u64,
342 memory_reads_idx: usize,
343 last_clk: u64,
344 last_mem_reads_idx: usize,
346}
347
348impl<T: MinimalTrace> SplicedMinimalTrace<T> {
349 #[tracing::instrument(name = "SplicedMinimalTrace::new", skip(inner), level = "trace")]
351 pub fn new(
352 inner: T,
353 start_registers: [u64; 32],
354 start_pc: u64,
355 start_clk: u64,
356 memory_reads_idx: usize,
357 ) -> Self {
358 Self {
359 inner,
360 start_registers,
361 start_pc,
362 start_clk,
363 memory_reads_idx,
364 last_clk: 0,
365 last_mem_reads_idx: 0,
366 }
367 }
368
369 #[tracing::instrument(
371 name = "SplicedMinimalTrace::new_full_trace",
372 skip(trace),
373 level = "trace"
374 )]
375 pub fn new_full_trace(trace: T) -> Self {
376 let start_registers = trace.start_registers();
377 let start_pc = trace.pc_start();
378 let start_clk = trace.clk_start();
379
380 tracing::trace!("start_pc: {}", start_pc);
381 tracing::trace!("start_clk: {}", start_clk);
382 tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
383
384 Self::new(trace, start_registers, start_pc, start_clk, 0)
385 }
386
387 pub fn set_last_clk(&mut self, clk: u64) {
389 self.last_clk = clk;
390 }
391
392 pub fn set_last_mem_reads_idx(&mut self, mem_reads_idx: usize) {
394 self.last_mem_reads_idx = mem_reads_idx;
395 }
396}
397
398impl<T: MinimalTrace> MinimalTrace for SplicedMinimalTrace<T> {
399 fn start_registers(&self) -> [u64; 32] {
400 self.start_registers
401 }
402
403 fn pc_start(&self) -> u64 {
404 self.start_pc
405 }
406
407 fn clk_start(&self) -> u64 {
408 self.start_clk
409 }
410
411 fn clk_end(&self) -> u64 {
412 self.last_clk
413 }
414
415 fn num_mem_reads(&self) -> u64 {
416 self.inner.num_mem_reads() - self.memory_reads_idx as u64
417 }
418
419 fn mem_reads(&self) -> MemReads<'_> {
420 let mut reads = self.inner.mem_reads();
421 reads.advance(self.memory_reads_idx);
422
423 reads
424 }
425}
426
427impl<T: MinimalTrace> Serialize for SplicedMinimalTrace<T> {
428 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
429 where
430 S: serde::Serializer,
431 {
432 let len = self.last_mem_reads_idx - self.memory_reads_idx;
433 let mem_reads = unsafe {
434 let mem_reads_buf = Arc::new_uninit_slice(len);
435 let start_mem_reads = self.mem_reads();
436 let src_ptr = start_mem_reads.head_raw();
437 std::ptr::copy_nonoverlapping(src_ptr, mem_reads_buf.as_ptr() as *mut MemValue, len);
438 mem_reads_buf.assume_init()
439 };
440
441 let trace = TraceChunk {
442 start_registers: self.start_registers,
443 pc_start: self.start_pc,
444 clk_start: self.start_clk,
445 clk_end: self.last_clk,
446 mem_reads,
447 };
448
449 trace.serialize(serializer)
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use sp1_jit::MemValue;
456
457 use super::*;
458
459 #[test]
460 fn test_serialize_spliced_minimal_trace() {
461 let trace_chunk = TraceChunk {
462 start_registers: [1; 32],
463 pc_start: 2,
464 clk_start: 3,
465 clk_end: 4,
466 mem_reads: Arc::new([MemValue { clk: 8, value: 9 }, MemValue { clk: 10, value: 11 }]),
467 };
468
469 let mut trace = SplicedMinimalTrace::new(trace_chunk, [2; 32], 2, 3, 1);
470 trace.set_last_mem_reads_idx(2);
471 trace.set_last_clk(2);
472
473 let serialized = bincode::serialize(&trace).unwrap();
474 let deserialized: TraceChunk = bincode::deserialize(&serialized).unwrap();
475
476 let expected = TraceChunk {
477 start_registers: [2; 32],
478 pc_start: 2,
479 clk_start: 3,
480 clk_end: 2,
481 mem_reads: Arc::new([MemValue { clk: 10, value: 11 }]),
482 };
483
484 assert_eq!(deserialized, expected);
485 }
486}