1use crate::shm::ConsumerGuard;
2
3use std::{marker::PhantomData, ops::Deref, sync::Arc};
4
5use memmap2::Mmap;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9#[repr(u8)]
10pub enum RiscRegister {
11 X0 = 0,
12 X1 = 1,
13 X2 = 2,
14 X3 = 3,
15 X4 = 4,
16 X5 = 5,
17 X6 = 6,
18 X7 = 7,
19 X8 = 8,
20 X9 = 9,
21 X10 = 10,
22 X11 = 11,
23 X12 = 12,
24 X13 = 13,
25 X14 = 14,
26 X15 = 15,
27 X16 = 16,
28 X17 = 17,
29 X18 = 18,
30 X19 = 19,
31 X20 = 20,
32 X21 = 21,
33 X22 = 22,
34 X23 = 23,
35 X24 = 24,
36 X25 = 25,
37 X26 = 26,
38 X27 = 27,
39 X28 = 28,
40 X29 = 29,
41 X30 = 30,
42 X31 = 31,
43}
44
45impl RiscRegister {
46 pub fn all_registers() -> &'static [RiscRegister] {
47 &[
48 RiscRegister::X0,
49 RiscRegister::X1,
50 RiscRegister::X2,
51 RiscRegister::X3,
52 RiscRegister::X4,
53 RiscRegister::X5,
54 RiscRegister::X6,
55 RiscRegister::X7,
56 RiscRegister::X8,
57 RiscRegister::X9,
58 RiscRegister::X10,
59 RiscRegister::X11,
60 RiscRegister::X12,
61 RiscRegister::X13,
62 RiscRegister::X14,
63 RiscRegister::X15,
64 RiscRegister::X16,
65 RiscRegister::X17,
66 RiscRegister::X18,
67 RiscRegister::X19,
68 RiscRegister::X20,
69 RiscRegister::X21,
70 RiscRegister::X22,
71 RiscRegister::X23,
72 RiscRegister::X24,
73 RiscRegister::X25,
74 RiscRegister::X26,
75 RiscRegister::X27,
76 RiscRegister::X28,
77 RiscRegister::X29,
78 RiscRegister::X30,
79 RiscRegister::X31,
80 ]
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
86pub enum RiscOperand {
87 Register(RiscRegister),
88 Immediate(i32),
89}
90
91impl From<RiscRegister> for RiscOperand {
92 fn from(reg: RiscRegister) -> Self {
93 RiscOperand::Register(reg)
94 }
95}
96
97impl From<u32> for RiscOperand {
98 fn from(imm: u32) -> Self {
99 RiscOperand::Immediate(imm as i32)
100 }
101}
102
103impl From<i32> for RiscOperand {
104 fn from(imm: i32) -> Self {
105 RiscOperand::Immediate(imm)
106 }
107}
108
109impl From<u64> for RiscOperand {
110 fn from(imm: u64) -> Self {
111 RiscOperand::Immediate(imm as i32)
112 }
113}
114
115#[repr(C)]
116#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
117pub struct MemValue {
118 pub clk: u64,
119 pub value: u64,
120}
121
122#[repr(C)]
124pub struct TraceChunkHeader {
125 pub start_registers: [u64; 32],
126 pub pc_start: u64,
127 pub clk_start: u64,
128 pub clk_end: u64,
129 pub num_mem_reads: u64,
130 pub global_clk_end: u64,
131 _padding: u64,
133}
134
135#[derive(Clone)]
136pub enum TraceChunkRaw {
137 Mmap(Arc<Mmap>),
138 Shm(Arc<ConsumerGuard>),
139}
140
141impl TraceChunkRaw {
142 pub unsafe fn new(inner: Mmap) -> Self {
148 Self::Mmap(Arc::new(inner))
149 }
150
151 pub unsafe fn from_shm(inner: ConsumerGuard) -> Self {
155 Self::Shm(Arc::new(inner))
156 }
157
158 fn as_ref(&self) -> &[u8] {
159 match self {
160 Self::Mmap(mmap) => mmap.as_ref(),
161 Self::Shm(shm) => shm.deref(),
162 }
163 }
164
165 fn as_ptr(&self) -> *const u8 {
166 match self {
167 Self::Mmap(mmap) => mmap.as_ptr(),
168 Self::Shm(shm) => shm.deref().as_ptr(),
169 }
170 }
171
172 fn len(&self) -> usize {
173 match self {
174 Self::Mmap(mmap) => mmap.len(),
175 Self::Shm(shm) => shm.deref().len(),
176 }
177 }
178
179 pub fn global_clk_end(&self) -> u64 {
185 let offset = std::mem::offset_of!(TraceChunkHeader, global_clk_end);
186
187 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
188 }
189}
190
191impl MinimalTrace for TraceChunkRaw {
192 fn start_registers(&self) -> [u64; 32] {
193 let offset = std::mem::offset_of!(TraceChunkHeader, start_registers);
194
195 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const [u64; 32]) }
196 }
197
198 fn pc_start(&self) -> u64 {
199 let offset = std::mem::offset_of!(TraceChunkHeader, pc_start);
200
201 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
202 }
203
204 fn clk_start(&self) -> u64 {
205 let offset = std::mem::offset_of!(TraceChunkHeader, clk_start);
206
207 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
208 }
209
210 fn clk_end(&self) -> u64 {
211 let offset = std::mem::offset_of!(TraceChunkHeader, clk_end);
212
213 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
214 }
215
216 fn num_mem_reads(&self) -> u64 {
217 let offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
218
219 unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
220 }
221
222 fn mem_reads(&self) -> MemReads<'_> {
223 let header_end = std::mem::size_of::<TraceChunkHeader>();
224 let len = self.num_mem_reads() as usize;
225
226 debug_assert!(self.len() - header_end >= len);
227
228 unsafe { MemReads::new(self.as_ptr().add(header_end) as *const MemValue, len) }
233 }
234}
235
236pub struct MemReads<'a> {
237 inner: *const MemValue,
238 end: *const MemValue,
239 _phantom: PhantomData<&'a ()>,
241}
242
243impl<'a> MemReads<'a> {
244 pub(crate) unsafe fn new(inner: *const MemValue, len: usize) -> Self {
249 debug_assert!(inner.is_aligned(), "MemReads ptr is not aligned");
250
251 Self { inner, end: inner.add(len), _phantom: PhantomData }
252 }
253
254 pub fn advance(&mut self, n: usize) {
260 unsafe {
261 let advanced = self.inner.add(n);
262
263 if advanced > self.end {
264 panic!("Cannot advance by more than the length of the slice");
265 }
266
267 self.inner = advanced;
268 }
269 }
270
271 pub fn head_raw(&self) -> *const MemValue {
273 self.inner
274 }
275
276 #[must_use]
278 pub fn len(&self) -> usize {
279 unsafe { self.end.offset_from_unsigned(self.inner) }
280 }
281
282 #[must_use]
284 pub fn is_empty(&self) -> bool {
285 self.inner == self.end
286 }
287}
288
289impl<'a> Iterator for MemReads<'a> {
290 type Item = MemValue;
291
292 fn next(&mut self) -> Option<Self::Item> {
293 if self.inner == self.end {
294 None
295 } else {
296 let value = unsafe { std::ptr::read(self.inner) };
297 self.inner = unsafe { self.inner.add(1) };
298
299 Some(value)
300 }
301 }
302}
303
304#[repr(C)]
315#[derive(Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
316pub struct TraceChunk {
317 pub start_registers: [u64; 32],
318 pub pc_start: u64,
319 pub clk_start: u64,
320 pub clk_end: u64,
321 #[serde(serialize_with = "ser::serialize_mem_reads")]
322 #[serde(deserialize_with = "ser::deserialize_mem_reads")]
323 pub mem_reads: Arc<[MemValue]>,
324}
325
326impl From<TraceChunkRaw> for TraceChunk {
327 fn from(raw: TraceChunkRaw) -> Self {
328 TraceChunk::copy_from_bytes(raw.as_ref())
329 }
330}
331
332impl TraceChunk {
333 pub fn copy_from_bytes(src: &[u8]) -> Self {
340 const HDR: usize = size_of::<TraceChunkHeader>();
341
342 if src.len() < HDR {
344 panic!("TraceChunk header too small");
345 }
346
347 let raw: TraceChunkHeader =
354 unsafe { core::ptr::read_unaligned(src.as_ptr() as *const TraceChunkHeader) };
355
356 let n_words = raw.num_mem_reads as usize;
358 let n_bytes = n_words.checked_mul(size_of::<MemValue>()).expect("Num mem reads too large");
359 let total = HDR.checked_add(n_bytes).expect("Num mem reads too large");
360 if src.len() < total {
361 panic!("TraceChunk tail too small");
362 }
363
364 let tail = &src[HDR..total]; let mem_reads = Arc::new_uninit_slice(n_words);
368
369 unsafe {
381 std::ptr::copy_nonoverlapping(tail.as_ptr(), mem_reads.as_ptr() as *mut u8, n_bytes)
382 };
383
384 Self {
385 start_registers: raw.start_registers,
386 pc_start: raw.pc_start,
387 clk_start: raw.clk_start,
388 clk_end: raw.clk_end,
389 mem_reads: unsafe { mem_reads.assume_init() },
391 }
392 }
393}
394
395pub trait MinimalTrace: Clone + Send + Sync + 'static {
402 fn start_registers(&self) -> [u64; 32];
403
404 fn pc_start(&self) -> u64;
405
406 fn clk_start(&self) -> u64;
407
408 fn clk_end(&self) -> u64;
409
410 fn num_mem_reads(&self) -> u64;
411
412 fn mem_reads(&self) -> MemReads<'_>;
413}
414
415impl MinimalTrace for TraceChunk {
416 fn start_registers(&self) -> [u64; 32] {
417 self.start_registers
418 }
419
420 fn pc_start(&self) -> u64 {
421 self.pc_start
422 }
423
424 fn clk_start(&self) -> u64 {
425 self.clk_start
426 }
427
428 fn clk_end(&self) -> u64 {
429 self.clk_end
430 }
431
432 fn num_mem_reads(&self) -> u64 {
433 self.mem_reads.len() as u64
434 }
435
436 fn mem_reads(&self) -> MemReads<'_> {
437 unsafe { MemReads::new(self.mem_reads.as_ptr(), self.mem_reads.len()) }
442 }
443}
444
445mod ser {
446 use super::*;
447 use serde::{Deserializer, Serializer};
448
449 pub fn serialize_mem_reads<S: Serializer>(
450 mem_reads: &Arc<[MemValue]>,
451 serializer: S,
452 ) -> Result<S::Ok, S::Error> {
453 let as_vec: Vec<MemValue> = Vec::from(&mem_reads[..]);
454
455 Vec::serialize(&as_vec, serializer)
456 }
457
458 pub fn deserialize_mem_reads<'a, D: Deserializer<'a>>(
459 deserializer: D,
460 ) -> Result<Arc<[MemValue]>, D::Error> {
461 let as_vec = Vec::deserialize(deserializer)?;
462
463 Ok(as_vec.into())
464 }
465
466 #[test]
467 #[cfg(test)]
468 fn test_mem_reads() {
469 let mem_reads = Arc::new([MemValue { clk: 0, value: 0 }, MemValue { clk: 1, value: 1 }]);
470 let trace = TraceChunk {
471 start_registers: [5; 32],
472 pc_start: 6,
473 clk_start: 7,
474 clk_end: 8,
475 mem_reads,
476 };
477
478 let serialized = bincode::serialize(&trace).unwrap();
479 let deserialized = bincode::deserialize(&serialized).unwrap();
480
481 assert_eq!(trace, deserialized);
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
491 fn test_trace_chunk_header_alignment() {
492 assert_eq!(std::mem::size_of::<TraceChunkHeader>() % 16, 0);
493 }
494}