userspace_paging/
lib.rs

1pub use nix::libc;
2use nix::libc::c_void;
3use nix::sys::mman::{MapFlags, ProtFlags};
4use nix::sys::signal;
5use nix::unistd;
6use parking_lot::Mutex;
7use std::num::NonZeroUsize;
8use std::os::fd::RawFd;
9use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
10use std::sync::Arc;
11
12const ADDR_SIZE: usize = std::mem::size_of::<usize>();
13
14#[derive(Debug, PartialEq, Eq)]
15pub enum AccessType {
16    Read,
17    Write,
18}
19
20pub trait PageStore {
21    /// Callback that is triggered upon a page fault. Return `Some()` if the page needs to be loaded with the given data.
22    fn page_fault(&mut self, offset: usize, length: usize, access: AccessType) -> Option<Vec<u8>>;
23}
24
25struct MappedMemory {
26    base: AtomicPtr<u8>,
27    length: usize,
28}
29
30impl MappedMemory {
31    #[inline(always)]
32    fn base(&self) -> *mut u8 {
33        unsafe { *self.base.as_ptr() }
34    }
35
36    #[inline(always)]
37    fn as_slice(&self) -> &mut [u8] {
38        unsafe { std::slice::from_raw_parts_mut(self.base(), self.length) }
39    }
40}
41
42pub struct PagedMemory<'a> {
43    mem: Arc<MappedMemory>,
44    unmap: bool,
45    page_size: usize,
46    _phantom: std::marker::PhantomData<&'a ()>,
47}
48
49struct PagedMemoryEntry {
50    start: usize,
51    len: usize,
52    mem: Arc<MappedMemory>,
53    store: Box<dyn PageStore + Send + 'static>,
54    page_size: usize,
55}
56
57#[derive(Debug, PartialEq, Eq)]
58pub enum Error {
59    BaseNotAligned,
60    NullBase,
61    LengthNotAligned,
62    LengthIsZero,
63    PageSizeNotAvail,
64    NotSupported,
65    UnixError(nix::errno::Errno),
66    MemoryOverlap,
67}
68
69static HANDLER_SPIN: AtomicBool = AtomicBool::new(false);
70static mut TO_HANDLER: (RawFd, RawFd) = (0, 1);
71static mut FROM_HANDLER: (RawFd, RawFd) = (0, 1);
72
73extern "C" {
74    fn userspace_paging_fallback(signum: libc::c_int, info: *mut libc::siginfo_t, ctx: *mut c_void);
75}
76
77#[inline]
78fn handle_page_fault_(info: *mut libc::siginfo_t, ctx: *mut c_void) -> bool {
79    let (tx, rx, addr, ctx) = unsafe {
80        let (rx, _) = TO_HANDLER;
81        let (_, tx) = FROM_HANDLER;
82        (tx, rx, (*info).si_addr() as usize, &mut *(ctx as *mut libc::ucontext_t))
83    };
84    #[cfg(target_arch = "x86_64")]
85    let flag = ((ctx.uc_mcontext.gregs[libc::REG_ERR as usize] & 0x2) >> 1) as u8;
86    #[cfg(not(target_arch = "x86_64"))]
87    let flag = 1;
88    let mut buff = [0; ADDR_SIZE + 1];
89    buff[..ADDR_SIZE].copy_from_slice(&addr.to_le_bytes());
90    buff[ADDR_SIZE] = flag;
91    // use a spin lock to avoid ABA (another thread could interfere in-between read and write)
92    while HANDLER_SPIN.swap(true, Ordering::Acquire) {
93        std::thread::yield_now();
94    }
95    if unistd::write(tx, &buff).is_err() {
96        HANDLER_SPIN.swap(false, Ordering::Release);
97        return true
98    }
99    unistd::read(rx, &mut buff[..1]).ok();
100    HANDLER_SPIN.swap(false, Ordering::Release);
101    buff[0] == 1
102}
103
104extern "C" fn handle_page_fault(_: i32, info: *mut libc::siginfo_t, ctx: *mut c_void) {
105    handle_page_fault_(info, ctx);
106}
107
108extern "C" fn handle_page_fault_with_fallback(signum: libc::c_int, info: *mut libc::siginfo_t, ctx: *mut c_void) {
109    if handle_page_fault_(info, ctx) {
110        // not hitting a managed memory region, fallback to some handler
111        unsafe {
112            userspace_paging_fallback(signum, info, ctx);
113        }
114    }
115}
116
117struct PagedMemoryManager {
118    entries: Vec<PagedMemoryEntry>,
119}
120
121impl PagedMemoryManager {
122    fn insert(&mut self, entry: PagedMemoryEntry) -> bool {
123        for (i, PagedMemoryEntry { start, len, .. }) in self.entries.iter().enumerate() {
124            if entry.start + entry.len <= *start {
125                // insert before this entry
126                self.entries.insert(i, entry);
127                return true
128            }
129            if entry.start < *start + *len {
130                // overlapping space
131                return false
132            }
133        }
134        self.entries.push(entry);
135        true
136    }
137
138    fn remove(&mut self, start_: usize, len_: usize) {
139        for (i, PagedMemoryEntry { start, len, .. }) in self.entries.iter().enumerate() {
140            if *start == start_ && *len == len_ {
141                self.entries.remove(i);
142                return
143            }
144        }
145        panic!(
146            "failed to locate PagedMemoryEntry (start = 0x{:x}, end = 0x{:x})",
147            start_,
148            start_ + len_
149        )
150    }
151}
152
153static MANAGER: Mutex<PagedMemoryManager> = Mutex::new(PagedMemoryManager { entries: Vec::new() });
154
155unsafe fn register_signal_handlers_(handler: extern "C" fn(i32, *mut libc::siginfo_t, *mut c_void)) {
156    let sig_action = signal::SigAction::new(
157        signal::SigHandler::SigAction(handler),
158        signal::SaFlags::SA_NODEFER | signal::SaFlags::SA_SIGINFO,
159        signal::SigSet::empty(),
160    );
161    signal::sigaction(signal::SIGSEGV, &sig_action).expect("fail to register SIGSEGV handler");
162    signal::sigaction(signal::SIGBUS, &sig_action).expect("fail to register SIGBUS handler");
163}
164
165pub unsafe fn register_signal_handlers() {
166    register_signal_handlers_(handle_page_fault);
167}
168
169pub unsafe fn register_signal_handlers_with_fallback() {
170    register_signal_handlers_(handle_page_fault_with_fallback);
171}
172
173fn handler_init() {
174    let to_handler = nix::unistd::pipe().expect("fail to create pipe to the handler");
175    let from_handler = nix::unistd::pipe().expect("fail to create pipe from the handler");
176    unsafe {
177        TO_HANDLER = to_handler;
178        FROM_HANDLER = from_handler;
179        register_signal_handlers();
180    }
181    std::sync::atomic::fence(Ordering::SeqCst);
182    std::thread::spawn(move || {
183        let from_handler = from_handler.0;
184        let to_handler = to_handler.1;
185        let mut buff = [0; ADDR_SIZE + 1];
186        loop {
187            unistd::read(from_handler, &mut buff).unwrap();
188            let addr = usize::from_le_bytes(buff[..ADDR_SIZE].try_into().unwrap());
189            let (access_type, mprotect_flag) = match buff[ADDR_SIZE] {
190                0 => (AccessType::Read, ProtFlags::PROT_READ),
191                _ => (AccessType::Write, ProtFlags::PROT_READ | ProtFlags::PROT_WRITE),
192            };
193            let mut mgr = MANAGER.lock();
194            let mut fallback = 1;
195            for entry in mgr.entries.iter_mut() {
196                if entry.start <= addr && addr < entry.start + entry.len {
197                    let page_mask = usize::MAX ^ (entry.page_size - 1);
198                    let page_addr = addr & page_mask;
199                    // load the page data
200                    let slice = entry.mem.as_slice();
201                    let base = slice.as_ptr() as usize;
202                    let page_offset = page_addr - base;
203                    if let Some(page) = entry.store.page_fault(page_offset, entry.page_size, access_type) {
204                        unsafe {
205                            nix::sys::mman::mprotect(
206                                page_addr as *mut c_void,
207                                entry.page_size,
208                                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
209                            )
210                            .expect("mprotect failed");
211                        }
212                        slice[page_offset..page_offset + entry.page_size].copy_from_slice(&page);
213                    }
214                    // mark as readable/writable
215                    unsafe {
216                        nix::sys::mman::mprotect(page_addr as *mut c_void, entry.page_size, mprotect_flag)
217                            .expect("mprotect failed");
218                    }
219                    fallback = 0;
220                    break
221                }
222            }
223            // otherwise this SIGSEGV falls through and we don't do anything about it
224            unistd::write(to_handler, &[fallback]).unwrap();
225        }
226    });
227}
228
229impl<'a> PagedMemory<'a> {
230    /// Make the memory paged in userspace with raw base pointer and length. Note that the range of
231    /// addresses should be valid throughout the lifetime of [PagedMemory] and no access should be
232    /// made to the memory unless it is wrapped in [PagedMemory::run].
233    pub unsafe fn from_raw<S: PageStore + Send + 'static>(
234        base: *mut u8, length: usize, store: S, page_size: Option<usize>,
235    ) -> Result<PagedMemory<'static>, Error> {
236        let mem: &'static mut [u8] = std::slice::from_raw_parts_mut(base, length);
237        Self::new_(Some(mem.as_ptr() as *mut u8), mem.len(), store, page_size)
238    }
239
240    /// Create a slice of memory that is pagged.
241    pub fn new<S: PageStore + Send + 'static>(
242        length: usize, store: S, page_size: Option<usize>,
243    ) -> Result<PagedMemory<'static>, Error> {
244        Self::new_(None, length, store, page_size)
245    }
246
247    fn new_<'b, S: PageStore + Send + 'static>(
248        base: Option<*mut u8>, mut length: usize, store: S, page_size: Option<usize>,
249    ) -> Result<PagedMemory<'b>, Error> {
250        static INIT: std::sync::Once = std::sync::Once::new();
251        INIT.call_once(|| handler_init());
252
253        let page_size = match page_size {
254            Some(s) => s,
255            None => unistd::sysconf(unistd::SysconfVar::PAGE_SIZE)
256                .map_err(Error::UnixError)?
257                .ok_or(Error::PageSizeNotAvail)? as usize,
258        };
259        let rem = length & (page_size - 1);
260        match base {
261            Some(base) => {
262                if (base as usize) % page_size != 0 {
263                    return Err(Error::BaseNotAligned)
264                }
265                if rem != 0 {
266                    return Err(Error::LengthNotAligned)
267                }
268            }
269            None => {
270                if rem != 0 {
271                    length += page_size - rem
272                }
273            }
274        }
275
276        let new_base = unsafe {
277            nix::sys::mman::mmap(
278                match base {
279                    Some(b) => Some(NonZeroUsize::new(b as usize).ok_or(Error::NullBase)?),
280                    None => None,
281                },
282                NonZeroUsize::new(length).ok_or(Error::LengthIsZero)?,
283                ProtFlags::PROT_NONE,
284                match base {
285                    Some(_) => MapFlags::MAP_FIXED,
286                    None => MapFlags::empty(),
287                } | MapFlags::MAP_PRIVATE |
288                    MapFlags::MAP_ANONYMOUS,
289                Option::<std::fs::File>::None,
290                0,
291            )
292            .map_err(Error::UnixError)?
293        } as *mut u8;
294
295        if let Some(base) = base {
296            if base != new_base {
297                return Err(Error::NotSupported)
298            }
299        }
300
301        let mem = std::sync::Arc::new(MappedMemory {
302            base: new_base.into(),
303            length,
304        });
305        let mem_clone = mem.clone();
306        let mut mgr = MANAGER.lock();
307        if !mgr.insert(PagedMemoryEntry {
308            start: new_base as usize,
309            len: length,
310            mem: mem_clone,
311            store: Box::new(store),
312            page_size,
313        }) {
314            return Err(Error::MemoryOverlap)
315        }
316
317        Ok(PagedMemory {
318            mem,
319            unmap: base.is_none(),
320            page_size,
321            _phantom: std::marker::PhantomData,
322        })
323    }
324
325    /*
326    /// Run code that possibly accesses the paged memory. Because an access to the memory could
327    /// wait upon the [PageStore] to load the page in case of a page fault, to avoid dead-lock,
328    /// make sure the resources the store will acquire will not be held by this code before
329    /// accessing the paged memory. For example, using `println!("{}", mem[0])` could dead-lock the
330    /// system if the `read()` implementation of [PageStore] also uses `println`: the I/O is first
331    /// locked before the dereference of `mem[0]` which could possibly induces a page fault that
332    /// invokes `read()` to bring in the page content, then when the page store tries to invoke
333    /// `println`, it gets stuck (the dereference is imcomplete). As a good practice, always make
334    /// sure [PageStore] grabs the least resources that do not overlap with the code here.
335    pub fn run<F, T>(&mut self, f: F) -> std::thread::JoinHandle<T>
336    where
337        F: FnOnce(&mut [u8]) -> T + Send + 'static,
338        T: Send + 'static,
339    {
340        let mem = self.mem.clone();
341        std::thread::spawn(move || f(mem.as_slice()))
342    }
343    */
344
345    /// Because an access to the memory could
346    /// wait upon the [PageStore] to load the page in case of a page fault, to avoid dead-lock,
347    /// make sure the resources the store will acquire will not be held by this code before
348    /// accessing the paged memory. For example, using `println!("{}", mem.as_slice()[0])` could dead-lock the
349    /// system if the `read()` implementation of [PageStore] also uses `println`: the I/O is first
350    /// locked before the dereference of `mem[0]` which could possibly induces a page fault that
351    /// invokes `read()` to bring in the page content, then when the page store tries to invoke
352    /// `println`, it gets stuck (the dereference is imcomplete). As a good practice, always make
353    /// sure [PageStore] grabs the least resources that do not overlap with the code here.
354    pub fn as_slice_mut(&mut self) -> &mut [u8] {
355        self.mem.as_slice()
356    }
357
358    pub fn as_slice(&self) -> &[u8] {
359        self.mem.as_slice()
360    }
361
362    pub fn as_raw_parts(&self) -> (*mut u8, usize) {
363        let s = self.mem.as_slice();
364        (s.as_mut_ptr(), s.len())
365    }
366
367    /// Return the configured page size.
368    pub fn page_size(&self) -> usize {
369        self.page_size
370    }
371
372    /// Release the page content loaded from PageStore. The next access to an address within this
373    /// page will trigger a page fault. `page_offset` must be one of the offset passed in by page
374    /// fault handler.
375    pub fn release_page(&self, page_offset: usize) {
376        if page_offset & (self.page_size - 1) != 0 || page_offset >= self.mem.length {
377            panic!("invalid page offset: {:x}", page_offset);
378        }
379        let page_addr = self.mem.base() as usize + page_offset;
380        unsafe {
381            nix::sys::mman::mmap(
382                Some(NonZeroUsize::new(page_addr).unwrap()),
383                NonZeroUsize::new(self.page_size).unwrap(),
384                ProtFlags::PROT_NONE,
385                MapFlags::MAP_FIXED | MapFlags::MAP_PRIVATE | MapFlags::MAP_ANONYMOUS,
386                Option::<std::fs::File>::None,
387                0,
388            )
389            .expect("mmap failed");
390        }
391    }
392
393    pub fn release_all_pages(&self) {
394        unsafe {
395            nix::sys::mman::mmap(
396                Some(NonZeroUsize::new(self.mem.base() as usize).unwrap()),
397                NonZeroUsize::new(self.mem.length).unwrap(),
398                ProtFlags::PROT_NONE,
399                MapFlags::MAP_FIXED | MapFlags::MAP_PRIVATE | MapFlags::MAP_ANONYMOUS,
400                Option::<std::fs::File>::None,
401                0,
402            )
403            .expect("mmap failed");
404        }
405    }
406}
407
408impl<'a> Drop for PagedMemory<'a> {
409    fn drop(&mut self) {
410        let mut mgr = MANAGER.lock();
411        mgr.remove(self.mem.base() as usize, self.mem.length);
412        if self.unmap {
413            unsafe {
414                nix::sys::mman::munmap(self.mem.base() as *mut c_void, self.mem.length).unwrap();
415            }
416        }
417    }
418}
419
420pub struct VecPageStore {
421    vec: Vec<u8>,
422    bitmask: [u64; 0x4000],
423}
424
425impl VecPageStore {
426    pub fn new(vec: Vec<u8>) -> Self {
427        Self {
428            vec,
429            bitmask: [0; 0x4000],
430        }
431    }
432
433    fn get_page_bit(&self, page_num: usize) -> bool {
434        ((self.bitmask[page_num >> 6] >> (page_num & 63)) & 1) == 1
435    }
436
437    fn set_page_bit(&mut self, page_num: usize) {
438        self.bitmask[page_num >> 6] |= 1 << (page_num & 63)
439    }
440}
441
442impl PageStore for VecPageStore {
443    fn page_fault(&mut self, offset: usize, length: usize, _access: AccessType) -> Option<Vec<u8>> {
444        let page_num = offset / length;
445        if self.get_page_bit(page_num) {
446            #[cfg(debug_assertions)]
447            println!(
448                "{:?} page fault at 0x{:x} access={:?}",
449                self as *mut Self, offset, _access,
450            );
451        } else {
452            #[cfg(debug_assertions)]
453            println!(
454                "{:?} loading page at 0x{:x} access={:?}",
455                self as *mut Self, offset, _access,
456            );
457            self.set_page_bit(page_num);
458        }
459        Some(self.vec[offset..offset + length].to_vec())
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    const PAGE_SIZE: usize = 4096;
468
469    #[test]
470    fn test1() {
471        for _ in 0..100 {
472            let mut v = Vec::new();
473            v.resize(PAGE_SIZE * 100, 0);
474            v[0] = 42;
475            v[PAGE_SIZE * 10 + 1] = 43;
476            v[PAGE_SIZE * 20 + 1] = 44;
477
478            let pm = PagedMemory::new(PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
479            let m = pm.as_slice();
480            assert_eq!(m[0], 42);
481            assert_eq!(m[PAGE_SIZE * 10 + 1], 43);
482            assert_eq!(m[PAGE_SIZE * 20 + 1], 44);
483        }
484    }
485
486    #[test]
487    fn test2() {
488        for _ in 0..100 {
489            let mut v = Vec::new();
490            v.resize(PAGE_SIZE * 100, 0);
491            v[0] = 1;
492            v[PAGE_SIZE * 10 + 1] = 2;
493            v[PAGE_SIZE * 20 + 1] = 3;
494
495            let pm1 = PagedMemory::new(PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
496
497            let mut v = Vec::new();
498            v.resize(PAGE_SIZE * 100, 0);
499            for (i, v) in v.iter_mut().enumerate() {
500                *v = i as u8;
501            }
502            let mut pm2 = PagedMemory::new(PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
503
504            let m2 = pm2.as_slice_mut();
505            let m1 = pm1.as_slice();
506
507            assert_eq!(m2[100], 100);
508            m2[100] = 0;
509            assert_eq!(m2[100], 0);
510
511            assert_eq!(m1[0], 1);
512            assert_eq!(m1[PAGE_SIZE * 10 + 1], 2);
513            assert_eq!(m1[PAGE_SIZE * 20 + 1], 3);
514        }
515    }
516
517    #[test]
518    fn test_release_page() {
519        let mut v = Vec::new();
520        v.resize(PAGE_SIZE * 20, 0);
521        v[0] = 42;
522        v[PAGE_SIZE * 10 + 1] = 43;
523
524        let pm = PagedMemory::new(PAGE_SIZE * 100, VecPageStore::new(v), None).unwrap();
525        let m = pm.as_slice();
526        assert_eq!(m[0], 42);
527        assert_eq!(m[PAGE_SIZE * 10 + 1], 43);
528        for _ in 0..5 {
529            pm.release_page(0);
530            pm.release_page(PAGE_SIZE * 10);
531            assert_eq!(m[0], 42);
532            assert_eq!(m[PAGE_SIZE * 10 + 1], 43);
533        }
534    }
535
536    #[test]
537    fn out_of_order_scan() {
538        let mut v = Vec::new();
539        v.resize(PAGE_SIZE * 100, 0);
540        for (i, v) in v.iter_mut().enumerate() {
541            *v = i as u8;
542        }
543        let store = VecPageStore::new(v);
544        let pm = PagedMemory::new(PAGE_SIZE * 100, store, None).unwrap();
545        use rand::{seq::SliceRandom, SeedableRng};
546        use rand_chacha::ChaChaRng;
547        let seed = [0; 32];
548        let mut rng = ChaChaRng::from_seed(seed);
549
550        let m = pm.as_slice();
551        let mut idxes = Vec::new();
552        for i in 0..m.len() {
553            idxes.push(i);
554        }
555        idxes.shuffle(&mut rng);
556        for i in idxes.into_iter() {
557            #[cfg(debug_assertions)]
558            {
559                let x = m[i];
560                println!("m[0x{:08x}] = {}", i, x);
561            }
562            assert_eq!(m[i], i as u8);
563        }
564    }
565}