userspace_pagefault/
lib.rs

1#![warn(static_mut_refs)]
2pub use nix::libc;
3use nix::libc::c_void;
4use nix::sys::mman::MapFlags;
5pub use nix::sys::mman::ProtFlags;
6use nix::sys::signal;
7use nix::unistd;
8use parking_lot::Mutex;
9use std::num::NonZeroUsize;
10use std::os::fd::{AsRawFd, BorrowedFd, IntoRawFd, RawFd};
11use std::ptr::NonNull;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
14
15mod machdep;
16
17const ADDR_SIZE: usize = std::mem::size_of::<usize>();
18
19type SignalHandler = extern "C" fn(libc::c_int, *mut libc::siginfo_t, *mut c_void);
20
21#[derive(Debug, PartialEq, Eq)]
22pub enum AccessType {
23    Read,
24    Write,
25}
26
27pub trait PageStore {
28    /// Callback that is triggered upon a page fault. Return `Some()` if the page needs to be loaded with the given data.
29    fn page_fault(
30        &mut self, offset: usize, length: usize, access: AccessType,
31    ) -> Option<Box<dyn Iterator<Item = Box<dyn AsRef<[u8]> + '_>> + '_>>;
32}
33
34pub struct MappedMemory {
35    base: AtomicPtr<u8>,
36    length: usize,
37    unmap: bool,
38    shared: Mutex<Vec<SharedMemory>>,
39}
40
41impl MappedMemory {
42    pub fn new(base: Option<*mut u8>, mut length: usize, page_size: usize, flags: ProtFlags) -> Result<Self, Error> {
43        let rem = length & (page_size - 1);
44        match base {
45            Some(base) => {
46                if (base as usize) % page_size != 0 {
47                    return Err(Error::BaseNotAligned);
48                }
49                if rem != 0 {
50                    return Err(Error::LengthNotAligned);
51                }
52            }
53            None => {
54                if rem != 0 {
55                    length += page_size - rem
56                }
57            }
58        }
59
60        let addr = match base {
61            Some(b) => Some(NonZeroUsize::new(b as usize).ok_or(Error::NullBase)?),
62            None => None,
63        };
64        let length_nz = NonZeroUsize::new(length).ok_or(Error::LengthIsZero)?;
65        let map_flags = match base {
66            Some(_) => MapFlags::MAP_FIXED,
67            None => MapFlags::empty(),
68        } | MapFlags::MAP_PRIVATE;
69
70        let new_base = unsafe {
71            nix::sys::mman::mmap_anonymous(addr, length_nz, flags, map_flags)
72                .map_err(Error::UnixError)?
73                .cast::<u8>()
74        };
75        let new_base_ptr = new_base.as_ptr();
76
77        if let Some(base) = base {
78            if base != new_base_ptr {
79                return Err(Error::NotSupported);
80            }
81        }
82
83        Ok(Self {
84            base: AtomicPtr::new(new_base_ptr),
85            length,
86            unmap: base.is_none(),
87            shared: Mutex::new(Vec::new()),
88        })
89    }
90    #[inline(always)]
91    pub fn base(&self) -> *mut u8 {
92        unsafe { *self.base.as_ptr() }
93    }
94
95    #[inline(always)]
96    pub fn as_slice(&self) -> &mut [u8] {
97        unsafe { std::slice::from_raw_parts_mut(self.base(), self.length) }
98    }
99
100    pub fn make_shared(&self, offset: usize, shm: &SharedMemory, flags: ProtFlags) -> Result<(), Error> {
101        let len = shm.0.size;
102        if offset + len >= self.length {
103            return Err(Error::MemoryOverflow);
104        }
105        unsafe {
106            nix::sys::mman::mmap(
107                Some(NonZeroUsize::new(self.base().add(offset) as usize).unwrap()),
108                NonZeroUsize::new(len).unwrap(),
109                flags,
110                MapFlags::MAP_FIXED | MapFlags::MAP_SHARED,
111                &shm.0.fd,
112                0,
113            )
114            .map_err(Error::UnixError)?;
115        }
116        // keep a reference to the shared memory so it is not deallocated
117        self.shared.lock().push(shm.clone());
118        Ok(())
119    }
120}
121
122impl Drop for MappedMemory {
123    fn drop(&mut self) {
124        if self.unmap {
125            unsafe {
126                if let Some(ptr) = NonNull::new(self.base() as *mut c_void) {
127                    nix::sys::mman::munmap(ptr, self.length).unwrap();
128                }
129            }
130        }
131    }
132}
133
134pub struct PagedMemory<'a> {
135    mem: Arc<MappedMemory>,
136    page_size: usize,
137    _phantom: std::marker::PhantomData<&'a ()>,
138}
139
140struct PagedMemoryEntry {
141    start: usize,
142    len: usize,
143    mem: Arc<MappedMemory>,
144    store: Box<dyn PageStore + Send + 'static>,
145    page_size: usize,
146}
147
148#[derive(Debug, PartialEq, Eq)]
149pub enum Error {
150    BaseNotAligned,
151    NullBase,
152    LengthNotAligned,
153    LengthIsZero,
154    PageSizeNotAvail,
155    NotSupported,
156    UnixError(nix::errno::Errno),
157    MemoryOverlap,
158    MemoryOverflow,
159}
160
161static HANDLER_SPIN: AtomicBool = AtomicBool::new(false);
162static HANDLER_INITIALIZED: AtomicBool = AtomicBool::new(false);
163static mut TO_HANDLER: (RawFd, RawFd) = (0, 1);
164static mut FROM_HANDLER: (RawFd, RawFd) = (0, 1);
165static mut FALLBACK_SIGSEGV_HANDLER: Option<SignalHandler> = None;
166static mut FALLBACK_SIGBUS_HANDLER: Option<SignalHandler> = None;
167
168#[inline]
169fn handle_page_fault_(info: *mut libc::siginfo_t, ctx: *mut c_void) -> bool {
170    // NOTE: this function should be SIGBUS/SIGSEGV-free, another signal can't be raised during the
171    // handling of the signal.
172    let (tx, rx, addr, ctx) = unsafe {
173        let (rx, _) = TO_HANDLER;
174        let (_, tx) = FROM_HANDLER;
175        (tx, rx, (*info).si_addr() as usize, &mut *(ctx as *mut libc::ucontext_t))
176    };
177    let flag = machdep::check_page_fault_rw_flag_from_context(*ctx);
178    let mut buff = [0; ADDR_SIZE + 1];
179    buff[..ADDR_SIZE].copy_from_slice(&addr.to_le_bytes());
180    buff[ADDR_SIZE] = flag;
181    // use a spin lock to avoid ABA (another thread could interfere in-between read and write)
182    while HANDLER_SPIN.swap(true, Ordering::Acquire) {
183        std::thread::yield_now();
184    }
185    if unistd::write(unsafe { BorrowedFd::borrow_raw(tx) }, &buff).is_err() {
186        HANDLER_SPIN.swap(false, Ordering::Release);
187        return true;
188    }
189    let _ = unistd::read(unsafe { BorrowedFd::borrow_raw(rx) }, &mut buff[..1]);
190    HANDLER_SPIN.swap(false, Ordering::Release);
191    buff[0] == 1
192}
193
194// The fallback signal handling was inspired by wasmtime trap handlers:
195// https://github.com/bytecodealliance/wasmtime/blob/v22.0.0/crates/wasmtime/src/runtime/vm/sys/unix/signals.rs
196extern "C" fn handle_page_fault(signum: libc::c_int, info: *mut libc::siginfo_t, ctx: *mut c_void) {
197    if !handle_page_fault_(info, ctx) {
198        return;
199    }
200    // Otherwise, not hitting a managed memory region, fallback to previous handler
201
202    unsafe {
203        let sig = signal::Signal::try_from(signum).expect("invalid signum");
204        let fallback_handler = match sig {
205            signal::SIGSEGV => FALLBACK_SIGSEGV_HANDLER,
206            signal::SIGBUS => FALLBACK_SIGBUS_HANDLER,
207            _ => panic!("unknown signal: {}", sig),
208        };
209
210        if let Some(handler) = fallback_handler {
211            // Delegate to the fallback handler
212            handler(signum, info, ctx);
213        } else {
214            // No fallback handler (was SIG_DFL or SIG_IGN), reset to default and raise
215            let sig_action = signal::SigAction::new(
216                signal::SigHandler::SigDfl,
217                signal::SaFlags::empty(),
218                signal::SigSet::empty(),
219            );
220            signal::sigaction(sig, &sig_action).expect("fail to reset signal handler");
221            signal::raise(sig).expect("fail to raise signal");
222            unreachable!("SIG_DFL should have terminated the process");
223        }
224    }
225}
226
227unsafe fn register_signal_handlers(handler: SignalHandler) {
228    let register = |fallback_handler: *mut Option<SignalHandler>, sig: signal::Signal| {
229        // The flags here are relatively careful, and they are...
230        //
231        // SA_SIGINFO gives us access to information like the program
232        // counter from where the fault happened.
233        //
234        // SA_ONSTACK allows us to handle signals on an alternate stack,
235        // so that the handler can run in response to running out of
236        // stack space on the main stack. Rust installs an alternate
237        // stack with sigaltstack, so we rely on that.
238        //
239        // SA_NODEFER allows us to reenter the signal handler if we
240        // crash while handling the signal, and fall through to the
241        // Breakpad handler by testing handlingSegFault.
242        let sig_action = signal::SigAction::new(
243            signal::SigHandler::SigAction(handler),
244            signal::SaFlags::SA_NODEFER | signal::SaFlags::SA_SIGINFO | signal::SaFlags::SA_ONSTACK,
245            signal::SigSet::empty(),
246        );
247
248        // Extract and save the fallback handler function pointer if it's a SigAction with SA_SIGINFO
249        unsafe {
250            let sig = signal::sigaction(sig, &sig_action).expect("fail to register signal handler");
251            *fallback_handler = match sig.handler() {
252                signal::SigHandler::SigAction(h)
253                    if sig.flags() & signal::SaFlags::SA_SIGINFO == signal::SaFlags::SA_SIGINFO =>
254                {
255                    Some(h)
256                }
257                _ => None,
258            };
259        }
260    };
261
262    register(&raw mut FALLBACK_SIGSEGV_HANDLER, signal::SIGSEGV);
263    register(&raw mut FALLBACK_SIGBUS_HANDLER, signal::SIGBUS);
264}
265
266struct PagedMemoryManager {
267    entries: Vec<PagedMemoryEntry>,
268}
269
270impl PagedMemoryManager {
271    fn insert(&mut self, entry: PagedMemoryEntry) -> bool {
272        for (i, PagedMemoryEntry { start, len, .. }) in self.entries.iter().enumerate() {
273            if entry.start + entry.len <= *start {
274                // insert before this entry
275                self.entries.insert(i, entry);
276                return true;
277            }
278            if entry.start < *start + *len {
279                // overlapping space
280                return false;
281            }
282        }
283        self.entries.push(entry);
284        true
285    }
286
287    fn remove(&mut self, start_: usize, len_: usize) {
288        for (i, PagedMemoryEntry { start, len, .. }) in self.entries.iter().enumerate() {
289            if *start == start_ && *len == len_ {
290                self.entries.remove(i);
291                return;
292            }
293        }
294        panic!(
295            "failed to locate PagedMemoryEntry (start = 0x{:x}, end = 0x{:x})",
296            start_,
297            start_ + len_
298        )
299    }
300}
301
302static MANAGER: Mutex<PagedMemoryManager> = Mutex::new(PagedMemoryManager { entries: Vec::new() });
303
304fn handler_init() {
305    let (to_read, to_write) = nix::unistd::pipe().expect("fail to create pipe to the handler");
306    let (from_read, from_write) = nix::unistd::pipe().expect("fail to create pipe from the handler");
307    let from_handler = unsafe { BorrowedFd::borrow_raw(from_read.as_raw_fd()) };
308    let to_handler = unsafe { BorrowedFd::borrow_raw(to_write.as_raw_fd()) };
309    unsafe {
310        TO_HANDLER = (to_read.into_raw_fd(), to_write.into_raw_fd());
311        FROM_HANDLER = (from_read.into_raw_fd(), from_write.into_raw_fd());
312        register_signal_handlers(handle_page_fault);
313    }
314    std::sync::atomic::fence(Ordering::SeqCst);
315    std::thread::spawn(move || {
316        let mut buff = [0; ADDR_SIZE + 1];
317        loop {
318            unistd::read(&from_handler, &mut buff).unwrap();
319            let addr = usize::from_le_bytes(buff[..ADDR_SIZE].try_into().unwrap());
320            let (access_type, mprotect_flag) = match buff[ADDR_SIZE] {
321                0 => (AccessType::Read, ProtFlags::PROT_READ),
322                _ => (AccessType::Write, ProtFlags::PROT_READ | ProtFlags::PROT_WRITE),
323            };
324            let mut mgr = MANAGER.lock();
325            let mut fallback = 1;
326            for entry in mgr.entries.iter_mut() {
327                if entry.start <= addr && addr < entry.start + entry.len {
328                    let page_mask = usize::MAX ^ (entry.page_size - 1);
329                    let page_addr = addr & page_mask;
330                    let page_ptr = unsafe { NonNull::new_unchecked(page_addr as *mut c_void) };
331                    // load the page data
332                    let slice = entry.mem.as_slice();
333                    let base = slice.as_ptr() as usize;
334                    let page_offset = page_addr - base;
335                    if let Some(page) = entry.store.page_fault(page_offset, entry.page_size, access_type) {
336                        unsafe {
337                            nix::sys::mman::mprotect(
338                                page_ptr,
339                                entry.page_size,
340                                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
341                            )
342                            .expect("mprotect failed");
343                        }
344                        let target = &mut slice[page_offset..page_offset + entry.page_size];
345                        let mut base = 0;
346                        for chunk in page {
347                            let chunk = (*chunk).as_ref();
348                            let chunk_len = chunk.len();
349                            target[base..base + chunk_len].copy_from_slice(&chunk);
350                            base += chunk_len;
351                        }
352                    }
353                    // mark as readable/writable
354                    unsafe {
355                        nix::sys::mman::mprotect(page_ptr, entry.page_size, mprotect_flag).expect("mprotect failed");
356                    }
357                    fallback = 0;
358                    break;
359                }
360            }
361            // otherwise this SIGSEGV falls through and we don't do anything about it
362            unistd::write(&to_handler, &[fallback]).unwrap();
363        }
364    });
365}
366
367impl<'a> PagedMemory<'a> {
368    /// Make the memory paged in userspace with raw base pointer and length. Note that the range of
369    /// addresses should be valid throughout the lifetime of [PagedMemory] and no access should be
370    /// made to the memory unless it is wrapped in [PagedMemory::run].
371    pub unsafe fn from_raw<S: PageStore + Send + 'static>(
372        base: *mut u8, length: usize, store: S, page_size: Option<usize>,
373    ) -> Result<PagedMemory<'static>, Error> {
374        let mem: &'static mut [u8] = unsafe { std::slice::from_raw_parts_mut(base, length) };
375        Self::new_(Some(mem.as_ptr() as *mut u8), mem.len(), store, page_size)
376    }
377
378    /// Create a slice of memory that is pagged.
379    pub fn new<S: PageStore + Send + 'static>(
380        length: usize, store: S, page_size: Option<usize>,
381    ) -> Result<PagedMemory<'static>, Error> {
382        Self::new_(None, length, store, page_size)
383    }
384
385    fn new_<'b, S: PageStore + Send + 'static>(
386        base: Option<*mut u8>, length: usize, store: S, page_size: Option<usize>,
387    ) -> Result<PagedMemory<'b>, Error> {
388        if !HANDLER_INITIALIZED.swap(true, Ordering::SeqCst) {
389            handler_init();
390        }
391        let page_size = match page_size {
392            Some(s) => s,
393            None => get_page_size()?,
394        };
395        let mem = std::sync::Arc::new(MappedMemory::new(base, length, page_size, ProtFlags::PROT_NONE)?);
396        let mut mgr = MANAGER.lock();
397        if !mgr.insert(PagedMemoryEntry {
398            start: mem.base() as usize,
399            len: length,
400            mem: mem.clone(),
401            store: Box::new(store),
402            page_size,
403        }) {
404            return Err(Error::MemoryOverlap);
405        }
406
407        Ok(PagedMemory {
408            mem,
409            page_size,
410            _phantom: std::marker::PhantomData,
411        })
412    }
413
414    /*
415    /// Run code that possibly accesses the paged memory. Because an access to the memory could
416    /// wait upon the [PageStore] to load the page in case of a page fault, to avoid dead-lock,
417    /// make sure the resources the store will acquire will not be held by this code before
418    /// accessing the paged memory. For example, using `println!("{}", mem[0])` could dead-lock the
419    /// system if the `read()` implementation of [PageStore] also uses `println`: the I/O is first
420    /// locked before the dereference of `mem[0]` which could possibly induces a page fault that
421    /// invokes `read()` to bring in the page content, then when the page store tries to invoke
422    /// `println`, it gets stuck (the dereference is imcomplete). As a good practice, always make
423    /// sure [PageStore] grabs the least resources that do not overlap with the code here.
424    pub fn run<F, T>(&mut self, f: F) -> std::thread::JoinHandle<T>
425    where
426        F: FnOnce(&mut [u8]) -> T + Send + 'static,
427        T: Send + 'static,
428    {
429        let mem = self.mem.clone();
430        std::thread::spawn(move || f(mem.as_slice()))
431    }
432    */
433
434    /// Because an access to the memory could
435    /// wait upon the [PageStore] to load the page in case of a page fault, to avoid dead-lock,
436    /// make sure the resources the store will acquire will not be held by this code before
437    /// accessing the paged memory. For example, using `println!("{}", mem.as_slice()[0])` could dead-lock the
438    /// system if the `read()` implementation of [PageStore] also uses `println`: the I/O is first
439    /// locked before the dereference of `mem[0]` which could possibly induces a page fault that
440    /// invokes `read()` to bring in the page content, then when the page store tries to invoke
441    /// `println`, it gets stuck (the dereference is imcomplete). As a good practice, always make
442    /// sure [PageStore] grabs the least resources that do not overlap with the code here.
443    pub fn as_slice_mut(&mut self) -> &mut [u8] {
444        self.mem.as_slice()
445    }
446
447    pub fn as_slice(&self) -> &[u8] {
448        self.mem.as_slice()
449    }
450
451    pub fn as_raw_parts(&self) -> (*mut u8, usize) {
452        let s = self.mem.as_slice();
453        (s.as_mut_ptr(), s.len())
454    }
455
456    /// Return the configured page size.
457    pub fn page_size(&self) -> usize {
458        self.page_size
459    }
460
461    /// Mark the entire PagedMemory to be read-only, this will trigger write-access page faults
462    /// again when write operation is made in the future.
463    pub fn mark_read_only(&self, offset: usize, length: usize) {
464        assert!(offset + length <= self.mem.length);
465        unsafe {
466            let ptr = NonNull::new_unchecked(self.mem.base().add(offset) as *mut c_void);
467            nix::sys::mman::mprotect(ptr, length, ProtFlags::PROT_READ).expect("mprotect failed");
468        }
469    }
470
471    /// Release the page content loaded from PageStore. The next access to an address within this
472    /// page will trigger a page fault. `page_offset` must be one of the offset passed in by page
473    /// fault handler.
474    pub fn release_page(&self, page_offset: usize) {
475        if page_offset & (self.page_size - 1) != 0 || page_offset >= self.mem.length {
476            panic!("invalid page offset: {:x}", page_offset);
477        }
478        let page_addr = self.mem.base() as usize + page_offset;
479        unsafe {
480            nix::sys::mman::mmap_anonymous(
481                Some(NonZeroUsize::new(page_addr).unwrap()),
482                NonZeroUsize::new(self.page_size).unwrap(),
483                ProtFlags::PROT_NONE,
484                MapFlags::MAP_FIXED | MapFlags::MAP_PRIVATE,
485            )
486            .expect("mmap failed");
487        }
488    }
489
490    pub fn release_all_pages(&self) {
491        unsafe {
492            nix::sys::mman::mmap_anonymous(
493                Some(NonZeroUsize::new(self.mem.base() as usize).unwrap()),
494                NonZeroUsize::new(self.mem.length).unwrap(),
495                ProtFlags::PROT_NONE,
496                MapFlags::MAP_FIXED | MapFlags::MAP_PRIVATE,
497            )
498            .expect("mmap failed");
499        }
500        self.mem.shared.lock().clear();
501    }
502
503    pub fn make_shared(&self, offset: usize, shm: &SharedMemory) -> Result<(), Error> {
504        self.mem.make_shared(offset, shm, ProtFlags::PROT_NONE)
505    }
506}
507
508impl<'a> Drop for PagedMemory<'a> {
509    fn drop(&mut self) {
510        let mut mgr = MANAGER.lock();
511        mgr.remove(self.mem.base() as usize, self.mem.length);
512    }
513}
514
515pub struct VecPageStore(Vec<u8>);
516
517impl VecPageStore {
518    pub fn new(vec: Vec<u8>) -> Self {
519        Self(vec)
520    }
521}
522
523impl PageStore for VecPageStore {
524    fn page_fault(
525        &mut self, offset: usize, length: usize, _access: AccessType,
526    ) -> Option<Box<dyn Iterator<Item = Box<dyn AsRef<[u8]> + '_>> + '_>> {
527        #[cfg(debug_assertions)]
528        println!(
529            "{:?} loading page at 0x{:x} access={:?}",
530            self as *mut Self, offset, _access,
531        );
532        Some(Box::new(std::iter::once(
533            Box::new(&self.0[offset..offset + length]) as Box<dyn AsRef<[u8]>>
534        )))
535    }
536}
537
538#[derive(Clone)]
539pub struct SharedMemory(Arc<SharedMemoryInner>);
540
541struct SharedMemoryInner {
542    fd: std::os::fd::OwnedFd,
543    size: usize,
544}
545
546impl SharedMemory {
547    pub fn new(size: usize) -> Result<Self, Error> {
548        let fd = machdep::get_shared_memory()?;
549        nix::unistd::ftruncate(&fd, size as libc::off_t).map_err(Error::UnixError)?;
550        Ok(Self(Arc::new(SharedMemoryInner { fd, size })))
551    }
552}
553
554pub fn get_page_size() -> Result<usize, Error> {
555    Ok(unistd::sysconf(unistd::SysconfVar::PAGE_SIZE)
556        .map_err(Error::UnixError)?
557        .ok_or(Error::PageSizeNotAvail)? as usize)
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use lazy_static::lazy_static;
564    use parking_lot::Mutex;
565
566    lazy_static! {
567        static ref PAGE_SIZE: usize = unistd::sysconf(unistd::SysconfVar::PAGE_SIZE).unwrap().unwrap() as usize;
568    }
569
570    static TEST_MUTEX: Mutex<()> = Mutex::new(());
571
572    #[test]
573    fn test1() {
574        let _guard = TEST_MUTEX.lock();
575        for _ in 0..100 {
576            let mut v = Vec::new();
577            v.resize(*PAGE_SIZE * 100, 0);
578            v[0] = 42;
579            v[*PAGE_SIZE * 10 + 1] = 43;
580            v[*PAGE_SIZE * 20 + 1] = 44;
581
582            let pm = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
583            let m = pm.as_slice();
584            assert_eq!(m[0], 42);
585            assert_eq!(m[*PAGE_SIZE * 10 + 1], 43);
586            assert_eq!(m[*PAGE_SIZE * 20 + 1], 44);
587        }
588    }
589
590    #[test]
591    fn test2() {
592        let _guard = TEST_MUTEX.lock();
593        for _ in 0..100 {
594            let mut v = Vec::new();
595            v.resize(*PAGE_SIZE * 100, 0);
596            v[0] = 1;
597            v[*PAGE_SIZE * 10 + 1] = 2;
598            v[*PAGE_SIZE * 20 + 1] = 3;
599
600            let pm1 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
601
602            let mut v = Vec::new();
603            v.resize(*PAGE_SIZE * 100, 0);
604            for (i, v) in v.iter_mut().enumerate() {
605                *v = i as u8;
606            }
607            let mut pm2 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
608
609            let m2 = pm2.as_slice_mut();
610            let m1 = pm1.as_slice();
611
612            assert_eq!(m2[100], 100);
613            m2[100] = 0;
614            assert_eq!(m2[100], 0);
615
616            assert_eq!(m1[0], 1);
617            assert_eq!(m1[*PAGE_SIZE * 10 + 1], 2);
618            assert_eq!(m1[*PAGE_SIZE * 20 + 1], 3);
619        }
620    }
621
622    #[test]
623    fn test_shared_memory() {
624        let _guard = TEST_MUTEX.lock();
625        let mut v = Vec::new();
626        v.resize(*PAGE_SIZE * 100, 0);
627        v[0] = 42;
628        v[*PAGE_SIZE * 10 + 1] = 43;
629        v[*PAGE_SIZE * 20 + 1] = 44;
630
631        let shm = SharedMemory::new(*PAGE_SIZE).unwrap();
632        let mut pm1 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v.clone()), None).unwrap();
633        let pm2 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
634        pm1.make_shared(*PAGE_SIZE * 10, &shm).unwrap();
635        pm2.make_shared(*PAGE_SIZE * 10, &shm).unwrap();
636
637        assert_eq!(pm1.as_slice()[*PAGE_SIZE * 10 + 1], 43);
638        assert_eq!(pm2.as_slice()[*PAGE_SIZE * 10 + 1], 43);
639        pm1.as_slice_mut()[*PAGE_SIZE * 10 + 1] = 99;
640        assert_eq!(pm2.as_slice()[*PAGE_SIZE * 10 + 1], 99);
641        assert_eq!(pm1.as_slice()[*PAGE_SIZE * 10 + 1], 99);
642
643        let m = pm1.as_slice();
644        assert_eq!(m[0], 42);
645        assert_eq!(m[*PAGE_SIZE * 20 + 1], 44);
646    }
647
648    #[test]
649    fn test_release_page() {
650        let _guard = TEST_MUTEX.lock();
651        let mut v = Vec::new();
652        v.resize(*PAGE_SIZE * 20, 0);
653        v[0] = 42;
654        v[*PAGE_SIZE * 10 + 1] = 43;
655
656        let pm = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
657        let m = pm.as_slice();
658        assert_eq!(m[0], 42);
659        assert_eq!(m[*PAGE_SIZE * 10 + 1], 43);
660        for _ in 0..5 {
661            pm.release_page(0);
662            pm.release_page(*PAGE_SIZE * 10);
663            assert_eq!(m[0], 42);
664            assert_eq!(m[*PAGE_SIZE * 10 + 1], 43);
665        }
666    }
667
668    #[test]
669    fn out_of_order_scan() {
670        let _guard = TEST_MUTEX.lock();
671        let mut v = Vec::new();
672        v.resize(*PAGE_SIZE * 100, 0);
673        for (i, v) in v.iter_mut().enumerate() {
674            *v = i as u8;
675        }
676        let store = VecPageStore::new(v);
677        let pm = PagedMemory::new(*PAGE_SIZE * 100, store, None).unwrap();
678        use rand::{SeedableRng, seq::SliceRandom};
679        use rand_chacha::ChaChaRng;
680        let seed = [0; 32];
681        let mut rng = ChaChaRng::from_seed(seed);
682
683        let m = pm.as_slice();
684        let mut idxes = Vec::new();
685        for i in 0..m.len() {
686            idxes.push(i);
687        }
688        idxes.shuffle(&mut rng);
689        for i in idxes.into_iter() {
690            #[cfg(debug_assertions)]
691            {
692                let x = m[i];
693                println!("m[0x{:08x}] = {}", i, x);
694            }
695            assert_eq!(m[i], i as u8);
696        }
697    }
698
699    use signal::{SaFlags, SigAction, SigHandler, SigSet, Signal};
700
701    /// Reset the handler state for testing
702    unsafe fn handler_reset_init() {
703        unsafe {
704            // Close pipe file descriptors to cause the handler thread to exit
705            let (to_read, to_write) = TO_HANDLER;
706            let (from_read, from_write) = FROM_HANDLER;
707
708            if to_read != 0 {
709                let _ = nix::unistd::close(to_read);
710            }
711            if to_write != 1 {
712                let _ = nix::unistd::close(to_write);
713            }
714            if from_read != 0 {
715                let _ = nix::unistd::close(from_read);
716            }
717            if from_write != 1 {
718                let _ = nix::unistd::close(from_write);
719            }
720
721            // Reset signal handlers to SIG_DFL so next init sees default handlers
722            let sig_dfl = SigAction::new(SigHandler::SigDfl, SaFlags::empty(), SigSet::empty());
723            let _ = signal::sigaction(Signal::SIGSEGV, &sig_dfl);
724            let _ = signal::sigaction(Signal::SIGBUS, &sig_dfl);
725
726            // Reset the init flag so handler_init can run again
727            HANDLER_INITIALIZED.store(false, Ordering::SeqCst);
728
729            // Clear fallback handlers
730            FALLBACK_SIGSEGV_HANDLER = None;
731            FALLBACK_SIGBUS_HANDLER = None;
732
733            // Reset pipe fds to initial values
734            TO_HANDLER = (0, 1);
735            FROM_HANDLER = (0, 1);
736        }
737    }
738
739    static SIGSEGV_CALLED: AtomicBool = AtomicBool::new(false);
740    static SIGBUS_CALLED: AtomicBool = AtomicBool::new(false);
741
742    extern "C" fn test_sigsegv_handler(_signum: libc::c_int, info: *mut libc::siginfo_t, _ctx: *mut c_void) {
743        SIGSEGV_CALLED.store(true, Ordering::SeqCst);
744        // Make the memory accessible so the instruction can succeed on retry
745        unsafe {
746            let addr = (*info).si_addr();
747            let page_size = nix::unistd::sysconf(nix::unistd::SysconfVar::PAGE_SIZE)
748                .unwrap()
749                .unwrap() as usize;
750            let page_addr = (addr as usize) & !(page_size - 1);
751            nix::sys::mman::mprotect(
752                NonNull::new_unchecked(page_addr as *mut c_void),
753                page_size,
754                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
755            )
756            .expect("mprotect failed in handler");
757        }
758    }
759
760    extern "C" fn test_sigbus_handler(_signum: libc::c_int, info: *mut libc::siginfo_t, _ctx: *mut c_void) {
761        SIGBUS_CALLED.store(true, Ordering::SeqCst);
762        unsafe {
763            let addr = (*info).si_addr();
764            let page_size = nix::unistd::sysconf(nix::unistd::SysconfVar::PAGE_SIZE)
765                .unwrap()
766                .unwrap() as usize;
767            let page_addr = (addr as usize) & !(page_size - 1);
768            nix::sys::mman::mprotect(
769                NonNull::new_unchecked(page_addr as *mut c_void),
770                page_size,
771                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
772            )
773            .expect("mprotect failed in handler");
774        }
775    }
776
777    #[test]
778    fn test_fallback_handlers_set_and_called() {
779        let _guard = TEST_MUTEX.lock();
780
781        unsafe {
782            // Reset handler state before test
783            handler_reset_init();
784
785            // Register the SIGSEGV handler handler
786            let sigsegv_action = SigAction::new(
787                SigHandler::SigAction(test_sigsegv_handler),
788                SaFlags::SA_SIGINFO | SaFlags::SA_NODEFER,
789                SigSet::empty(),
790            );
791            signal::sigaction(Signal::SIGSEGV, &sigsegv_action).expect("failed to set SIGSEGV handler");
792
793            // Register the fallback SIGBUS handler
794            let sigbus_action = SigAction::new(
795                SigHandler::SigAction(test_sigbus_handler),
796                SaFlags::SA_SIGINFO | SaFlags::SA_NODEFER,
797                SigSet::empty(),
798            );
799            signal::sigaction(Signal::SIGBUS, &sigbus_action).expect("failed to set SIGBUS handler");
800
801            // Create a PagedMemory - this will trigger handler_init() which considers the fallback
802            // handlers.
803            let _pm1 = PagedMemory::new(*PAGE_SIZE, VecPageStore::new(vec![0u8; *PAGE_SIZE]), None).unwrap();
804
805            // Save the handler pointers to verify they don't change
806            let saved_sigsegv = FALLBACK_SIGSEGV_HANDLER.map(|f| f as usize);
807            let saved_sigbus = FALLBACK_SIGBUS_HANDLER.map(|f| f as usize);
808
809            // Verify that the fallback handlers were saved
810            assert!(saved_sigsegv.is_some(), "SIGSEGV fallback handler should be saved");
811            assert!(saved_sigbus.is_some(), "SIGBUS fallback handler should be saved");
812
813            // Create another PagedMemory - handler_init() should NOT run again due to Once guard
814            let _pm2 = PagedMemory::new(*PAGE_SIZE, VecPageStore::new(vec![0u8; *PAGE_SIZE]), None).unwrap();
815
816            // Verify the handler pointers haven't changed (Once guard prevented re-registration)
817            let current_sigsegv = FALLBACK_SIGSEGV_HANDLER.map(|f| f as usize);
818            let current_sigbus = FALLBACK_SIGBUS_HANDLER.map(|f| f as usize);
819            assert_eq!(
820                current_sigsegv, saved_sigsegv,
821                "SIGSEGV fallback handler should not change"
822            );
823            assert_eq!(
824                current_sigbus, saved_sigbus,
825                "SIGBUS fallback handler should not change"
826            );
827
828            // Test SIGSEGV/SIGBUS fallback by accessing memory protected with PROT_NONE
829            SIGSEGV_CALLED.store(false, Ordering::SeqCst);
830            SIGBUS_CALLED.store(false, Ordering::SeqCst);
831
832            // Allocate memory with PROT_NONE to trigger SIGSEGV or SIGBUS when accessed
833            let test_mem = nix::sys::mman::mmap_anonymous(
834                None,
835                NonZeroUsize::new(*PAGE_SIZE).unwrap(),
836                ProtFlags::PROT_NONE,
837                MapFlags::MAP_PRIVATE | MapFlags::MAP_ANONYMOUS,
838            )
839            .expect("mmap failed");
840
841            // Access the protected memory - this triggers SIGSEGV or SIGBUS, handler makes it accessible
842            std::ptr::write_volatile(test_mem.cast::<u8>().as_ptr(), 42);
843
844            // Verify at least one fallback handler was called (platform dependent)
845            assert!(
846                SIGSEGV_CALLED.load(Ordering::SeqCst) || SIGBUS_CALLED.load(Ordering::SeqCst),
847                "SIGSEGV or SIGBUS fallback handler should have been called"
848            );
849
850            // Clean up
851            nix::sys::mman::munmap(test_mem.cast(), *PAGE_SIZE).expect("munmap failed");
852
853            // Reset handler state after test
854            handler_reset_init();
855        }
856    }
857}