1use std::{marker::PhantomData, sync::Arc};
2
3use memmap2::Mmap;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7#[repr(u8)]
8pub enum RiscRegister {
9 X0 = 0,
10 X1 = 1,
11 X2 = 2,
12 X3 = 3,
13 X4 = 4,
14 X5 = 5,
15 X6 = 6,
16 X7 = 7,
17 X8 = 8,
18 X9 = 9,
19 X10 = 10,
20 X11 = 11,
21 X12 = 12,
22 X13 = 13,
23 X14 = 14,
24 X15 = 15,
25 X16 = 16,
26 X17 = 17,
27 X18 = 18,
28 X19 = 19,
29 X20 = 20,
30 X21 = 21,
31 X22 = 22,
32 X23 = 23,
33 X24 = 24,
34 X25 = 25,
35 X26 = 26,
36 X27 = 27,
37 X28 = 28,
38 X29 = 29,
39 X30 = 30,
40 X31 = 31,
41}
42
43impl RiscRegister {
44 pub fn all_registers() -> &'static [RiscRegister] {
45 &[
46 RiscRegister::X0,
47 RiscRegister::X1,
48 RiscRegister::X2,
49 RiscRegister::X3,
50 RiscRegister::X4,
51 RiscRegister::X5,
52 RiscRegister::X6,
53 RiscRegister::X7,
54 RiscRegister::X8,
55 RiscRegister::X9,
56 RiscRegister::X10,
57 RiscRegister::X11,
58 RiscRegister::X12,
59 RiscRegister::X13,
60 RiscRegister::X14,
61 RiscRegister::X15,
62 RiscRegister::X16,
63 RiscRegister::X17,
64 RiscRegister::X18,
65 RiscRegister::X19,
66 RiscRegister::X20,
67 RiscRegister::X21,
68 RiscRegister::X22,
69 RiscRegister::X23,
70 RiscRegister::X24,
71 RiscRegister::X25,
72 RiscRegister::X26,
73 RiscRegister::X27,
74 RiscRegister::X28,
75 RiscRegister::X29,
76 RiscRegister::X30,
77 RiscRegister::X31,
78 ]
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
84pub enum RiscOperand {
85 Register(RiscRegister),
86 Immediate(i32),
87}
88
89impl From<RiscRegister> for RiscOperand {
90 fn from(reg: RiscRegister) -> Self {
91 RiscOperand::Register(reg)
92 }
93}
94
95impl From<u32> for RiscOperand {
96 fn from(imm: u32) -> Self {
97 RiscOperand::Immediate(imm as i32)
98 }
99}
100
101impl From<i32> for RiscOperand {
102 fn from(imm: i32) -> Self {
103 RiscOperand::Immediate(imm)
104 }
105}
106
107impl From<u64> for RiscOperand {
108 fn from(imm: u64) -> Self {
109 RiscOperand::Immediate(imm as i32)
110 }
111}
112
113#[repr(C)]
114#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115pub struct MemValue {
116 pub clk: u64,
117 pub value: u64,
118}
119
120#[repr(C)]
122pub struct TraceChunkHeader {
123 pub start_registers: [u64; 32],
124 pub pc_start: u64,
125 pub clk_start: u64,
126 pub clk_end: u64,
127 pub num_mem_reads: u64,
128}
129
130#[repr(C)]
131#[derive(Clone)]
132pub struct TraceChunkRaw {
133 inner: Arc<Mmap>,
134 hint_lens: Vec<usize>,
135}
136
137impl TraceChunkRaw {
138 pub unsafe fn new(inner: Mmap, hint_lens: Vec<usize>) -> Self {
144 Self { inner: Arc::new(inner), hint_lens }
145 }
146}
147
148impl MinimalTrace for TraceChunkRaw {
149 fn start_registers(&self) -> [u64; 32] {
150 let offset = std::mem::offset_of!(TraceChunkHeader, start_registers);
151
152 unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const [u64; 32]) }
153 }
154
155 fn pc_start(&self) -> u64 {
156 let offset = std::mem::offset_of!(TraceChunkHeader, pc_start);
157
158 unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const u64) }
159 }
160
161 fn clk_start(&self) -> u64 {
162 let offset = std::mem::offset_of!(TraceChunkHeader, clk_start);
163
164 unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const u64) }
165 }
166
167 fn clk_end(&self) -> u64 {
168 let offset = std::mem::offset_of!(TraceChunkHeader, clk_end);
169
170 unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const u64) }
171 }
172
173 fn num_mem_reads(&self) -> u64 {
174 let offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
175
176 unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const u64) }
177 }
178
179 fn mem_reads(&self) -> MemReads<'_> {
180 let header_end = std::mem::size_of::<TraceChunkHeader>();
181 let len = self.num_mem_reads() as usize;
182
183 debug_assert!(self.inner.len() - header_end >= len);
184
185 unsafe { MemReads::new(self.inner.as_ptr().add(header_end) as *const MemValue, len) }
190 }
191
192 fn hint_lens(&self) -> &[usize] {
193 &self.hint_lens
194 }
195}
196
197pub struct MemReads<'a> {
198 inner: *const MemValue,
199 end: *const MemValue,
200 _phantom: PhantomData<&'a ()>,
202}
203
204impl<'a> MemReads<'a> {
205 pub(crate) unsafe fn new(inner: *const MemValue, len: usize) -> Self {
210 debug_assert!(inner.is_aligned(), "MemReads ptr is not aligned");
211
212 Self { inner, end: inner.add(len), _phantom: PhantomData }
213 }
214
215 pub fn advance(&mut self, n: usize) {
221 unsafe {
222 let advanced = self.inner.add(n);
223
224 if advanced > self.end {
225 panic!("Cannot advance by more than the length of the slice");
226 }
227
228 self.inner = advanced;
229 }
230 }
231
232 pub fn head_raw(&self) -> *const MemValue {
234 self.inner
235 }
236
237 #[must_use]
239 pub fn len(&self) -> usize {
240 unsafe { self.end.offset_from_unsigned(self.inner) }
241 }
242
243 #[must_use]
245 pub fn is_empty(&self) -> bool {
246 self.inner == self.end
247 }
248}
249
250impl<'a> Iterator for MemReads<'a> {
251 type Item = MemValue;
252
253 fn next(&mut self) -> Option<Self::Item> {
254 if self.inner == self.end {
255 None
256 } else {
257 let value = unsafe { std::ptr::read(self.inner) };
258 self.inner = unsafe { self.inner.add(1) };
259
260 Some(value)
261 }
262 }
263}
264
265#[repr(C)]
276#[derive(Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
277pub struct TraceChunk {
278 pub start_registers: [u64; 32],
279 pub pc_start: u64,
280 pub clk_start: u64,
281 pub clk_end: u64,
282 pub hint_lens: Vec<usize>,
283 #[serde(serialize_with = "ser::serialize_mem_reads")]
284 #[serde(deserialize_with = "ser::deserialize_mem_reads")]
285 pub mem_reads: Arc<[MemValue]>,
286}
287
288impl From<TraceChunkRaw> for TraceChunk {
289 fn from(raw: TraceChunkRaw) -> Self {
290 TraceChunk::copy_from_bytes(raw.hint_lens, raw.inner.as_ref())
291 }
292}
293
294impl TraceChunk {
295 pub fn copy_from_bytes(hint_lens: Vec<usize>, src: &[u8]) -> Self {
302 const HDR: usize = size_of::<TraceChunkHeader>();
303
304 if src.len() < HDR {
306 panic!("TraceChunk header too small");
307 }
308
309 let raw: TraceChunkHeader =
316 unsafe { core::ptr::read_unaligned(src.as_ptr() as *const TraceChunkHeader) };
317
318 let n_words = raw.num_mem_reads as usize;
320 let n_bytes = n_words.checked_mul(size_of::<MemValue>()).expect("Num mem reads too large");
321 let total = HDR.checked_add(n_bytes).expect("Num mem reads too large");
322 if src.len() < total {
323 panic!("TraceChunk tail too small");
324 }
325
326 let tail = &src[HDR..total]; let mem_reads = Arc::new_uninit_slice(n_words);
330
331 unsafe {
343 std::ptr::copy_nonoverlapping(tail.as_ptr(), mem_reads.as_ptr() as *mut u8, n_bytes)
344 };
345
346 Self {
347 start_registers: raw.start_registers,
348 pc_start: raw.pc_start,
349 clk_start: raw.clk_start,
350 clk_end: raw.clk_end,
351 hint_lens,
352 mem_reads: unsafe { mem_reads.assume_init() },
354 }
355 }
356}
357
358pub trait MinimalTrace: Clone + Send + Sync + 'static {
365 fn start_registers(&self) -> [u64; 32];
366
367 fn pc_start(&self) -> u64;
368
369 fn clk_start(&self) -> u64;
370
371 fn clk_end(&self) -> u64;
372
373 fn num_mem_reads(&self) -> u64;
374
375 fn mem_reads(&self) -> MemReads<'_>;
376
377 fn hint_lens(&self) -> &[usize];
378}
379
380impl MinimalTrace for TraceChunk {
381 fn start_registers(&self) -> [u64; 32] {
382 self.start_registers
383 }
384
385 fn pc_start(&self) -> u64 {
386 self.pc_start
387 }
388
389 fn clk_start(&self) -> u64 {
390 self.clk_start
391 }
392
393 fn clk_end(&self) -> u64 {
394 self.clk_end
395 }
396
397 fn num_mem_reads(&self) -> u64 {
398 self.mem_reads.len() as u64
399 }
400
401 fn mem_reads(&self) -> MemReads<'_> {
402 unsafe { MemReads::new(self.mem_reads.as_ptr(), self.mem_reads.len()) }
407 }
408
409 fn hint_lens(&self) -> &[usize] {
410 &self.hint_lens
411 }
412}
413
414mod ser {
415 use super::*;
416 use serde::{Deserializer, Serializer};
417
418 pub fn serialize_mem_reads<S: Serializer>(
419 mem_reads: &Arc<[MemValue]>,
420 serializer: S,
421 ) -> Result<S::Ok, S::Error> {
422 let as_vec: Vec<MemValue> = Vec::from(&mem_reads[..]);
423
424 Vec::serialize(&as_vec, serializer)
425 }
426
427 pub fn deserialize_mem_reads<'a, D: Deserializer<'a>>(
428 deserializer: D,
429 ) -> Result<Arc<[MemValue]>, D::Error> {
430 let as_vec = Vec::deserialize(deserializer)?;
431
432 Ok(as_vec.into())
433 }
434
435 #[test]
436 #[cfg(test)]
437 fn test_mem_reads() {
438 let mem_reads = Arc::new([MemValue { clk: 0, value: 0 }, MemValue { clk: 1, value: 1 }]);
439 let trace = TraceChunk {
440 start_registers: [5; 32],
441 pc_start: 6,
442 clk_start: 7,
443 clk_end: 8,
444 hint_lens: vec![1, 2, 3],
445 mem_reads,
446 };
447
448 let serialized = bincode::serialize(&trace).unwrap();
449 let deserialized = bincode::deserialize(&serialized).unwrap();
450
451 assert_eq!(trace, deserialized);
452 }
453}