1pub use nix::libc;
2use nix::libc::c_void;
3use nix::sys::mman::MapFlags;
4pub use nix::sys::mman::ProtFlags;
5use nix::sys::signal;
6use nix::unistd;
7use parking_lot::Mutex;
8use std::mem;
9use std::num::NonZeroUsize;
10use std::os::fd::RawFd;
11use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
12use std::sync::Arc;
13
14const ADDR_SIZE: usize = std::mem::size_of::<usize>();
15
16#[derive(Debug, PartialEq, Eq)]
17pub enum AccessType {
18 Read,
19 Write,
20}
21
22pub trait PageStore {
23 fn page_fault(
25 &mut self, offset: usize, length: usize, access: AccessType,
26 ) -> Option<Box<dyn Iterator<Item = Box<dyn AsRef<[u8]> + '_>> + '_>>;
27}
28
29pub struct MappedMemory {
30 base: AtomicPtr<u8>,
31 length: usize,
32 unmap: bool,
33 shared: Mutex<Vec<SharedMemory>>,
34}
35
36impl MappedMemory {
37 pub fn new(base: Option<*mut u8>, mut length: usize, page_size: usize, flags: ProtFlags) -> Result<Self, Error> {
38 let rem = length & (page_size - 1);
39 match base {
40 Some(base) => {
41 if (base as usize) % page_size != 0 {
42 return Err(Error::BaseNotAligned)
43 }
44 if rem != 0 {
45 return Err(Error::LengthNotAligned)
46 }
47 }
48 None => {
49 if rem != 0 {
50 length += page_size - rem
51 }
52 }
53 }
54
55 let new_base = unsafe {
56 nix::sys::mman::mmap(
57 match base {
58 Some(b) => Some(NonZeroUsize::new(b as usize).ok_or(Error::NullBase)?),
59 None => None,
60 },
61 NonZeroUsize::new(length).ok_or(Error::LengthIsZero)?,
62 flags,
63 match base {
64 Some(_) => MapFlags::MAP_FIXED,
65 None => MapFlags::empty(),
66 } | MapFlags::MAP_PRIVATE |
67 MapFlags::MAP_ANONYMOUS,
68 Option::<std::fs::File>::None,
69 0,
70 )
71 .map_err(Error::UnixError)?
72 } as *mut u8;
73
74 if let Some(base) = base {
75 if base != new_base {
76 return Err(Error::NotSupported)
77 }
78 }
79
80 Ok(Self {
81 base: new_base.into(),
82 length,
83 unmap: base.is_none(),
84 shared: Mutex::new(Vec::new()),
85 })
86 }
87 #[inline(always)]
88 pub fn base(&self) -> *mut u8 {
89 unsafe { *self.base.as_ptr() }
90 }
91
92 #[inline(always)]
93 pub fn as_slice(&self) -> &mut [u8] {
94 unsafe { std::slice::from_raw_parts_mut(self.base(), self.length) }
95 }
96
97 pub fn make_shared(&self, offset: usize, shm: &SharedMemory, flags: ProtFlags) -> Result<(), Error> {
98 let len = shm.0.size;
99 if offset + len >= self.length {
100 return Err(Error::MemoryOverflow)
101 }
102 unsafe {
103 nix::sys::mman::mmap(
104 Some(NonZeroUsize::new(self.base().add(offset) as usize).unwrap()),
105 NonZeroUsize::new(len).unwrap(),
106 flags,
107 MapFlags::MAP_FIXED | MapFlags::MAP_SHARED,
108 Some(&shm.0.fd),
109 0,
110 )
111 .map_err(Error::UnixError)?;
112 }
113 self.shared.lock().push(shm.clone());
115 Ok(())
116 }
117}
118
119impl Drop for MappedMemory {
120 fn drop(&mut self) {
121 if self.unmap {
122 unsafe {
123 nix::sys::mman::munmap(self.base() as *mut c_void, self.length).unwrap();
124 }
125 }
126 }
127}
128
129pub struct PagedMemory<'a> {
130 mem: Arc<MappedMemory>,
131 page_size: usize,
132 _phantom: std::marker::PhantomData<&'a ()>,
133}
134
135struct PagedMemoryEntry {
136 start: usize,
137 len: usize,
138 mem: Arc<MappedMemory>,
139 store: Box<dyn PageStore + Send + 'static>,
140 page_size: usize,
141}
142
143#[derive(Debug, PartialEq, Eq)]
144pub enum Error {
145 BaseNotAligned,
146 NullBase,
147 LengthNotAligned,
148 LengthIsZero,
149 PageSizeNotAvail,
150 NotSupported,
151 UnixError(nix::errno::Errno),
152 MemoryOverlap,
153 MemoryOverflow,
154}
155
156static HANDLER_SPIN: AtomicBool = AtomicBool::new(false);
157static mut TO_HANDLER: (RawFd, RawFd) = (0, 1);
158static mut FROM_HANDLER: (RawFd, RawFd) = (0, 1);
159static mut PREV_SIGSEGV: mem::MaybeUninit<signal::SigAction> = mem::MaybeUninit::uninit();
160static mut PREV_SIGBUS: mem::MaybeUninit<signal::SigAction> = mem::MaybeUninit::uninit();
161
162#[inline]
163fn handle_page_fault_(info: *mut libc::siginfo_t, ctx: *mut c_void) -> bool {
164 let (tx, rx, addr, ctx) = unsafe {
165 let (rx, _) = TO_HANDLER;
166 let (_, tx) = FROM_HANDLER;
167 (tx, rx, (*info).si_addr() as usize, &mut *(ctx as *mut libc::ucontext_t))
168 };
169 #[cfg(target_arch = "x86_64")]
170 let flag = ((ctx.uc_mcontext.gregs[libc::REG_ERR as usize] & 0x2) >> 1) as u8;
171 #[cfg(not(target_arch = "x86_64"))]
172 let flag = 1;
173 let mut buff = [0; ADDR_SIZE + 1];
174 buff[..ADDR_SIZE].copy_from_slice(&addr.to_le_bytes());
175 buff[ADDR_SIZE] = flag;
176 while HANDLER_SPIN.swap(true, Ordering::Acquire) {
178 std::thread::yield_now();
179 }
180 if unistd::write(tx, &buff).is_err() {
181 HANDLER_SPIN.swap(false, Ordering::Release);
182 return true
183 }
184 unistd::read(rx, &mut buff[..1]).ok();
185 HANDLER_SPIN.swap(false, Ordering::Release);
186 buff[0] == 1
187}
188
189extern "C" fn handle_page_fault(signum: libc::c_int, info: *mut libc::siginfo_t, ctx: *mut c_void) {
192 if !handle_page_fault_(info, ctx) {
193 return
194 }
195
196 unsafe {
197 let previous_signal = signal::Signal::try_from(signum).expect("invalid signum");
199 let previous = *(match previous_signal {
200 signal::SIGSEGV => PREV_SIGSEGV.as_ptr(),
201 signal::SIGBUS => PREV_SIGBUS.as_ptr(),
202 _ => panic!("unknown signal: {}", previous_signal),
203 });
204
205 match previous.handler() {
206 signal::SigHandler::SigDfl => {
207 signal::signal(previous_signal, signal::SigHandler::SigDfl).expect("fail to reset signal handler");
208 let _ = signal::raise(previous_signal);
209 }
210 signal::SigHandler::SigIgn => {}
211 signal::SigHandler::SigAction(handler)
212 if previous.flags() & signal::SaFlags::SA_SIGINFO == signal::SaFlags::SA_SIGINFO =>
213 {
214 handler(signum, info, ctx);
215 }
216 signal::SigHandler::Handler(handler) => handler(signum),
217 _ => panic!("unexpected signal handler"),
218 }
219 }
220}
221
222unsafe fn register_signal_handlers_(handler: extern "C" fn(i32, *mut libc::siginfo_t, *mut c_void)) {
223 let register = |slot: *mut signal::SigAction, signal: signal::Signal| {
224 let sig_action = signal::SigAction::new(
238 signal::SigHandler::SigAction(handler),
239 signal::SaFlags::SA_NODEFER | signal::SaFlags::SA_SIGINFO | signal::SaFlags::SA_ONSTACK,
240 signal::SigSet::empty(),
241 );
242
243 *slot = signal::sigaction(signal, &sig_action).expect("fail to register signal handler");
244 };
245
246 register(PREV_SIGSEGV.as_mut_ptr(), signal::SIGSEGV);
247 register(PREV_SIGBUS.as_mut_ptr(), signal::SIGBUS);
248}
249
250pub unsafe fn register_signal_handlers() {
251 register_signal_handlers_(handle_page_fault);
252}
253
254struct PagedMemoryManager {
255 entries: Vec<PagedMemoryEntry>,
256}
257
258impl PagedMemoryManager {
259 fn insert(&mut self, entry: PagedMemoryEntry) -> bool {
260 for (i, PagedMemoryEntry { start, len, .. }) in self.entries.iter().enumerate() {
261 if entry.start + entry.len <= *start {
262 self.entries.insert(i, entry);
264 return true
265 }
266 if entry.start < *start + *len {
267 return false
269 }
270 }
271 self.entries.push(entry);
272 true
273 }
274
275 fn remove(&mut self, start_: usize, len_: usize) {
276 for (i, PagedMemoryEntry { start, len, .. }) in self.entries.iter().enumerate() {
277 if *start == start_ && *len == len_ {
278 self.entries.remove(i);
279 return
280 }
281 }
282 panic!(
283 "failed to locate PagedMemoryEntry (start = 0x{:x}, end = 0x{:x})",
284 start_,
285 start_ + len_
286 )
287 }
288}
289
290static MANAGER: Mutex<PagedMemoryManager> = Mutex::new(PagedMemoryManager { entries: Vec::new() });
291
292fn handler_init() {
293 let to_handler = nix::unistd::pipe().expect("fail to create pipe to the handler");
294 let from_handler = nix::unistd::pipe().expect("fail to create pipe from the handler");
295 unsafe {
296 TO_HANDLER = to_handler;
297 FROM_HANDLER = from_handler;
298 register_signal_handlers();
299 }
300 std::sync::atomic::fence(Ordering::SeqCst);
301 std::thread::spawn(move || {
302 let from_handler = from_handler.0;
303 let to_handler = to_handler.1;
304 let mut buff = [0; ADDR_SIZE + 1];
305 loop {
306 unistd::read(from_handler, &mut buff).unwrap();
307 let addr = usize::from_le_bytes(buff[..ADDR_SIZE].try_into().unwrap());
308 let (access_type, mprotect_flag) = match buff[ADDR_SIZE] {
309 0 => (AccessType::Read, ProtFlags::PROT_READ),
310 _ => (AccessType::Write, ProtFlags::PROT_READ | ProtFlags::PROT_WRITE),
311 };
312 let mut mgr = MANAGER.lock();
313 let mut fallback = 1;
314 for entry in mgr.entries.iter_mut() {
315 if entry.start <= addr && addr < entry.start + entry.len {
316 let page_mask = usize::MAX ^ (entry.page_size - 1);
317 let page_addr = addr & page_mask;
318 let slice = entry.mem.as_slice();
320 let base = slice.as_ptr() as usize;
321 let page_offset = page_addr - base;
322 if let Some(page) = entry.store.page_fault(page_offset, entry.page_size, access_type) {
323 unsafe {
324 nix::sys::mman::mprotect(
325 page_addr as *mut c_void,
326 entry.page_size,
327 ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
328 )
329 .expect("mprotect failed");
330 }
331 let target = &mut slice[page_offset..page_offset + entry.page_size];
332 let mut base = 0;
333 for chunk in page {
334 let chunk = (*chunk).as_ref();
335 let chunk_len = chunk.len();
336 target[base..base + chunk_len].copy_from_slice(&chunk);
337 base += chunk_len;
338 }
339 }
340 unsafe {
342 nix::sys::mman::mprotect(page_addr as *mut c_void, entry.page_size, mprotect_flag)
343 .expect("mprotect failed");
344 }
345 fallback = 0;
346 break
347 }
348 }
349 unistd::write(to_handler, &[fallback]).unwrap();
351 }
352 });
353}
354
355impl<'a> PagedMemory<'a> {
356 pub unsafe fn from_raw<S: PageStore + Send + 'static>(
360 base: *mut u8, length: usize, store: S, page_size: Option<usize>,
361 ) -> Result<PagedMemory<'static>, Error> {
362 let mem: &'static mut [u8] = std::slice::from_raw_parts_mut(base, length);
363 Self::new_(Some(mem.as_ptr() as *mut u8), mem.len(), store, page_size)
364 }
365
366 pub fn new<S: PageStore + Send + 'static>(
368 length: usize, store: S, page_size: Option<usize>,
369 ) -> Result<PagedMemory<'static>, Error> {
370 Self::new_(None, length, store, page_size)
371 }
372
373 fn new_<'b, S: PageStore + Send + 'static>(
374 base: Option<*mut u8>, length: usize, store: S, page_size: Option<usize>,
375 ) -> Result<PagedMemory<'b>, Error> {
376 static INIT: std::sync::Once = std::sync::Once::new();
377 INIT.call_once(|| handler_init());
378 let page_size = match page_size {
379 Some(s) => s,
380 None => get_page_size()?,
381 };
382 let mem = std::sync::Arc::new(MappedMemory::new(base, length, page_size, ProtFlags::PROT_NONE)?);
383 let mut mgr = MANAGER.lock();
384 if !mgr.insert(PagedMemoryEntry {
385 start: mem.base() as usize,
386 len: length,
387 mem: mem.clone(),
388 store: Box::new(store),
389 page_size,
390 }) {
391 return Err(Error::MemoryOverlap)
392 }
393
394 Ok(PagedMemory {
395 mem,
396 page_size,
397 _phantom: std::marker::PhantomData,
398 })
399 }
400
401 pub fn as_slice_mut(&mut self) -> &mut [u8] {
431 self.mem.as_slice()
432 }
433
434 pub fn as_slice(&self) -> &[u8] {
435 self.mem.as_slice()
436 }
437
438 pub fn as_raw_parts(&self) -> (*mut u8, usize) {
439 let s = self.mem.as_slice();
440 (s.as_mut_ptr(), s.len())
441 }
442
443 pub fn page_size(&self) -> usize {
445 self.page_size
446 }
447
448 pub fn mark_read_only(&self, offset: usize, length: usize) {
451 assert!(offset + length <= self.mem.length);
452 unsafe {
453 nix::sys::mman::mprotect(self.mem.base().add(offset) as *mut c_void, length, ProtFlags::PROT_READ)
454 .expect("mprotect failed");
455 }
456 }
457
458 pub fn release_page(&self, page_offset: usize) {
462 if page_offset & (self.page_size - 1) != 0 || page_offset >= self.mem.length {
463 panic!("invalid page offset: {:x}", page_offset);
464 }
465 let page_addr = self.mem.base() as usize + page_offset;
466 unsafe {
467 nix::sys::mman::mmap(
468 Some(NonZeroUsize::new(page_addr).unwrap()),
469 NonZeroUsize::new(self.page_size).unwrap(),
470 ProtFlags::PROT_NONE,
471 MapFlags::MAP_FIXED | MapFlags::MAP_PRIVATE | MapFlags::MAP_ANONYMOUS,
472 Option::<std::fs::File>::None,
473 0,
474 )
475 .expect("mmap failed");
476 }
477 }
478
479 pub fn release_all_pages(&self) {
480 unsafe {
481 nix::sys::mman::mmap(
482 Some(NonZeroUsize::new(self.mem.base() as usize).unwrap()),
483 NonZeroUsize::new(self.mem.length).unwrap(),
484 ProtFlags::PROT_NONE,
485 MapFlags::MAP_FIXED | MapFlags::MAP_PRIVATE | MapFlags::MAP_ANONYMOUS,
486 Option::<std::fs::File>::None,
487 0,
488 )
489 .expect("mmap failed");
490 }
491 self.mem.shared.lock().clear();
492 }
493
494 pub fn make_shared(&self, offset: usize, shm: &SharedMemory) -> Result<(), Error> {
495 self.mem.make_shared(offset, shm, ProtFlags::PROT_NONE)
496 }
497}
498
499impl<'a> Drop for PagedMemory<'a> {
500 fn drop(&mut self) {
501 let mut mgr = MANAGER.lock();
502 mgr.remove(self.mem.base() as usize, self.mem.length);
503 }
504}
505
506pub struct VecPageStore(Vec<u8>);
507
508impl VecPageStore {
509 pub fn new(vec: Vec<u8>) -> Self {
510 Self(vec)
511 }
512}
513
514impl PageStore for VecPageStore {
515 fn page_fault(
516 &mut self, offset: usize, length: usize, _access: AccessType,
517 ) -> Option<Box<dyn Iterator<Item = Box<dyn AsRef<[u8]> + '_>> + '_>> {
518 #[cfg(debug_assertions)]
519 println!(
520 "{:?} loading page at 0x{:x} access={:?}",
521 self as *mut Self, offset, _access,
522 );
523 Some(Box::new(std::iter::once(
524 Box::new(&self.0[offset..offset + length]) as Box<dyn AsRef<[u8]>>
525 )))
526 }
527}
528
529#[derive(Clone)]
530pub struct SharedMemory(Arc<SharedMemoryInner>);
531
532struct SharedMemoryInner {
533 fd: std::os::fd::OwnedFd,
534 size: usize,
535}
536
537impl SharedMemory {
538 pub fn new(size: usize) -> Result<Self, Error> {
539 #[cfg(target_os = "linux")]
540 let fd = nix::sys::memfd::memfd_create(c"userspace-paging", nix::sys::memfd::MemFdCreateFlag::empty())
541 .map_err(Error::UnixError)?;
542 #[cfg(any(target_os = "macos", target_os = "ios"))]
543 let fd = {
544 let mut failed = 0;
545 let mut result = None;
546 loop {
547 if failed > 10 {
548 break
549 }
550 use nix::fcntl::OFlag;
551 use nix::sys::stat::Mode;
552 use rand::distributions::{Alphanumeric, DistString};
553
554 let name = format!("/usp-{}", Alphanumeric.sample_string(&mut rand::thread_rng(), 16));
556 if let Ok(fd) = nix::sys::mman::shm_open(
557 &*name,
558 OFlag::O_RDWR | OFlag::O_CREAT | OFlag::O_EXCL | OFlag::O_NOFOLLOW,
559 Mode::S_IRUSR | Mode::S_IWUSR,
560 ) {
561 result = Some(fd);
562 break
563 }
564 failed += 1;
565 }
566 match result {
567 None => return Err(Error::NotSupported),
568 Some(fd) => fd,
569 }
570 };
571 nix::unistd::ftruncate(&fd, size as libc::off_t).map_err(Error::UnixError)?;
572 Ok(Self(Arc::new(SharedMemoryInner { fd, size })))
573 }
574}
575
576pub fn get_page_size() -> Result<usize, Error> {
577 Ok(unistd::sysconf(unistd::SysconfVar::PAGE_SIZE)
578 .map_err(Error::UnixError)?
579 .ok_or(Error::PageSizeNotAvail)? as usize)
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use lazy_static::lazy_static;
586
587 lazy_static! {
588 static ref PAGE_SIZE: usize = unistd::sysconf(unistd::SysconfVar::PAGE_SIZE).unwrap().unwrap() as usize;
589 }
590
591 #[test]
592 fn test1() {
593 for _ in 0..100 {
594 let mut v = Vec::new();
595 v.resize(*PAGE_SIZE * 100, 0);
596 v[0] = 42;
597 v[*PAGE_SIZE * 10 + 1] = 43;
598 v[*PAGE_SIZE * 20 + 1] = 44;
599
600 let pm = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
601 let m = pm.as_slice();
602 assert_eq!(m[0], 42);
603 assert_eq!(m[*PAGE_SIZE * 10 + 1], 43);
604 assert_eq!(m[*PAGE_SIZE * 20 + 1], 44);
605 }
606 }
607
608 #[test]
609 fn test2() {
610 for _ in 0..100 {
611 let mut v = Vec::new();
612 v.resize(*PAGE_SIZE * 100, 0);
613 v[0] = 1;
614 v[*PAGE_SIZE * 10 + 1] = 2;
615 v[*PAGE_SIZE * 20 + 1] = 3;
616
617 let pm1 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
618
619 let mut v = Vec::new();
620 v.resize(*PAGE_SIZE * 100, 0);
621 for (i, v) in v.iter_mut().enumerate() {
622 *v = i as u8;
623 }
624 let mut pm2 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
625
626 let m2 = pm2.as_slice_mut();
627 let m1 = pm1.as_slice();
628
629 assert_eq!(m2[100], 100);
630 m2[100] = 0;
631 assert_eq!(m2[100], 0);
632
633 assert_eq!(m1[0], 1);
634 assert_eq!(m1[*PAGE_SIZE * 10 + 1], 2);
635 assert_eq!(m1[*PAGE_SIZE * 20 + 1], 3);
636 }
637 }
638
639 #[test]
640 fn test_shared_memory() {
641 let mut v = Vec::new();
642 v.resize(*PAGE_SIZE * 100, 0);
643 v[0] = 42;
644 v[*PAGE_SIZE * 10 + 1] = 43;
645 v[*PAGE_SIZE * 20 + 1] = 44;
646
647 let shm = SharedMemory::new(*PAGE_SIZE).unwrap();
648 let mut pm1 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v.clone()), None).unwrap();
649 let pm2 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
650 pm1.make_shared(*PAGE_SIZE * 10, &shm).unwrap();
651 pm2.make_shared(*PAGE_SIZE * 10, &shm).unwrap();
652
653 assert_eq!(pm1.as_slice()[*PAGE_SIZE * 10 + 1], 43);
654 assert_eq!(pm2.as_slice()[*PAGE_SIZE * 10 + 1], 43);
655 pm1.as_slice_mut()[*PAGE_SIZE * 10 + 1] = 99;
656 assert_eq!(pm2.as_slice()[*PAGE_SIZE * 10 + 1], 99);
657 assert_eq!(pm1.as_slice()[*PAGE_SIZE * 10 + 1], 99);
658
659 let m = pm1.as_slice();
660 assert_eq!(m[0], 42);
661 assert_eq!(m[*PAGE_SIZE * 20 + 1], 44);
662 }
663
664 #[test]
665 fn test_release_page() {
666 let mut v = Vec::new();
667 v.resize(*PAGE_SIZE * 20, 0);
668 v[0] = 42;
669 v[*PAGE_SIZE * 10 + 1] = 43;
670
671 let pm = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
672 let m = pm.as_slice();
673 assert_eq!(m[0], 42);
674 assert_eq!(m[*PAGE_SIZE * 10 + 1], 43);
675 for _ in 0..5 {
676 pm.release_page(0);
677 pm.release_page(*PAGE_SIZE * 10);
678 assert_eq!(m[0], 42);
679 assert_eq!(m[*PAGE_SIZE * 10 + 1], 43);
680 }
681 }
682
683 #[test]
684 fn out_of_order_scan() {
685 let mut v = Vec::new();
686 v.resize(*PAGE_SIZE * 100, 0);
687 for (i, v) in v.iter_mut().enumerate() {
688 *v = i as u8;
689 }
690 let store = VecPageStore::new(v);
691 let pm = PagedMemory::new(*PAGE_SIZE * 100, store, None).unwrap();
692 use rand::{seq::SliceRandom, SeedableRng};
693 use rand_chacha::ChaChaRng;
694 let seed = [0; 32];
695 let mut rng = ChaChaRng::from_seed(seed);
696
697 let m = pm.as_slice();
698 let mut idxes = Vec::new();
699 for i in 0..m.len() {
700 idxes.push(i);
701 }
702 idxes.shuffle(&mut rng);
703 for i in idxes.into_iter() {
704 #[cfg(debug_assertions)]
705 {
706 let x = m[i];
707 println!("m[0x{:08x}] = {}", i, x);
708 }
709 assert_eq!(m[i], i as u8);
710 }
711 }
712}