userspace_pagefault/
lib.rs

1#![warn(static_mut_refs)]
2pub use nix::libc;
3use nix::libc::c_void;
4use nix::sys::mman::MapFlags;
5pub use nix::sys::mman::ProtFlags;
6use nix::sys::signal;
7use nix::unistd;
8use parking_lot::Mutex;
9use std::num::NonZeroUsize;
10use std::os::fd::{AsRawFd, BorrowedFd, IntoRawFd, RawFd};
11use std::ptr::NonNull;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
14
15mod machdep;
16
17const ADDR_SIZE: usize = std::mem::size_of::<usize>();
18
19type SignalHandler = extern "C" fn(libc::c_int, *mut libc::siginfo_t, *mut c_void);
20
21#[derive(Debug, PartialEq, Eq)]
22pub enum AccessType {
23    Read,
24    Write,
25}
26
27pub trait PageStore {
28    /// Callback that is triggered upon a page fault. Return `Some()` if the page needs to be loaded with the given data.
29    fn page_fault(
30        &mut self, offset: usize, length: usize, access: AccessType,
31    ) -> Option<Box<dyn Iterator<Item = Box<dyn AsRef<[u8]> + '_>> + '_>>;
32}
33
34pub struct MappedMemory {
35    base: AtomicPtr<u8>,
36    length: usize,
37    unmap: bool,
38    shared: Mutex<Vec<SharedMemory>>,
39}
40
41impl MappedMemory {
42    pub fn new(base: Option<*mut u8>, mut length: usize, page_size: usize, flags: ProtFlags) -> Result<Self, Error> {
43        let rem = length & (page_size - 1);
44        match base {
45            Some(base) => {
46                if (base as usize) % page_size != 0 {
47                    return Err(Error::BaseNotAligned);
48                }
49                if rem != 0 {
50                    return Err(Error::LengthNotAligned);
51                }
52            }
53            None => {
54                if rem != 0 {
55                    length += page_size - rem
56                }
57            }
58        }
59
60        let addr = match base {
61            Some(b) => Some(NonZeroUsize::new(b as usize).ok_or(Error::NullBase)?),
62            None => None,
63        };
64        let length_nz = NonZeroUsize::new(length).ok_or(Error::LengthIsZero)?;
65        let map_flags = match base {
66            Some(_) => MapFlags::MAP_FIXED,
67            None => MapFlags::empty(),
68        } | MapFlags::MAP_PRIVATE;
69
70        let new_base = unsafe {
71            nix::sys::mman::mmap_anonymous(addr, length_nz, flags, map_flags)
72                .map_err(Error::UnixError)?
73                .cast::<u8>()
74        };
75        let new_base_ptr = new_base.as_ptr();
76
77        if let Some(base) = base {
78            if base != new_base_ptr {
79                return Err(Error::NotSupported);
80            }
81        }
82
83        Ok(Self {
84            base: AtomicPtr::new(new_base_ptr),
85            length,
86            unmap: base.is_none(),
87            shared: Mutex::new(Vec::new()),
88        })
89    }
90    #[inline(always)]
91    pub fn base(&self) -> *mut u8 {
92        unsafe { *self.base.as_ptr() }
93    }
94
95    #[inline(always)]
96    pub fn as_slice(&self) -> &mut [u8] {
97        unsafe { std::slice::from_raw_parts_mut(self.base(), self.length) }
98    }
99
100    pub fn make_shared(&self, offset: usize, shm: &SharedMemory, flags: ProtFlags) -> Result<(), Error> {
101        let len = shm.0.size;
102        if offset + len >= self.length {
103            return Err(Error::MemoryOverflow);
104        }
105        unsafe {
106            nix::sys::mman::mmap(
107                Some(NonZeroUsize::new(self.base().add(offset) as usize).unwrap()),
108                NonZeroUsize::new(len).unwrap(),
109                flags,
110                MapFlags::MAP_FIXED | MapFlags::MAP_SHARED,
111                &shm.0.fd,
112                0,
113            )
114            .map_err(Error::UnixError)?;
115        }
116        // keep a reference to the shared memory so it is not deallocated
117        self.shared.lock().push(shm.clone());
118        Ok(())
119    }
120}
121
122impl Drop for MappedMemory {
123    fn drop(&mut self) {
124        if self.unmap {
125            unsafe {
126                if let Some(ptr) = NonNull::new(self.base() as *mut c_void) {
127                    nix::sys::mman::munmap(ptr, self.length).unwrap();
128                }
129            }
130        }
131    }
132}
133
134pub struct PagedMemory<'a> {
135    mem: Arc<MappedMemory>,
136    page_size: usize,
137    _phantom: std::marker::PhantomData<&'a ()>,
138}
139
140struct PagedMemoryEntry {
141    start: usize,
142    len: usize,
143    mem: Arc<MappedMemory>,
144    store: Box<dyn PageStore + Send + 'static>,
145    page_size: usize,
146}
147
148#[derive(Debug, PartialEq, Eq)]
149pub enum Error {
150    BaseNotAligned,
151    NullBase,
152    LengthNotAligned,
153    LengthIsZero,
154    PageSizeNotAvail,
155    NotSupported,
156    UnixError(nix::errno::Errno),
157    MemoryOverlap,
158    MemoryOverflow,
159}
160
161static HANDLER_SPIN: AtomicBool = AtomicBool::new(false);
162static HANDLER_INITIALIZED: AtomicBool = AtomicBool::new(false);
163static HANDLER_THREAD: Mutex<Option<std::thread::JoinHandle<()>>> = Mutex::new(None);
164static mut TO_HANDLER: (RawFd, RawFd) = (0, 1);
165static mut FROM_HANDLER: (RawFd, RawFd) = (0, 1);
166static mut FALLBACK_SIGSEGV_HANDLER: Option<SignalHandler> = None;
167static mut FALLBACK_SIGBUS_HANDLER: Option<SignalHandler> = None;
168
169#[inline]
170fn handle_page_fault_(info: *mut libc::siginfo_t, ctx: *mut c_void) -> bool {
171    // NOTE: this function should be SIGBUS/SIGSEGV-free, another signal can't be raised during the
172    // handling of the signal.
173    let (tx, rx, addr, ctx) = unsafe {
174        let (rx, _) = TO_HANDLER;
175        let (_, tx) = FROM_HANDLER;
176        (tx, rx, (*info).si_addr() as usize, &mut *(ctx as *mut libc::ucontext_t))
177    };
178    let flag = machdep::check_page_fault_rw_flag_from_context(*ctx);
179    let mut buff = [0; ADDR_SIZE + 1];
180    buff[..ADDR_SIZE].copy_from_slice(&addr.to_le_bytes());
181    buff[ADDR_SIZE] = flag;
182    // use a spin lock to avoid ABA (another thread could interfere in-between read and write)
183    while HANDLER_SPIN.swap(true, Ordering::Acquire) {
184        std::thread::yield_now();
185    }
186    if unistd::write(unsafe { BorrowedFd::borrow_raw(tx) }, &buff).is_err() {
187        HANDLER_SPIN.swap(false, Ordering::Release);
188        return true;
189    }
190    let _ = unistd::read(unsafe { BorrowedFd::borrow_raw(rx) }, &mut buff[..1]);
191    HANDLER_SPIN.swap(false, Ordering::Release);
192    buff[0] == 1
193}
194
195// The fallback signal handling was inspired by wasmtime trap handlers:
196// https://github.com/bytecodealliance/wasmtime/blob/v22.0.0/crates/wasmtime/src/runtime/vm/sys/unix/signals.rs
197extern "C" fn handle_page_fault(signum: libc::c_int, info: *mut libc::siginfo_t, ctx: *mut c_void) {
198    if !handle_page_fault_(info, ctx) {
199        return;
200    }
201    // Otherwise, not hitting a managed memory region, fallback to previous handler
202
203    unsafe {
204        let sig = signal::Signal::try_from(signum).expect("invalid signum");
205        let fallback_handler = match sig {
206            signal::SIGSEGV => FALLBACK_SIGSEGV_HANDLER,
207            signal::SIGBUS => FALLBACK_SIGBUS_HANDLER,
208            _ => panic!("unknown signal: {}", sig),
209        };
210
211        if let Some(handler) = fallback_handler {
212            // Delegate to the fallback handler
213            handler(signum, info, ctx);
214        } else {
215            // No fallback handler (was SIG_DFL or SIG_IGN), reset to default and raise
216            let sig_action = signal::SigAction::new(
217                signal::SigHandler::SigDfl,
218                signal::SaFlags::empty(),
219                signal::SigSet::empty(),
220            );
221            signal::sigaction(sig, &sig_action).expect("fail to reset signal handler");
222            signal::raise(sig).expect("fail to raise signal");
223            unreachable!("SIG_DFL should have terminated the process");
224        }
225    }
226}
227
228unsafe fn register_signal_handlers(handler: SignalHandler) {
229    let register = |fallback_handler: *mut Option<SignalHandler>, sig: signal::Signal| {
230        // The flags here are relatively careful, and they are...
231        //
232        // SA_SIGINFO gives us access to information like the program
233        // counter from where the fault happened.
234        //
235        // SA_ONSTACK allows us to handle signals on an alternate stack,
236        // so that the handler can run in response to running out of
237        // stack space on the main stack. Rust installs an alternate
238        // stack with sigaltstack, so we rely on that.
239        //
240        // SA_NODEFER allows us to reenter the signal handler if we
241        // crash while handling the signal, and fall through to the
242        // Breakpad handler by testing handlingSegFault.
243        let sig_action = signal::SigAction::new(
244            signal::SigHandler::SigAction(handler),
245            signal::SaFlags::SA_NODEFER | signal::SaFlags::SA_SIGINFO | signal::SaFlags::SA_ONSTACK,
246            signal::SigSet::empty(),
247        );
248
249        // Extract and save the fallback handler function pointer if it's a SigAction with SA_SIGINFO
250        unsafe {
251            let sig = signal::sigaction(sig, &sig_action).expect("fail to register signal handler");
252            *fallback_handler = match sig.handler() {
253                signal::SigHandler::SigAction(h)
254                    if sig.flags() & signal::SaFlags::SA_SIGINFO == signal::SaFlags::SA_SIGINFO =>
255                {
256                    Some(h)
257                }
258                _ => None,
259            };
260        }
261    };
262
263    register(&raw mut FALLBACK_SIGSEGV_HANDLER, signal::SIGSEGV);
264    register(&raw mut FALLBACK_SIGBUS_HANDLER, signal::SIGBUS);
265}
266
267struct PagedMemoryManager {
268    entries: Vec<PagedMemoryEntry>,
269}
270
271impl PagedMemoryManager {
272    fn insert(&mut self, entry: PagedMemoryEntry) -> bool {
273        for (i, PagedMemoryEntry { start, len, .. }) in self.entries.iter().enumerate() {
274            if entry.start + entry.len <= *start {
275                // insert before this entry
276                self.entries.insert(i, entry);
277                return true;
278            }
279            if entry.start < *start + *len {
280                // overlapping space
281                return false;
282            }
283        }
284        self.entries.push(entry);
285        true
286    }
287
288    fn remove(&mut self, start_: usize, len_: usize) {
289        for (i, PagedMemoryEntry { start, len, .. }) in self.entries.iter().enumerate() {
290            if *start == start_ && *len == len_ {
291                self.entries.remove(i);
292                return;
293            }
294        }
295        panic!(
296            "failed to locate PagedMemoryEntry (start = 0x{:x}, end = 0x{:x})",
297            start_,
298            start_ + len_
299        )
300    }
301}
302
303static MANAGER: Mutex<PagedMemoryManager> = Mutex::new(PagedMemoryManager { entries: Vec::new() });
304
305fn handler_init() {
306    let (to_read, to_write) = nix::unistd::pipe().expect("fail to create pipe to the handler");
307    let (from_read, from_write) = nix::unistd::pipe().expect("fail to create pipe from the handler");
308    let from_handler = unsafe { BorrowedFd::borrow_raw(from_read.as_raw_fd()) };
309    let to_handler = unsafe { BorrowedFd::borrow_raw(to_write.as_raw_fd()) };
310    unsafe {
311        TO_HANDLER = (to_read.into_raw_fd(), to_write.into_raw_fd());
312        FROM_HANDLER = (from_read.into_raw_fd(), from_write.into_raw_fd());
313        register_signal_handlers(handle_page_fault);
314    }
315    std::sync::atomic::fence(Ordering::SeqCst);
316    let handle = std::thread::spawn(move || {
317        let mut buff = [0; ADDR_SIZE + 1];
318        loop {
319            if unistd::read(&from_handler, &mut buff).is_err() {
320                // Pipe closed, exit thread
321                break;
322            }
323            let addr = usize::from_le_bytes(buff[..ADDR_SIZE].try_into().unwrap());
324            let (access_type, mprotect_flag) = match buff[ADDR_SIZE] {
325                0 => (AccessType::Read, ProtFlags::PROT_READ),
326                _ => (AccessType::Write, ProtFlags::PROT_READ | ProtFlags::PROT_WRITE),
327            };
328            let mut mgr = MANAGER.lock();
329            let mut fallback = 1;
330            for entry in mgr.entries.iter_mut() {
331                if entry.start <= addr && addr < entry.start + entry.len {
332                    let page_mask = usize::MAX ^ (entry.page_size - 1);
333                    let page_addr = addr & page_mask;
334                    let page_ptr = unsafe { NonNull::new_unchecked(page_addr as *mut c_void) };
335                    // load the page data
336                    let slice = entry.mem.as_slice();
337                    let base = slice.as_ptr() as usize;
338                    let page_offset = page_addr - base;
339                    if let Some(page) = entry.store.page_fault(page_offset, entry.page_size, access_type) {
340                        unsafe {
341                            nix::sys::mman::mprotect(
342                                page_ptr,
343                                entry.page_size,
344                                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
345                            )
346                            .expect("mprotect failed");
347                        }
348                        let target = &mut slice[page_offset..page_offset + entry.page_size];
349                        let mut base = 0;
350                        for chunk in page {
351                            let chunk = (*chunk).as_ref();
352                            let chunk_len = chunk.len();
353                            target[base..base + chunk_len].copy_from_slice(&chunk);
354                            base += chunk_len;
355                        }
356                    }
357                    // mark as readable/writable
358                    unsafe {
359                        nix::sys::mman::mprotect(page_ptr, entry.page_size, mprotect_flag).expect("mprotect failed");
360                    }
361                    fallback = 0;
362                    break;
363                }
364            }
365            // otherwise this SIGSEGV falls through and we don't do anything about it
366            if unistd::write(&to_handler, &[fallback]).is_err() {
367                // Pipe closed, exit thread
368                break;
369            }
370        }
371    });
372    *HANDLER_THREAD.lock() = Some(handle);
373}
374
375impl<'a> PagedMemory<'a> {
376    /// Make the memory paged in userspace with raw base pointer and length. Note that the range of
377    /// addresses should be valid throughout the lifetime of [PagedMemory] and no access should be
378    /// made to the memory unless it is wrapped in [PagedMemory::run].
379    pub unsafe fn from_raw<S: PageStore + Send + 'static>(
380        base: *mut u8, length: usize, store: S, page_size: Option<usize>,
381    ) -> Result<PagedMemory<'static>, Error> {
382        let mem: &'static mut [u8] = unsafe { std::slice::from_raw_parts_mut(base, length) };
383        Self::new_(Some(mem.as_ptr() as *mut u8), mem.len(), store, page_size)
384    }
385
386    /// Create a slice of memory that is pagged.
387    pub fn new<S: PageStore + Send + 'static>(
388        length: usize, store: S, page_size: Option<usize>,
389    ) -> Result<PagedMemory<'static>, Error> {
390        Self::new_(None, length, store, page_size)
391    }
392
393    fn new_<'b, S: PageStore + Send + 'static>(
394        base: Option<*mut u8>, length: usize, store: S, page_size: Option<usize>,
395    ) -> Result<PagedMemory<'b>, Error> {
396        if !HANDLER_INITIALIZED.swap(true, Ordering::SeqCst) {
397            handler_init();
398        }
399        let page_size = match page_size {
400            Some(s) => s,
401            None => get_page_size()?,
402        };
403        let mem = std::sync::Arc::new(MappedMemory::new(base, length, page_size, ProtFlags::PROT_NONE)?);
404        let mut mgr = MANAGER.lock();
405        if !mgr.insert(PagedMemoryEntry {
406            start: mem.base() as usize,
407            len: length,
408            mem: mem.clone(),
409            store: Box::new(store),
410            page_size,
411        }) {
412            return Err(Error::MemoryOverlap);
413        }
414
415        Ok(PagedMemory {
416            mem,
417            page_size,
418            _phantom: std::marker::PhantomData,
419        })
420    }
421
422    /*
423    /// Run code that possibly accesses the paged memory. Because an access to the memory could
424    /// wait upon the [PageStore] to load the page in case of a page fault, to avoid dead-lock,
425    /// make sure the resources the store will acquire will not be held by this code before
426    /// accessing the paged memory. For example, using `println!("{}", mem[0])` could dead-lock the
427    /// system if the `read()` implementation of [PageStore] also uses `println`: the I/O is first
428    /// locked before the dereference of `mem[0]` which could possibly induces a page fault that
429    /// invokes `read()` to bring in the page content, then when the page store tries to invoke
430    /// `println`, it gets stuck (the dereference is imcomplete). As a good practice, always make
431    /// sure [PageStore] grabs the least resources that do not overlap with the code here.
432    pub fn run<F, T>(&mut self, f: F) -> std::thread::JoinHandle<T>
433    where
434        F: FnOnce(&mut [u8]) -> T + Send + 'static,
435        T: Send + 'static,
436    {
437        let mem = self.mem.clone();
438        std::thread::spawn(move || f(mem.as_slice()))
439    }
440    */
441
442    /// Because an access to the memory could
443    /// wait upon the [PageStore] to load the page in case of a page fault, to avoid dead-lock,
444    /// make sure the resources the store will acquire will not be held by this code before
445    /// accessing the paged memory. For example, using `println!("{}", mem.as_slice()[0])` could dead-lock the
446    /// system if the `read()` implementation of [PageStore] also uses `println`: the I/O is first
447    /// locked before the dereference of `mem[0]` which could possibly induces a page fault that
448    /// invokes `read()` to bring in the page content, then when the page store tries to invoke
449    /// `println`, it gets stuck (the dereference is imcomplete). As a good practice, always make
450    /// sure [PageStore] grabs the least resources that do not overlap with the code here.
451    pub fn as_slice_mut(&mut self) -> &mut [u8] {
452        self.mem.as_slice()
453    }
454
455    pub fn as_slice(&self) -> &[u8] {
456        self.mem.as_slice()
457    }
458
459    pub fn as_raw_parts(&self) -> (*mut u8, usize) {
460        let s = self.mem.as_slice();
461        (s.as_mut_ptr(), s.len())
462    }
463
464    /// Return the configured page size.
465    pub fn page_size(&self) -> usize {
466        self.page_size
467    }
468
469    /// Mark the entire PagedMemory to be read-only, this will trigger write-access page faults
470    /// again when write operation is made in the future.
471    pub fn mark_read_only(&self, offset: usize, length: usize) {
472        assert!(offset + length <= self.mem.length);
473        unsafe {
474            let ptr = NonNull::new_unchecked(self.mem.base().add(offset) as *mut c_void);
475            nix::sys::mman::mprotect(ptr, length, ProtFlags::PROT_READ).expect("mprotect failed");
476        }
477    }
478
479    /// Release the page content loaded from PageStore. The next access to an address within this
480    /// page will trigger a page fault. `page_offset` must be one of the offset passed in by page
481    /// fault handler.
482    pub fn release_page(&self, page_offset: usize) {
483        if page_offset & (self.page_size - 1) != 0 || page_offset >= self.mem.length {
484            panic!("invalid page offset: {:x}", page_offset);
485        }
486        let page_addr = self.mem.base() as usize + page_offset;
487        unsafe {
488            nix::sys::mman::mmap_anonymous(
489                Some(NonZeroUsize::new(page_addr).unwrap()),
490                NonZeroUsize::new(self.page_size).unwrap(),
491                ProtFlags::PROT_NONE,
492                MapFlags::MAP_FIXED | MapFlags::MAP_PRIVATE,
493            )
494            .expect("mmap failed");
495        }
496    }
497
498    pub fn release_all_pages(&self) {
499        unsafe {
500            nix::sys::mman::mmap_anonymous(
501                Some(NonZeroUsize::new(self.mem.base() as usize).unwrap()),
502                NonZeroUsize::new(self.mem.length).unwrap(),
503                ProtFlags::PROT_NONE,
504                MapFlags::MAP_FIXED | MapFlags::MAP_PRIVATE,
505            )
506            .expect("mmap failed");
507        }
508        self.mem.shared.lock().clear();
509    }
510
511    pub fn make_shared(&self, offset: usize, shm: &SharedMemory) -> Result<(), Error> {
512        self.mem.make_shared(offset, shm, ProtFlags::PROT_NONE)
513    }
514}
515
516impl<'a> Drop for PagedMemory<'a> {
517    fn drop(&mut self) {
518        let mut mgr = MANAGER.lock();
519        mgr.remove(self.mem.base() as usize, self.mem.length);
520    }
521}
522
523pub struct VecPageStore(Vec<u8>);
524
525impl VecPageStore {
526    pub fn new(vec: Vec<u8>) -> Self {
527        Self(vec)
528    }
529}
530
531impl PageStore for VecPageStore {
532    fn page_fault(
533        &mut self, offset: usize, length: usize, _access: AccessType,
534    ) -> Option<Box<dyn Iterator<Item = Box<dyn AsRef<[u8]> + '_>> + '_>> {
535        #[cfg(debug_assertions)]
536        println!(
537            "{:?} loading page at 0x{:x} access={:?}",
538            self as *mut Self, offset, _access,
539        );
540        Some(Box::new(std::iter::once(
541            Box::new(&self.0[offset..offset + length]) as Box<dyn AsRef<[u8]>>
542        )))
543    }
544}
545
546#[derive(Clone)]
547pub struct SharedMemory(Arc<SharedMemoryInner>);
548
549struct SharedMemoryInner {
550    fd: std::os::fd::OwnedFd,
551    size: usize,
552}
553
554impl SharedMemory {
555    pub fn new(size: usize) -> Result<Self, Error> {
556        let fd = machdep::get_shared_memory()?;
557        nix::unistd::ftruncate(&fd, size as libc::off_t).map_err(Error::UnixError)?;
558        Ok(Self(Arc::new(SharedMemoryInner { fd, size })))
559    }
560}
561
562pub fn get_page_size() -> Result<usize, Error> {
563    Ok(unistd::sysconf(unistd::SysconfVar::PAGE_SIZE)
564        .map_err(Error::UnixError)?
565        .ok_or(Error::PageSizeNotAvail)? as usize)
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use lazy_static::lazy_static;
572    use parking_lot::Mutex;
573
574    lazy_static! {
575        static ref PAGE_SIZE: usize = unistd::sysconf(unistd::SysconfVar::PAGE_SIZE).unwrap().unwrap() as usize;
576    }
577
578    static TEST_MUTEX: Mutex<()> = Mutex::new(());
579
580    #[test]
581    fn test1() {
582        let _guard = TEST_MUTEX.lock();
583        for _ in 0..100 {
584            let mut v = Vec::new();
585            v.resize(*PAGE_SIZE * 100, 0);
586            v[0] = 42;
587            v[*PAGE_SIZE * 10 + 1] = 43;
588            v[*PAGE_SIZE * 20 + 1] = 44;
589
590            let pm = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
591            let m = pm.as_slice();
592            assert_eq!(m[0], 42);
593            assert_eq!(m[*PAGE_SIZE * 10 + 1], 43);
594            assert_eq!(m[*PAGE_SIZE * 20 + 1], 44);
595        }
596    }
597
598    #[test]
599    fn test2() {
600        let _guard = TEST_MUTEX.lock();
601        for _ in 0..100 {
602            let mut v = Vec::new();
603            v.resize(*PAGE_SIZE * 100, 0);
604            v[0] = 1;
605            v[*PAGE_SIZE * 10 + 1] = 2;
606            v[*PAGE_SIZE * 20 + 1] = 3;
607
608            let pm1 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
609
610            let mut v = Vec::new();
611            v.resize(*PAGE_SIZE * 100, 0);
612            for (i, v) in v.iter_mut().enumerate() {
613                *v = i as u8;
614            }
615            let mut pm2 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
616
617            let m2 = pm2.as_slice_mut();
618            let m1 = pm1.as_slice();
619
620            assert_eq!(m2[100], 100);
621            m2[100] = 0;
622            assert_eq!(m2[100], 0);
623
624            assert_eq!(m1[0], 1);
625            assert_eq!(m1[*PAGE_SIZE * 10 + 1], 2);
626            assert_eq!(m1[*PAGE_SIZE * 20 + 1], 3);
627        }
628    }
629
630    #[test]
631    fn test_shared_memory() {
632        let _guard = TEST_MUTEX.lock();
633        let mut v = Vec::new();
634        v.resize(*PAGE_SIZE * 100, 0);
635        v[0] = 42;
636        v[*PAGE_SIZE * 10 + 1] = 43;
637        v[*PAGE_SIZE * 20 + 1] = 44;
638
639        let shm = SharedMemory::new(*PAGE_SIZE).unwrap();
640        let mut pm1 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v.clone()), None).unwrap();
641        let pm2 = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
642        pm1.make_shared(*PAGE_SIZE * 10, &shm).unwrap();
643        pm2.make_shared(*PAGE_SIZE * 10, &shm).unwrap();
644
645        assert_eq!(pm1.as_slice()[*PAGE_SIZE * 10 + 1], 43);
646        assert_eq!(pm2.as_slice()[*PAGE_SIZE * 10 + 1], 43);
647        pm1.as_slice_mut()[*PAGE_SIZE * 10 + 1] = 99;
648        assert_eq!(pm2.as_slice()[*PAGE_SIZE * 10 + 1], 99);
649        assert_eq!(pm1.as_slice()[*PAGE_SIZE * 10 + 1], 99);
650
651        let m = pm1.as_slice();
652        assert_eq!(m[0], 42);
653        assert_eq!(m[*PAGE_SIZE * 20 + 1], 44);
654    }
655
656    #[test]
657    fn test_release_page() {
658        let _guard = TEST_MUTEX.lock();
659        let mut v = Vec::new();
660        v.resize(*PAGE_SIZE * 20, 0);
661        v[0] = 42;
662        v[*PAGE_SIZE * 10 + 1] = 43;
663
664        let pm = PagedMemory::new(*PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
665        let m = pm.as_slice();
666        assert_eq!(m[0], 42);
667        assert_eq!(m[*PAGE_SIZE * 10 + 1], 43);
668        for _ in 0..5 {
669            pm.release_page(0);
670            pm.release_page(*PAGE_SIZE * 10);
671            assert_eq!(m[0], 42);
672            assert_eq!(m[*PAGE_SIZE * 10 + 1], 43);
673        }
674    }
675
676    #[test]
677    fn out_of_order_scan() {
678        let _guard = TEST_MUTEX.lock();
679        let mut v = Vec::new();
680        v.resize(*PAGE_SIZE * 100, 0);
681        for (i, v) in v.iter_mut().enumerate() {
682            *v = i as u8;
683        }
684        let store = VecPageStore::new(v);
685        let pm = PagedMemory::new(*PAGE_SIZE * 100, store, None).unwrap();
686        use rand::{SeedableRng, seq::SliceRandom};
687        use rand_chacha::ChaChaRng;
688        let seed = [0; 32];
689        let mut rng = ChaChaRng::from_seed(seed);
690
691        let m = pm.as_slice();
692        let mut idxes = Vec::new();
693        for i in 0..m.len() {
694            idxes.push(i);
695        }
696        idxes.shuffle(&mut rng);
697        for i in idxes.into_iter() {
698            #[cfg(debug_assertions)]
699            {
700                let x = m[i];
701                println!("m[0x{:08x}] = {}", i, x);
702            }
703            assert_eq!(m[i], i as u8);
704        }
705    }
706
707    use signal::{SaFlags, SigAction, SigHandler, SigSet, Signal};
708
709    /// Reset the handler state for testing
710    unsafe fn handler_reset_init() {
711        unsafe {
712            // Close pipe file descriptors to cause the handler thread to exit
713            let (to_read, to_write) = TO_HANDLER;
714            let (from_read, from_write) = FROM_HANDLER;
715
716            if to_read != 0 {
717                let _ = nix::unistd::close(to_read);
718            }
719            if to_write != 1 {
720                let _ = nix::unistd::close(to_write);
721            }
722            if from_read != 0 {
723                let _ = nix::unistd::close(from_read);
724            }
725            if from_write != 1 {
726                let _ = nix::unistd::close(from_write);
727            }
728
729            // Wait for the handler thread to exit
730            if let Some(handle) = HANDLER_THREAD.lock().take() {
731                let _ = handle.join();
732            }
733
734            // Reset signal handlers to SIG_DFL so next init sees default handlers
735            let sig_dfl = SigAction::new(SigHandler::SigDfl, SaFlags::empty(), SigSet::empty());
736            let _ = signal::sigaction(Signal::SIGSEGV, &sig_dfl);
737            let _ = signal::sigaction(Signal::SIGBUS, &sig_dfl);
738
739            // Reset the init flag so handler_init can run again
740            HANDLER_INITIALIZED.store(false, Ordering::SeqCst);
741
742            // Clear fallback handlers
743            FALLBACK_SIGSEGV_HANDLER = None;
744            FALLBACK_SIGBUS_HANDLER = None;
745
746            // Reset pipe fds to initial values
747            TO_HANDLER = (0, 1);
748            FROM_HANDLER = (0, 1);
749        }
750    }
751
752    static SIGSEGV_CALLED: AtomicBool = AtomicBool::new(false);
753    static SIGBUS_CALLED: AtomicBool = AtomicBool::new(false);
754
755    extern "C" fn test_sigsegv_handler(_signum: libc::c_int, info: *mut libc::siginfo_t, _ctx: *mut c_void) {
756        SIGSEGV_CALLED.store(true, Ordering::SeqCst);
757        // Make the memory accessible so the instruction can succeed on retry
758        unsafe {
759            let addr = (*info).si_addr();
760            let page_size = nix::unistd::sysconf(nix::unistd::SysconfVar::PAGE_SIZE)
761                .unwrap()
762                .unwrap() as usize;
763            let page_addr = (addr as usize) & !(page_size - 1);
764            nix::sys::mman::mprotect(
765                NonNull::new_unchecked(page_addr as *mut c_void),
766                page_size,
767                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
768            )
769            .expect("mprotect failed in handler");
770        }
771    }
772
773    extern "C" fn test_sigbus_handler(_signum: libc::c_int, info: *mut libc::siginfo_t, _ctx: *mut c_void) {
774        SIGBUS_CALLED.store(true, Ordering::SeqCst);
775        unsafe {
776            let addr = (*info).si_addr();
777            let page_size = nix::unistd::sysconf(nix::unistd::SysconfVar::PAGE_SIZE)
778                .unwrap()
779                .unwrap() as usize;
780            let page_addr = (addr as usize) & !(page_size - 1);
781            nix::sys::mman::mprotect(
782                NonNull::new_unchecked(page_addr as *mut c_void),
783                page_size,
784                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
785            )
786            .expect("mprotect failed in handler");
787        }
788    }
789
790    #[test]
791    fn test_fallback_handlers_set_and_called() {
792        let _guard = TEST_MUTEX.lock();
793
794        unsafe {
795            // Reset handler state before test
796            handler_reset_init();
797
798            // Register the SIGSEGV handler handler
799            let sigsegv_action = SigAction::new(
800                SigHandler::SigAction(test_sigsegv_handler),
801                SaFlags::SA_SIGINFO | SaFlags::SA_NODEFER,
802                SigSet::empty(),
803            );
804            signal::sigaction(Signal::SIGSEGV, &sigsegv_action).expect("failed to set SIGSEGV handler");
805
806            // Register the fallback SIGBUS handler
807            let sigbus_action = SigAction::new(
808                SigHandler::SigAction(test_sigbus_handler),
809                SaFlags::SA_SIGINFO | SaFlags::SA_NODEFER,
810                SigSet::empty(),
811            );
812            signal::sigaction(Signal::SIGBUS, &sigbus_action).expect("failed to set SIGBUS handler");
813
814            // Create a PagedMemory - this will trigger handler_init() which considers the fallback
815            // handlers.
816            let _pm1 = PagedMemory::new(*PAGE_SIZE, VecPageStore::new(vec![0u8; *PAGE_SIZE]), None).unwrap();
817
818            // Save the handler pointers to verify they don't change
819            let saved_sigsegv = FALLBACK_SIGSEGV_HANDLER.map(|f| f as usize);
820            let saved_sigbus = FALLBACK_SIGBUS_HANDLER.map(|f| f as usize);
821
822            // Verify that the fallback handlers were saved
823            assert!(saved_sigsegv.is_some(), "SIGSEGV fallback handler should be saved");
824            assert!(saved_sigbus.is_some(), "SIGBUS fallback handler should be saved");
825
826            // Create another PagedMemory - handler_init() should NOT run again due to Once guard
827            let _pm2 = PagedMemory::new(*PAGE_SIZE, VecPageStore::new(vec![0u8; *PAGE_SIZE]), None).unwrap();
828
829            // Verify the handler pointers haven't changed (Once guard prevented re-registration)
830            let current_sigsegv = FALLBACK_SIGSEGV_HANDLER.map(|f| f as usize);
831            let current_sigbus = FALLBACK_SIGBUS_HANDLER.map(|f| f as usize);
832            assert_eq!(
833                current_sigsegv, saved_sigsegv,
834                "SIGSEGV fallback handler should not change"
835            );
836            assert_eq!(
837                current_sigbus, saved_sigbus,
838                "SIGBUS fallback handler should not change"
839            );
840
841            // Test SIGSEGV/SIGBUS fallback by accessing memory protected with PROT_NONE
842            SIGSEGV_CALLED.store(false, Ordering::SeqCst);
843            SIGBUS_CALLED.store(false, Ordering::SeqCst);
844
845            // Allocate memory with PROT_NONE to trigger SIGSEGV or SIGBUS when accessed
846            let test_mem = nix::sys::mman::mmap_anonymous(
847                None,
848                NonZeroUsize::new(*PAGE_SIZE).unwrap(),
849                ProtFlags::PROT_NONE,
850                MapFlags::MAP_PRIVATE | MapFlags::MAP_ANONYMOUS,
851            )
852            .expect("mmap failed");
853
854            // Access the protected memory - this triggers SIGSEGV or SIGBUS, handler makes it accessible
855            std::ptr::write_volatile(test_mem.cast::<u8>().as_ptr(), 42);
856
857            // Verify at least one fallback handler was called (platform dependent)
858            assert!(
859                SIGSEGV_CALLED.load(Ordering::SeqCst) || SIGBUS_CALLED.load(Ordering::SeqCst),
860                "SIGSEGV or SIGBUS fallback handler should have been called"
861            );
862
863            // Clean up
864            nix::sys::mman::munmap(test_mem.cast(), *PAGE_SIZE).expect("munmap failed");
865
866            // Reset handler state after test
867            handler_reset_init();
868        }
869    }
870}