1use crate::shm::ConsumerGuard;
2
3use std::{marker::PhantomData, ops::Deref, sync::Arc};
4
5use memmap2::Mmap;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9#[repr(u8)]
10pub enum RiscRegister {
11 X0 = 0,
12 X1 = 1,
13 X2 = 2,
14 X3 = 3,
15 X4 = 4,
16 X5 = 5,
17 X6 = 6,
18 X7 = 7,
19 X8 = 8,
20 X9 = 9,
21 X10 = 10,
22 X11 = 11,
23 X12 = 12,
24 X13 = 13,
25 X14 = 14,
26 X15 = 15,
27 X16 = 16,
28 X17 = 17,
29 X18 = 18,
30 X19 = 19,
31 X20 = 20,
32 X21 = 21,
33 X22 = 22,
34 X23 = 23,
35 X24 = 24,
36 X25 = 25,
37 X26 = 26,
38 X27 = 27,
39 X28 = 28,
40 X29 = 29,
41 X30 = 30,
42 X31 = 31,
43}
44
45impl RiscRegister {
46 pub fn all_registers() -> &'static [RiscRegister] {
47 &[
48 RiscRegister::X0,
49 RiscRegister::X1,
50 RiscRegister::X2,
51 RiscRegister::X3,
52 RiscRegister::X4,
53 RiscRegister::X5,
54 RiscRegister::X6,
55 RiscRegister::X7,
56 RiscRegister::X8,
57 RiscRegister::X9,
58 RiscRegister::X10,
59 RiscRegister::X11,
60 RiscRegister::X12,
61 RiscRegister::X13,
62 RiscRegister::X14,
63 RiscRegister::X15,
64 RiscRegister::X16,
65 RiscRegister::X17,
66 RiscRegister::X18,
67 RiscRegister::X19,
68 RiscRegister::X20,
69 RiscRegister::X21,
70 RiscRegister::X22,
71 RiscRegister::X23,
72 RiscRegister::X24,
73 RiscRegister::X25,
74 RiscRegister::X26,
75 RiscRegister::X27,
76 RiscRegister::X28,
77 RiscRegister::X29,
78 RiscRegister::X30,
79 RiscRegister::X31,
80 ]
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
86pub enum RiscOperand {
87 Register(RiscRegister),
88 Immediate(i32),
89}
90
91impl From<RiscRegister> for RiscOperand {
92 fn from(reg: RiscRegister) -> Self {
93 RiscOperand::Register(reg)
94 }
95}
96
97impl From<u32> for RiscOperand {
98 fn from(imm: u32) -> Self {
99 RiscOperand::Immediate(imm as i32)
100 }
101}
102
103impl From<i32> for RiscOperand {
104 fn from(imm: i32) -> Self {
105 RiscOperand::Immediate(imm)
106 }
107}
108
109impl From<u64> for RiscOperand {
110 fn from(imm: u64) -> Self {
111 RiscOperand::Immediate(imm as i32)
112 }
113}
114
115#[repr(C)]
116#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
117pub struct MemValue {
118 pub clk: u64,
119 pub value: u64,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
125pub struct ElfInfo {
126 pub pc_base: u64,
127 pub instruction_count: usize,
128 pub untrusted_memory: Option<(u64, u64)>,
129}
130
131impl ElfInfo {
132 #[inline]
133 pub fn enable_untrusted_program(&self) -> bool {
134 self.untrusted_memory.is_some()
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139pub struct PageProtValue {
140 pub timestamp: u64,
141 pub value: u8,
142}
143
144impl Default for PageProtValue {
145 fn default() -> Self {
146 Self { timestamp: 0, value: sp1_primitives::consts::DEFAULT_PAGE_PROT }
147 }
148}
149
150impl From<PageProtValue> for MemValue {
151 fn from(v: PageProtValue) -> MemValue {
152 MemValue { clk: v.timestamp, value: v.value as u64 }
153 }
154}
155
156impl From<MemValue> for PageProtValue {
157 fn from(v: MemValue) -> PageProtValue {
158 assert!(v.value & !0xff == 0, "value = {:x}", v.value);
159 PageProtValue { timestamp: v.clk, value: (v.value & 0xff).try_into().unwrap() }
160 }
161}
162
163#[derive(Clone, Debug, PartialEq, Eq)]
166pub struct Interrupt {
167 pub code: u64,
169}
170
171#[repr(C)]
173pub struct TraceChunkHeader {
174 pub start_registers: [u64; 32],
175 pub pc_start: u64,
176 pub clk_start: u64,
177 pub clk_end: u64,
178 pub num_mem_reads: u64,
179 pub global_clk_end: u64,
180 _padding: u64,
182}
183
184#[derive(Clone)]
185pub enum TraceChunkRaw {
186 Mmap(Arc<Mmap>),
187 Shm(Arc<ConsumerGuard>),
188}
189
190impl TraceChunkRaw {
191 pub unsafe fn new(inner: Mmap) -> Self {
197 Self::Mmap(Arc::new(inner))
198 }
199
200 pub unsafe fn from_shm(inner: ConsumerGuard) -> Self {
204 Self::Shm(Arc::new(inner))
205 }
206
207 fn as_ref(&self) -> &[u8] {
208 match self {
209 Self::Mmap(mmap) => mmap.as_ref(),
210 Self::Shm(shm) => shm.deref(),
211 }
212 }
213
214 fn as_ptr(&self) -> *const u8 {
215 match self {
216 Self::Mmap(mmap) => mmap.as_ptr(),
217 Self::Shm(shm) => shm.deref().as_ptr(),
218 }
219 }
220
221 fn len(&self) -> usize {
222 match self {
223 Self::Mmap(mmap) => mmap.len(),
224 Self::Shm(shm) => shm.deref().len(),
225 }
226 }
227
228 pub fn global_clk_end(&self) -> u64 {
234 let offset = std::mem::offset_of!(TraceChunkHeader, global_clk_end);
235
236 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
237 }
238}
239
240impl MinimalTrace for TraceChunkRaw {
241 fn start_registers(&self) -> [u64; 32] {
242 let offset = std::mem::offset_of!(TraceChunkHeader, start_registers);
243
244 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const [u64; 32]) }
245 }
246
247 fn pc_start(&self) -> u64 {
248 let offset = std::mem::offset_of!(TraceChunkHeader, pc_start);
249
250 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
251 }
252
253 fn clk_start(&self) -> u64 {
254 let offset = std::mem::offset_of!(TraceChunkHeader, clk_start);
255
256 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
257 }
258
259 fn clk_end(&self) -> u64 {
260 let offset = std::mem::offset_of!(TraceChunkHeader, clk_end);
261
262 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
263 }
264
265 fn num_mem_reads(&self) -> u64 {
266 let offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
267
268 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
269 }
270
271 fn mem_reads(&self) -> MemReads<'_> {
272 let header_end = std::mem::size_of::<TraceChunkHeader>();
273 let len = self.num_mem_reads() as usize;
274
275 debug_assert!(self.len() - header_end >= len);
276
277 unsafe { MemReads::new(self.as_ptr().add(header_end) as *const MemValue, len) }
282 }
283}
284
285pub struct MemReads<'a> {
286 inner: *const MemValue,
287 end: *const MemValue,
288 _phantom: PhantomData<&'a ()>,
290}
291
292impl<'a> MemReads<'a> {
293 pub(crate) unsafe fn new(inner: *const MemValue, len: usize) -> Self {
298 debug_assert!(inner.is_aligned(), "MemReads ptr is not aligned");
299
300 Self { inner, end: inner.add(len), _phantom: PhantomData }
301 }
302
303 pub fn advance(&mut self, n: usize) {
309 unsafe {
310 let advanced = self.inner.add(n);
311
312 if advanced > self.end {
313 panic!("Cannot advance by more than the length of the slice");
314 }
315
316 self.inner = advanced;
317 }
318 }
319
320 pub fn head_raw(&self) -> *const MemValue {
322 self.inner
323 }
324
325 #[must_use]
327 pub fn len(&self) -> usize {
328 unsafe { self.end.offset_from_unsigned(self.inner) }
329 }
330
331 #[must_use]
333 pub fn is_empty(&self) -> bool {
334 self.inner == self.end
335 }
336}
337
338impl<'a> Iterator for MemReads<'a> {
339 type Item = MemValue;
340
341 fn next(&mut self) -> Option<Self::Item> {
342 if self.inner == self.end {
343 None
344 } else {
345 let value = unsafe { std::ptr::read(self.inner) };
346 self.inner = unsafe { self.inner.add(1) };
347
348 Some(value)
349 }
350 }
351}
352
353#[repr(C)]
364#[derive(Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
365pub struct TraceChunk {
366 pub start_registers: [u64; 32],
367 pub pc_start: u64,
368 pub clk_start: u64,
369 pub clk_end: u64,
370 #[serde(serialize_with = "ser::serialize_mem_reads")]
371 #[serde(deserialize_with = "ser::deserialize_mem_reads")]
372 pub mem_reads: Arc<[MemValue]>,
373}
374
375impl From<TraceChunkRaw> for TraceChunk {
376 fn from(raw: TraceChunkRaw) -> Self {
377 TraceChunk::copy_from_bytes(raw.as_ref())
378 }
379}
380
381impl TraceChunk {
382 pub fn copy_from_bytes(src: &[u8]) -> Self {
389 const HDR: usize = size_of::<TraceChunkHeader>();
390
391 if src.len() < HDR {
393 panic!("TraceChunk header too small");
394 }
395
396 let raw: TraceChunkHeader =
403 unsafe { core::ptr::read_unaligned(src.as_ptr() as *const TraceChunkHeader) };
404
405 let n_words = raw.num_mem_reads as usize;
407 let n_bytes = n_words.checked_mul(size_of::<MemValue>()).expect("Num mem reads too large");
408 let total = HDR.checked_add(n_bytes).expect("Num mem reads too large");
409 if src.len() < total {
410 panic!("TraceChunk tail too small");
411 }
412
413 let tail = &src[HDR..total]; let mem_reads = Arc::new_uninit_slice(n_words);
417
418 unsafe {
430 std::ptr::copy_nonoverlapping(tail.as_ptr(), mem_reads.as_ptr() as *mut u8, n_bytes)
431 };
432
433 Self {
434 start_registers: raw.start_registers,
435 pc_start: raw.pc_start,
436 clk_start: raw.clk_start,
437 clk_end: raw.clk_end,
438 mem_reads: unsafe { mem_reads.assume_init() },
440 }
441 }
442}
443
444pub trait MinimalTrace: Clone + Send + Sync + 'static {
451 fn start_registers(&self) -> [u64; 32];
452
453 fn pc_start(&self) -> u64;
454
455 fn clk_start(&self) -> u64;
456
457 fn clk_end(&self) -> u64;
458
459 fn num_mem_reads(&self) -> u64;
460
461 fn mem_reads(&self) -> MemReads<'_>;
462}
463
464impl MinimalTrace for TraceChunk {
465 fn start_registers(&self) -> [u64; 32] {
466 self.start_registers
467 }
468
469 fn pc_start(&self) -> u64 {
470 self.pc_start
471 }
472
473 fn clk_start(&self) -> u64 {
474 self.clk_start
475 }
476
477 fn clk_end(&self) -> u64 {
478 self.clk_end
479 }
480
481 fn num_mem_reads(&self) -> u64 {
482 self.mem_reads.len() as u64
483 }
484
485 fn mem_reads(&self) -> MemReads<'_> {
486 unsafe { MemReads::new(self.mem_reads.as_ptr(), self.mem_reads.len()) }
491 }
492}
493
494mod ser {
495 use super::*;
496 use serde::{Deserializer, Serializer};
497
498 pub fn serialize_mem_reads<S: Serializer>(
499 mem_reads: &Arc<[MemValue]>,
500 serializer: S,
501 ) -> Result<S::Ok, S::Error> {
502 let as_vec: Vec<MemValue> = Vec::from(&mem_reads[..]);
503
504 Vec::serialize(&as_vec, serializer)
505 }
506
507 pub fn deserialize_mem_reads<'a, D: Deserializer<'a>>(
508 deserializer: D,
509 ) -> Result<Arc<[MemValue]>, D::Error> {
510 let as_vec = Vec::deserialize(deserializer)?;
511
512 Ok(as_vec.into())
513 }
514
515 #[test]
516 #[cfg(test)]
517 fn test_mem_reads() {
518 let mem_reads = Arc::new([MemValue { clk: 0, value: 0 }, MemValue { clk: 1, value: 1 }]);
519 let trace = TraceChunk {
520 start_registers: [5; 32],
521 pc_start: 6,
522 clk_start: 7,
523 clk_end: 8,
524 mem_reads,
525 };
526
527 let serialized = bincode::serialize(&trace).unwrap();
528 let deserialized = bincode::deserialize(&serialized).unwrap();
529
530 assert_eq!(trace, deserialized);
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
540 fn test_trace_chunk_header_alignment() {
541 assert_eq!(std::mem::size_of::<TraceChunkHeader>() % 16, 0);
542 }
543}