Skip to main content

sp1_jit/
risc.rs

1use crate::shm::ConsumerGuard;
2
3use std::{marker::PhantomData, ops::Deref, sync::Arc};
4
5use memmap2::Mmap;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9#[repr(u8)]
10pub enum RiscRegister {
11    X0 = 0,
12    X1 = 1,
13    X2 = 2,
14    X3 = 3,
15    X4 = 4,
16    X5 = 5,
17    X6 = 6,
18    X7 = 7,
19    X8 = 8,
20    X9 = 9,
21    X10 = 10,
22    X11 = 11,
23    X12 = 12,
24    X13 = 13,
25    X14 = 14,
26    X15 = 15,
27    X16 = 16,
28    X17 = 17,
29    X18 = 18,
30    X19 = 19,
31    X20 = 20,
32    X21 = 21,
33    X22 = 22,
34    X23 = 23,
35    X24 = 24,
36    X25 = 25,
37    X26 = 26,
38    X27 = 27,
39    X28 = 28,
40    X29 = 29,
41    X30 = 30,
42    X31 = 31,
43}
44
45impl RiscRegister {
46    pub fn all_registers() -> &'static [RiscRegister] {
47        &[
48            RiscRegister::X0,
49            RiscRegister::X1,
50            RiscRegister::X2,
51            RiscRegister::X3,
52            RiscRegister::X4,
53            RiscRegister::X5,
54            RiscRegister::X6,
55            RiscRegister::X7,
56            RiscRegister::X8,
57            RiscRegister::X9,
58            RiscRegister::X10,
59            RiscRegister::X11,
60            RiscRegister::X12,
61            RiscRegister::X13,
62            RiscRegister::X14,
63            RiscRegister::X15,
64            RiscRegister::X16,
65            RiscRegister::X17,
66            RiscRegister::X18,
67            RiscRegister::X19,
68            RiscRegister::X20,
69            RiscRegister::X21,
70            RiscRegister::X22,
71            RiscRegister::X23,
72            RiscRegister::X24,
73            RiscRegister::X25,
74            RiscRegister::X26,
75            RiscRegister::X27,
76            RiscRegister::X28,
77            RiscRegister::X29,
78            RiscRegister::X30,
79            RiscRegister::X31,
80        ]
81    }
82}
83
84/// ALU operations can either have register or immediate operands.
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
86pub enum RiscOperand {
87    Register(RiscRegister),
88    Immediate(i32),
89}
90
91impl From<RiscRegister> for RiscOperand {
92    fn from(reg: RiscRegister) -> Self {
93        RiscOperand::Register(reg)
94    }
95}
96
97impl From<u32> for RiscOperand {
98    fn from(imm: u32) -> Self {
99        RiscOperand::Immediate(imm as i32)
100    }
101}
102
103impl From<i32> for RiscOperand {
104    fn from(imm: i32) -> Self {
105        RiscOperand::Immediate(imm)
106    }
107}
108
109impl From<u64> for RiscOperand {
110    fn from(imm: u64) -> Self {
111        RiscOperand::Immediate(imm as i32)
112    }
113}
114
115#[repr(C)]
116#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
117pub struct MemValue {
118    pub clk: u64,
119    pub value: u64,
120}
121
122/// Basic elf information. It extracts information from Program, but
123/// does not introduce dependency on Program.
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
125pub struct ElfInfo {
126    pub pc_base: u64,
127    pub instruction_count: usize,
128    pub untrusted_memory: Option<(u64, u64)>,
129}
130
131impl ElfInfo {
132    #[inline]
133    pub fn enable_untrusted_program(&self) -> bool {
134        self.untrusted_memory.is_some()
135    }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
139pub struct PageProtValue {
140    pub timestamp: u64,
141    pub value: u8,
142}
143
144impl Default for PageProtValue {
145    fn default() -> Self {
146        Self { timestamp: 0, value: sp1_primitives::consts::DEFAULT_PAGE_PROT }
147    }
148}
149
150impl From<PageProtValue> for MemValue {
151    fn from(v: PageProtValue) -> MemValue {
152        MemValue { clk: v.timestamp, value: v.value as u64 }
153    }
154}
155
156impl From<MemValue> for PageProtValue {
157    fn from(v: MemValue) -> PageProtValue {
158        assert!(v.value & !0xff == 0, "value = {:x}", v.value);
159        PageProtValue { timestamp: v.clk, value: (v.value & 0xff).try_into().unwrap() }
160    }
161}
162
163/// A RISC-V interrupt, right now we are only doing trap with this
164/// structure but it might be expanded later.
165#[derive(Clone, Debug, PartialEq, Eq)]
166pub struct Interrupt {
167    /// Trap code
168    pub code: u64,
169}
170
171/// A convience structure for getting offsets of fields in the actual [TraceChunk].
172#[repr(C)]
173pub struct TraceChunkHeader {
174    pub start_registers: [u64; 32],
175    pub pc_start: u64,
176    pub clk_start: u64,
177    pub clk_end: u64,
178    pub num_mem_reads: u64,
179    pub global_clk_end: u64,
180    // This ensures TraceChunkHeader is aligned to 16 bytes.
181    _padding: u64,
182}
183
184#[derive(Clone)]
185pub enum TraceChunkRaw {
186    Mmap(Arc<Mmap>),
187    Shm(Arc<ConsumerGuard>),
188}
189
190impl TraceChunkRaw {
191    /// # Safety
192    ///
193    /// - The mmap must be a valid [`TraceChunkHeader`].
194    /// - The mmap must contain valid [`MemValue`]s in after the header.
195    /// - The `num_mem_reads` must be the number of [`MemValue`]s in the mmap after the header.
196    pub unsafe fn new(inner: Mmap) -> Self {
197        Self::Mmap(Arc::new(inner))
198    }
199
200    /// # Safety
201    ///
202    /// See TraceChunkRaw::new for similar safety requirements
203    pub unsafe fn from_shm(inner: ConsumerGuard) -> Self {
204        Self::Shm(Arc::new(inner))
205    }
206
207    fn as_ref(&self) -> &[u8] {
208        match self {
209            Self::Mmap(mmap) => mmap.as_ref(),
210            Self::Shm(shm) => shm.deref(),
211        }
212    }
213
214    fn as_ptr(&self) -> *const u8 {
215        match self {
216            Self::Mmap(mmap) => mmap.as_ptr(),
217            Self::Shm(shm) => shm.deref().as_ptr(),
218        }
219    }
220
221    fn len(&self) -> usize {
222        match self {
223            Self::Mmap(mmap) => mmap.len(),
224            Self::Shm(shm) => shm.deref().len(),
225        }
226    }
227
228    /// Fetching global_clk when trace ends.
229    /// For now, only native executor requires this data. So we implemented
230    /// it as a method of TraceChunkRaw, not as a trait method of MinimalTrace.
231    /// Adding it to MinimalTrace would complicate SplicingVM, while SplicingVM
232    /// does not really need global_clk now.
233    pub fn global_clk_end(&self) -> u64 {
234        let offset = std::mem::offset_of!(TraceChunkHeader, global_clk_end);
235
236        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
237    }
238}
239
240impl MinimalTrace for TraceChunkRaw {
241    fn start_registers(&self) -> [u64; 32] {
242        let offset = std::mem::offset_of!(TraceChunkHeader, start_registers);
243
244        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const [u64; 32]) }
245    }
246
247    fn pc_start(&self) -> u64 {
248        let offset = std::mem::offset_of!(TraceChunkHeader, pc_start);
249
250        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
251    }
252
253    fn clk_start(&self) -> u64 {
254        let offset = std::mem::offset_of!(TraceChunkHeader, clk_start);
255
256        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
257    }
258
259    fn clk_end(&self) -> u64 {
260        let offset = std::mem::offset_of!(TraceChunkHeader, clk_end);
261
262        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
263    }
264
265    fn num_mem_reads(&self) -> u64 {
266        let offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
267
268        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
269    }
270
271    fn mem_reads(&self) -> MemReads<'_> {
272        let header_end = std::mem::size_of::<TraceChunkHeader>();
273        let len = self.num_mem_reads() as usize;
274
275        debug_assert!(self.len() - header_end >= len);
276
277        // SAFETY:
278        // - The memory is valid assuming num_mem_reads is correct.
279        // - The memory is technically always valid for reads since all bitpatterns are valid for
280        //   `MemValue`.
281        unsafe { MemReads::new(self.as_ptr().add(header_end) as *const MemValue, len) }
282    }
283}
284
285pub struct MemReads<'a> {
286    inner: *const MemValue,
287    end: *const MemValue,
288    /// Capture the lifetime of the buffer for saftey reasons.
289    _phantom: PhantomData<&'a ()>,
290}
291
292impl<'a> MemReads<'a> {
293    /// # Safety
294    ///
295    /// - The underlying memory is valid and contains valid `MemValue`s.
296    /// - The length is the number of `MemValue`s in the underlying memory.
297    pub(crate) unsafe fn new(inner: *const MemValue, len: usize) -> Self {
298        debug_assert!(inner.is_aligned(), "MemReads ptr is not aligned");
299
300        Self { inner, end: inner.add(len), _phantom: PhantomData }
301    }
302
303    /// Advance the pointer by `n` elements.
304    ///
305    /// # Panics
306    ///
307    /// Panics if `n` is greater than the purported length of the underlying buffer.
308    pub fn advance(&mut self, n: usize) {
309        unsafe {
310            let advanced = self.inner.add(n);
311
312            if advanced > self.end {
313                panic!("Cannot advance by more than the length of the slice");
314            }
315
316            self.inner = advanced;
317        }
318    }
319
320    /// Get the raw pointer to the head of the slice.
321    pub fn head_raw(&self) -> *const MemValue {
322        self.inner
323    }
324
325    /// The remaining length of the slice from our current position.
326    #[must_use]
327    pub fn len(&self) -> usize {
328        unsafe { self.end.offset_from_unsigned(self.inner) }
329    }
330
331    /// Check if the iterator is empty.
332    #[must_use]
333    pub fn is_empty(&self) -> bool {
334        self.inner == self.end
335    }
336}
337
338impl<'a> Iterator for MemReads<'a> {
339    type Item = MemValue;
340
341    fn next(&mut self) -> Option<Self::Item> {
342        if self.inner == self.end {
343            None
344        } else {
345            let value = unsafe { std::ptr::read(self.inner) };
346            self.inner = unsafe { self.inner.add(1) };
347
348            Some(value)
349        }
350    }
351}
352
353/// A trace chunk is all the data needed to continue the execution of a program at
354/// pc_start/clk_start.
355///
356/// We transmute this type directly from bytes, and the buffer should be of [TraceChunkRaw] form,
357/// plus, a slice of the memory reads.
358///
359/// When we read this type from the buffer, we will copy the registers, the pc/clk start and end,
360/// and take a pointer to the memory reads, by reading the num_mem_vals field.
361///
362/// The fields should be placed in the buffer according to the layout of [TraceChunkRaw].
363#[repr(C)]
364#[derive(Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
365pub struct TraceChunk {
366    pub start_registers: [u64; 32],
367    pub pc_start: u64,
368    pub clk_start: u64,
369    pub clk_end: u64,
370    #[serde(serialize_with = "ser::serialize_mem_reads")]
371    #[serde(deserialize_with = "ser::deserialize_mem_reads")]
372    pub mem_reads: Arc<[MemValue]>,
373}
374
375impl From<TraceChunkRaw> for TraceChunk {
376    fn from(raw: TraceChunkRaw) -> Self {
377        TraceChunk::copy_from_bytes(raw.as_ref())
378    }
379}
380
381impl TraceChunk {
382    /// Copy the bytes into a [TraceChunk]. We dont just back it with the original bytes,
383    /// since this type is likely to be sent off to worker for proving.
384    ///
385    /// # Note:
386    /// This method will panic if the buffer is not large enough,
387    /// or the number of reads causes an overflow.
388    pub fn copy_from_bytes(src: &[u8]) -> Self {
389        const HDR: usize = size_of::<TraceChunkHeader>();
390
391        /* ---------- 1. header must fit ---------- */
392        if src.len() < HDR {
393            panic!("TraceChunk header too small");
394        }
395
396        /* ---------- 2. copy-out the header ---------- */
397        // SAFETY:
398        // we just checked that `src` contains at least `HDR` bytes,
399        // and `read_unaligne
400        //
401        // Note: All bit patterns are valid for `TraceChunkRaw`.
402        let raw: TraceChunkHeader =
403            unsafe { core::ptr::read_unaligned(src.as_ptr() as *const TraceChunkHeader) };
404
405        /* ---------- 3. tail must fit ---------- */
406        let n_words = raw.num_mem_reads as usize;
407        let n_bytes = n_words.checked_mul(size_of::<MemValue>()).expect("Num mem reads too large");
408        let total = HDR.checked_add(n_bytes).expect("Num mem reads too large");
409        if src.len() < total {
410            panic!("TraceChunk tail too small");
411        }
412
413        /* ---------- 4. extract tail ---------- */
414        let tail = &src[HDR..total]; // only after the length check
415
416        let mem_reads = Arc::new_uninit_slice(n_words);
417
418        // SAFETY:
419        // - The tail contains valid u64s, so doing a bitwise copy preserves the validity and
420        //   endianness.
421        // - tail is likely unaligned, so casting to a u8 pointer gives the alignmnt guarantee the
422        //   compiler needs to do a copy.
423        // - `mem_reads` was just allocated to have enough space.
424        // - u8 has minimum alignment, so casting the pointer allocated by the vec is valid.
425        // - The cast from const -> mut is valid since there are no other references to the memory.
426        //
427        // This trick is mostly taken from [`std::ptr::read_unaligned`]
428        // see: <https://doc.rust-lang.org/src/core/ptr/mod.rs.html#1811>.
429        unsafe {
430            std::ptr::copy_nonoverlapping(tail.as_ptr(), mem_reads.as_ptr() as *mut u8, n_bytes)
431        };
432
433        Self {
434            start_registers: raw.start_registers,
435            pc_start: raw.pc_start,
436            clk_start: raw.clk_start,
437            clk_end: raw.clk_end,
438            // SAFETY: We know the memory is initialized, so we can assume it.
439            mem_reads: unsafe { mem_reads.assume_init() },
440        }
441    }
442}
443
444/// A trait that represents a minimal trace.
445///
446/// A minimal trace is the minimum required information to rexecute from
447/// `pc_start` and `clk_start` -> `clk_end`.
448///
449/// It effectively acts as an oracle for the results of memory read operations.
450pub trait MinimalTrace: Clone + Send + Sync + 'static {
451    fn start_registers(&self) -> [u64; 32];
452
453    fn pc_start(&self) -> u64;
454
455    fn clk_start(&self) -> u64;
456
457    fn clk_end(&self) -> u64;
458
459    fn num_mem_reads(&self) -> u64;
460
461    fn mem_reads(&self) -> MemReads<'_>;
462}
463
464impl MinimalTrace for TraceChunk {
465    fn start_registers(&self) -> [u64; 32] {
466        self.start_registers
467    }
468
469    fn pc_start(&self) -> u64 {
470        self.pc_start
471    }
472
473    fn clk_start(&self) -> u64 {
474        self.clk_start
475    }
476
477    fn clk_end(&self) -> u64 {
478        self.clk_end
479    }
480
481    fn num_mem_reads(&self) -> u64 {
482        self.mem_reads.len() as u64
483    }
484
485    fn mem_reads(&self) -> MemReads<'_> {
486        // SAFETY:
487        // - The memory is technically always valid for reads since all bitpatterns are valid for
488        //   `MemValue`.
489        // - the length comes directly from the Vec, which we know to be valid.
490        unsafe { MemReads::new(self.mem_reads.as_ptr(), self.mem_reads.len()) }
491    }
492}
493
494mod ser {
495    use super::*;
496    use serde::{Deserializer, Serializer};
497
498    pub fn serialize_mem_reads<S: Serializer>(
499        mem_reads: &Arc<[MemValue]>,
500        serializer: S,
501    ) -> Result<S::Ok, S::Error> {
502        let as_vec: Vec<MemValue> = Vec::from(&mem_reads[..]);
503
504        Vec::serialize(&as_vec, serializer)
505    }
506
507    pub fn deserialize_mem_reads<'a, D: Deserializer<'a>>(
508        deserializer: D,
509    ) -> Result<Arc<[MemValue]>, D::Error> {
510        let as_vec = Vec::deserialize(deserializer)?;
511
512        Ok(as_vec.into())
513    }
514
515    #[test]
516    #[cfg(test)]
517    fn test_mem_reads() {
518        let mem_reads = Arc::new([MemValue { clk: 0, value: 0 }, MemValue { clk: 1, value: 1 }]);
519        let trace = TraceChunk {
520            start_registers: [5; 32],
521            pc_start: 6,
522            clk_start: 7,
523            clk_end: 8,
524            mem_reads,
525        };
526
527        let serialized = bincode::serialize(&trace).unwrap();
528        let deserialized = bincode::deserialize(&serialized).unwrap();
529
530        assert_eq!(trace, deserialized);
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    // TraceChunkHeader must be aligned to 16 bytes
539    #[test]
540    fn test_trace_chunk_header_alignment() {
541        assert_eq!(std::mem::size_of::<TraceChunkHeader>() % 16, 0);
542    }
543}