zkvmc_core/memory/
forkable.rs

1use crate::{send_sync_ptr::SendSyncPtr, traits::Reset};
2use eyre::Result;
3use memfd::{Memfd, MemfdOptions};
4use rustix::mm::{self, MapFlags, ProtFlags};
5use std::{
6    ptr::{self, NonNull},
7    rc::Rc,
8};
9
10/// Linear memory used between Guest and VM.
11#[repr(C)]
12pub struct ForkableMemory {
13    memory: SendSyncPtr<[u8]>,
14    fd: Rc<Memfd>,
15}
16
17impl ForkableMemory {
18    pub fn new(size: usize) -> Result<Self> {
19        let mfd = MemfdOptions::default().create(format!("sized-{size}"))?;
20        mfd.as_file().set_len(size as u64)?;
21        let ptr = unsafe {
22            mm::mmap(
23                ptr::null_mut(),
24                size,
25                ProtFlags::READ | ProtFlags::WRITE,
26                MapFlags::SHARED | MapFlags::NORESERVE,
27                mfd.as_file(),
28                0,
29            )?
30        };
31
32        let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size);
33
34        Ok(Self {
35            memory: NonNull::new(memory).unwrap().into(),
36            fd: Rc::new(mfd),
37        })
38    }
39
40    #[inline]
41    pub fn as_ptr(&self) -> *mut u8 {
42        self.memory.cast().as_ptr()
43    }
44
45    #[inline]
46    pub fn as_send_sync_ptr(&self) -> SendSyncPtr<u8> {
47        self.memory.cast()
48    }
49
50    /// Copy-on-write fork the memory used for revert execution, like Sp1 unconstrained.
51    pub fn fork(&self) -> Result<Self> {
52        let ptr = unsafe {
53            mm::mmap(
54                ptr::null_mut(),
55                self.memory.len(),
56                ProtFlags::READ | ProtFlags::WRITE,
57                MapFlags::PRIVATE | MapFlags::NORESERVE,
58                self.fd.as_file(),
59                0,
60            )?
61        };
62
63        let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), self.memory.len());
64
65        Ok(Self {
66            memory: NonNull::new(memory).unwrap().into(),
67            fd: self.fd.clone(),
68        })
69    }
70}
71
72impl Reset for ForkableMemory {
73    fn reset(&mut self) {
74        let size = self.memory.len();
75        // NB: dealloc memory by reset to zero?
76        self.fd.as_file().set_len(0).unwrap();
77        self.fd.as_file().set_len(size as u64).unwrap();
78        let ptr = unsafe {
79            mm::mmap(
80                self.as_ptr().cast(),
81                size,
82                ProtFlags::READ | ProtFlags::WRITE,
83                MapFlags::SHARED | MapFlags::FIXED | MapFlags::NORESERVE,
84                self.fd.as_file(),
85                0,
86            )
87            .unwrap()
88        };
89        let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size);
90        self.memory = NonNull::new(memory).unwrap().into();
91    }
92}
93
94impl Drop for ForkableMemory {
95    fn drop(&mut self) {
96        unsafe {
97            let ptr = self.memory.as_ptr().cast();
98            let len = self.memory.len();
99            if len == 0 {
100                return;
101            }
102            rustix::mm::munmap(ptr, len).expect("munmap failed");
103        }
104    }
105}