1use std::sync::Arc;
2
3use serde::Serialize;
4use sp1_hypercube::air::PROOF_NONCE_NUM_WORDS;
5use sp1_jit::{MemReads, MemValue, MinimalTrace, TraceChunk};
6
7use crate::{
8 events::{MemoryReadRecord, MemoryWriteRecord},
9 vm::{
10 memory::CompressedMemory,
11 results::{CycleResult, LoadResult, StoreResult},
12 shapes::{ShapeChecker, HALT_AREA, HALT_HEIGHT},
13 syscall::SyscallRuntime,
14 CoreVM,
15 },
16 ExecutionError, Instruction, Opcode, Program, SP1CoreOpts, ShardingThreshold, SyscallCode,
17};
18
19pub struct SplicingVM<'a> {
25 pub core: CoreVM<'a>,
27 pub shape_checker: ShapeChecker,
29 pub touched_addresses: &'a mut CompressedMemory,
31}
32
33impl SplicingVM<'_> {
34 pub fn execute(&mut self) -> Result<CycleResult, ExecutionError> {
36 if self.core.is_done() {
37 return Ok(CycleResult::Done(true));
38 }
39
40 loop {
41 let mut result = self.execute_instruction()?;
42
43 if !result.is_done() && self.shape_checker.check_shard_limit() {
45 result = CycleResult::ShardBoundary;
46 }
47
48 match result {
49 CycleResult::Done(false) => {}
50 CycleResult::ShardBoundary | CycleResult::TraceEnd => {
51 self.start_new_shard();
52 return Ok(CycleResult::ShardBoundary);
53 }
54 CycleResult::Done(true) => {
55 return Ok(CycleResult::Done(true));
56 }
57 }
58 }
59 }
60
61 pub fn execute_instruction(&mut self) -> Result<CycleResult, ExecutionError> {
63 let instruction = self.core.fetch();
64 if instruction.is_none() {
65 unreachable!("Fetching the next instruction failed");
66 }
67
68 let instruction = unsafe { *instruction.unwrap_unchecked() };
70
71 match &instruction.opcode {
72 Opcode::ADD
73 | Opcode::ADDI
74 | Opcode::SUB
75 | Opcode::XOR
76 | Opcode::OR
77 | Opcode::AND
78 | Opcode::SLL
79 | Opcode::SLLW
80 | Opcode::SRL
81 | Opcode::SRA
82 | Opcode::SRLW
83 | Opcode::SRAW
84 | Opcode::SLT
85 | Opcode::SLTU
86 | Opcode::MUL
87 | Opcode::MULHU
88 | Opcode::MULHSU
89 | Opcode::MULH
90 | Opcode::MULW
91 | Opcode::DIVU
92 | Opcode::REMU
93 | Opcode::DIV
94 | Opcode::REM
95 | Opcode::DIVW
96 | Opcode::ADDW
97 | Opcode::SUBW
98 | Opcode::DIVUW
99 | Opcode::REMUW
100 | Opcode::REMW => {
101 self.execute_alu(&instruction);
102 }
103 Opcode::LB
104 | Opcode::LBU
105 | Opcode::LH
106 | Opcode::LHU
107 | Opcode::LW
108 | Opcode::LWU
109 | Opcode::LD => self.execute_load(&instruction)?,
110 Opcode::SB | Opcode::SH | Opcode::SW | Opcode::SD => {
111 self.execute_store(&instruction)?;
112 }
113 Opcode::JAL | Opcode::JALR => {
114 self.execute_jump(&instruction);
115 }
116 Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => {
117 self.execute_branch(&instruction);
118 }
119 Opcode::LUI | Opcode::AUIPC => {
120 self.execute_utype(&instruction);
121 }
122 Opcode::ECALL => self.execute_ecall(&instruction)?,
123 Opcode::EBREAK | Opcode::UNIMP => {
124 unreachable!("Invalid opcode for `execute_instruction`: {:?}", instruction.opcode)
125 }
126 }
127
128 self.shape_checker.handle_instruction(
129 &instruction,
130 self.core.needs_bump_clk_high(),
131 instruction.is_memory_load_instruction() && instruction.op_a == 0,
132 self.core.needs_state_bump(&instruction),
133 );
134
135 Ok(self.core.advance())
136 }
137
138 pub fn splice<T: MinimalTrace>(&self, trace: T) -> Option<SplicedMinimalTrace<T>> {
140 if self.core.is_trace_end() || self.core.is_done() {
142 return None;
143 }
144
145 let total_mem_reads = trace.num_mem_reads();
146
147 Some(SplicedMinimalTrace::new(
148 trace,
149 self.core.registers().iter().map(|v| v.value).collect::<Vec<_>>().try_into().unwrap(),
150 self.core.pc(),
151 self.core.clk(),
152 total_mem_reads as usize - self.core.mem_reads.len(),
153 ))
154 }
155
156 fn start_new_shard(&mut self) {
158 self.shape_checker.reset(self.core.clk());
159 self.core.register_refresh();
160 }
161}
162
163impl<'a> SplicingVM<'a> {
164 pub fn new<T: MinimalTrace>(
166 trace: &'a T,
167 program: Arc<Program>,
168 touched_addresses: &'a mut CompressedMemory,
169 proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
170 opts: SP1CoreOpts,
171 ) -> Self {
172 let program_len = program.instructions.len() as u64;
173 let ShardingThreshold { element_threshold, height_threshold } = opts.sharding_threshold;
174 assert!(
175 element_threshold >= HALT_AREA && height_threshold >= HALT_HEIGHT,
176 "invalid sharding threshold"
177 );
178 Self {
179 core: CoreVM::new(trace, program, opts, proof_nonce),
180 touched_addresses,
181 shape_checker: ShapeChecker::new(
182 program_len,
183 trace.clk_start(),
184 ShardingThreshold {
185 element_threshold: element_threshold - HALT_AREA,
186 height_threshold: height_threshold - HALT_HEIGHT,
187 },
188 ),
189 }
190 }
191
192 #[inline]
199 pub fn execute_load(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
200 let LoadResult { addr, mr_record, .. } = self.core.execute_load(instruction)?;
201
202 self.touched_addresses.insert(addr & !0b111, true);
204
205 self.shape_checker.handle_mem_event(addr, mr_record.prev_timestamp);
206
207 Ok(())
208 }
209
210 #[inline]
217 pub fn execute_store(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
218 let StoreResult { addr, mw_record, .. } = self.core.execute_store(instruction)?;
219
220 self.touched_addresses.insert(addr & !0b111, true);
222
223 self.shape_checker.handle_mem_event(addr, mw_record.prev_timestamp);
224
225 Ok(())
226 }
227
228 #[inline]
230 pub fn execute_alu(&mut self, instruction: &Instruction) {
231 let _ = self.core.execute_alu(instruction);
232 }
233
234 #[inline]
236 pub fn execute_jump(&mut self, instruction: &Instruction) {
237 let _ = self.core.execute_jump(instruction);
238 }
239
240 #[inline]
242 pub fn execute_branch(&mut self, instruction: &Instruction) {
243 let _ = self.core.execute_branch(instruction);
244 }
245
246 #[inline]
248 pub fn execute_utype(&mut self, instruction: &Instruction) {
249 let _ = self.core.execute_utype(instruction);
250 }
251
252 #[inline]
254 pub fn execute_ecall(&mut self, instruction: &Instruction) -> Result<(), ExecutionError> {
255 let code = self.core.read_code();
256
257 if code.should_send() == 1 {
258 if self.core.is_retained_syscall(code) {
259 self.shape_checker.handle_retained_syscall(code);
260 } else {
261 self.shape_checker.syscall_sent();
262 }
263 }
264
265 if code == SyscallCode::COMMIT || code == SyscallCode::COMMIT_DEFERRED_PROOFS {
266 self.shape_checker.handle_commit();
267 }
268
269 let _ = CoreVM::execute_ecall(self, instruction, code)?;
270
271 Ok(())
272 }
273}
274
275impl<'a> SyscallRuntime<'a> for SplicingVM<'a> {
276 const TRACING: bool = false;
277
278 fn core(&self) -> &CoreVM<'a> {
279 &self.core
280 }
281
282 fn core_mut(&mut self) -> &mut CoreVM<'a> {
283 &mut self.core
284 }
285
286 fn rr(&mut self, register: usize) -> MemoryReadRecord {
287 let record = SyscallRuntime::rr(self.core_mut(), register);
288
289 record
290 }
291
292 fn mw(&mut self, addr: u64) -> MemoryWriteRecord {
293 let record = SyscallRuntime::mw(self.core_mut(), addr);
294
295 self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
296
297 record
298 }
299
300 fn mr(&mut self, addr: u64) -> MemoryReadRecord {
301 let record = SyscallRuntime::mr(self.core_mut(), addr);
302
303 self.shape_checker.handle_mem_event(addr, record.prev_timestamp);
304
305 record
306 }
307
308 fn mr_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryReadRecord> {
309 let records = SyscallRuntime::mr_slice(self.core_mut(), addr, len);
310
311 for (i, record) in records.iter().enumerate() {
312 self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
313 }
314
315 records
316 }
317
318 fn mw_slice(&mut self, addr: u64, len: usize) -> Vec<MemoryWriteRecord> {
319 let records = SyscallRuntime::mw_slice(self.core_mut(), addr, len);
320
321 for (i, record) in records.iter().enumerate() {
322 self.shape_checker.handle_mem_event(addr + i as u64 * 8, record.prev_timestamp);
323 }
324
325 records
326 }
327}
328
329#[derive(Debug, Clone)]
336pub struct SplicedMinimalTrace<T: MinimalTrace> {
337 inner: T,
338 start_registers: [u64; 32],
339 start_pc: u64,
340 start_clk: u64,
341 memory_reads_idx: usize,
342 last_clk: u64,
343 last_mem_reads_idx: usize,
345}
346
347impl<T: MinimalTrace> SplicedMinimalTrace<T> {
348 #[tracing::instrument(name = "SplicedMinimalTrace::new", skip(inner), level = "trace")]
350 pub fn new(
351 inner: T,
352 start_registers: [u64; 32],
353 start_pc: u64,
354 start_clk: u64,
355 memory_reads_idx: usize,
356 ) -> Self {
357 Self {
358 inner,
359 start_registers,
360 start_pc,
361 start_clk,
362 memory_reads_idx,
363 last_clk: 0,
364 last_mem_reads_idx: 0,
365 }
366 }
367
368 #[tracing::instrument(
370 name = "SplicedMinimalTrace::new_full_trace",
371 skip(trace),
372 level = "trace"
373 )]
374 pub fn new_full_trace(trace: T) -> Self {
375 let start_registers = trace.start_registers();
376 let start_pc = trace.pc_start();
377 let start_clk = trace.clk_start();
378
379 tracing::trace!("start_pc: {}", start_pc);
380 tracing::trace!("start_clk: {}", start_clk);
381 tracing::trace!("trace.num_mem_reads(): {}", trace.num_mem_reads());
382
383 Self::new(trace, start_registers, start_pc, start_clk, 0)
384 }
385
386 pub fn set_last_clk(&mut self, clk: u64) {
388 self.last_clk = clk;
389 }
390
391 pub fn set_last_mem_reads_idx(&mut self, mem_reads_idx: usize) {
393 self.last_mem_reads_idx = mem_reads_idx;
394 }
395}
396
397impl<T: MinimalTrace> MinimalTrace for SplicedMinimalTrace<T> {
398 fn start_registers(&self) -> [u64; 32] {
399 self.start_registers
400 }
401
402 fn pc_start(&self) -> u64 {
403 self.start_pc
404 }
405
406 fn clk_start(&self) -> u64 {
407 self.start_clk
408 }
409
410 fn clk_end(&self) -> u64 {
411 self.last_clk
412 }
413
414 fn num_mem_reads(&self) -> u64 {
415 self.inner.num_mem_reads() - self.memory_reads_idx as u64
416 }
417
418 fn mem_reads(&self) -> MemReads<'_> {
419 let mut reads = self.inner.mem_reads();
420 reads.advance(self.memory_reads_idx);
421
422 reads
423 }
424}
425
426impl<T: MinimalTrace> Serialize for SplicedMinimalTrace<T> {
427 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
428 where
429 S: serde::Serializer,
430 {
431 let len = self.last_mem_reads_idx - self.memory_reads_idx;
432 let mem_reads = unsafe {
433 let mem_reads_buf = Arc::new_uninit_slice(len);
434 let start_mem_reads = self.mem_reads();
435 let src_ptr = start_mem_reads.head_raw();
436 std::ptr::copy_nonoverlapping(src_ptr, mem_reads_buf.as_ptr() as *mut MemValue, len);
437 mem_reads_buf.assume_init()
438 };
439
440 let trace = TraceChunk {
441 start_registers: self.start_registers,
442 pc_start: self.start_pc,
443 clk_start: self.start_clk,
444 clk_end: self.last_clk,
445 mem_reads,
446 };
447
448 trace.serialize(serializer)
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use sp1_jit::MemValue;
455
456 use super::*;
457
458 #[test]
459 fn test_serialize_spliced_minimal_trace() {
460 let trace_chunk = TraceChunk {
461 start_registers: [1; 32],
462 pc_start: 2,
463 clk_start: 3,
464 clk_end: 4,
465 mem_reads: Arc::new([MemValue { clk: 8, value: 9 }, MemValue { clk: 10, value: 11 }]),
466 };
467
468 let mut trace = SplicedMinimalTrace::new(trace_chunk, [2; 32], 2, 3, 1);
469 trace.set_last_mem_reads_idx(2);
470 trace.set_last_clk(2);
471
472 let serialized = bincode::serialize(&trace).unwrap();
473 let deserialized: TraceChunk = bincode::deserialize(&serialized).unwrap();
474
475 let expected = TraceChunk {
476 start_registers: [2; 32],
477 pc_start: 2,
478 clk_start: 3,
479 clk_end: 2,
480 mem_reads: Arc::new([MemValue { clk: 10, value: 11 }]),
481 };
482
483 assert_eq!(deserialized, expected);
484 }
485}