zkvmc_core/memory/
forkable.rs1use crate::{send_sync_ptr::SendSyncPtr, traits::Reset};
2use eyre::Result;
3use memfd::{Memfd, MemfdOptions};
4use rustix::mm::{self, MapFlags, ProtFlags};
5use std::{
6 ptr::{self, NonNull},
7 rc::Rc,
8};
9
10#[repr(C)]
12pub struct ForkableMemory {
13 memory: SendSyncPtr<[u8]>,
14 fd: Rc<Memfd>,
15}
16
17impl ForkableMemory {
18 pub fn new(size: usize) -> Result<Self> {
19 let mfd = MemfdOptions::default().create(format!("sized-{size}"))?;
20 mfd.as_file().set_len(size as u64)?;
21 let ptr = unsafe {
22 mm::mmap(
23 ptr::null_mut(),
24 size,
25 ProtFlags::READ | ProtFlags::WRITE,
26 MapFlags::SHARED | MapFlags::NORESERVE,
27 mfd.as_file(),
28 0,
29 )?
30 };
31
32 let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size);
33
34 Ok(Self {
35 memory: NonNull::new(memory).unwrap().into(),
36 fd: Rc::new(mfd),
37 })
38 }
39
40 #[inline]
41 pub fn as_ptr(&self) -> *mut u8 {
42 self.memory.cast().as_ptr()
43 }
44
45 #[inline]
46 pub fn as_send_sync_ptr(&self) -> SendSyncPtr<u8> {
47 self.memory.cast()
48 }
49
50 pub fn fork(&self) -> Result<Self> {
52 let ptr = unsafe {
53 mm::mmap(
54 ptr::null_mut(),
55 self.memory.len(),
56 ProtFlags::READ | ProtFlags::WRITE,
57 MapFlags::PRIVATE | MapFlags::NORESERVE,
58 self.fd.as_file(),
59 0,
60 )?
61 };
62
63 let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), self.memory.len());
64
65 Ok(Self {
66 memory: NonNull::new(memory).unwrap().into(),
67 fd: self.fd.clone(),
68 })
69 }
70}
71
72impl Reset for ForkableMemory {
73 fn reset(&mut self) {
74 let size = self.memory.len();
75 self.fd.as_file().set_len(0).unwrap();
77 self.fd.as_file().set_len(size as u64).unwrap();
78 let ptr = unsafe {
79 mm::mmap(
80 self.as_ptr().cast(),
81 size,
82 ProtFlags::READ | ProtFlags::WRITE,
83 MapFlags::SHARED | MapFlags::FIXED | MapFlags::NORESERVE,
84 self.fd.as_file(),
85 0,
86 )
87 .unwrap()
88 };
89 let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size);
90 self.memory = NonNull::new(memory).unwrap().into();
91 }
92}
93
94impl Drop for ForkableMemory {
95 fn drop(&mut self) {
96 unsafe {
97 let ptr = self.memory.as_ptr().cast();
98 let len = self.memory.len();
99 if len == 0 {
100 return;
101 }
102 rustix::mm::munmap(ptr, len).expect("munmap failed");
103 }
104 }
105}