vmi_core/ctx/
prober.rs

1use std::cell::RefCell;
2
3use indexmap::IndexSet;
4
5use crate::{AddressContext, PageFaults, VmiError};
6
7/// Prober for safely handling page faults during memory access operations.
8pub struct VmiProber {
9    /// The set of restricted page faults that are allowed to occur.
10    restricted: IndexSet<AddressContext>,
11
12    /// The set of page faults that have occurred.
13    page_faults: RefCell<IndexSet<AddressContext>>,
14}
15
16impl VmiProber {
17    /// Creates a new prober.
18    pub fn new(restricted: &IndexSet<AddressContext>) -> Self {
19        Self {
20            restricted: restricted.clone(),
21            page_faults: RefCell::new(IndexSet::new()),
22        }
23    }
24
25    /// Probes for safely handling page faults during memory access operations.
26    pub fn probe<T, F>(&self, f: F) -> Result<Option<T>, VmiError>
27    where
28        F: FnOnce() -> Result<T, VmiError>,
29    {
30        self.check_result(f())
31    }
32
33    /// Handles a result that may contain page faults, returning the value
34    /// if successful.
35    pub fn check_result<T>(&self, result: Result<T, VmiError>) -> Result<Option<T>, VmiError> {
36        match result {
37            Ok(value) => Ok(Some(value)),
38            Err(VmiError::Translation(pfs)) => {
39                self.check_restricted(pfs);
40                Ok(None)
41            }
42            Err(err) => Err(err),
43        }
44    }
45
46    /*
47    /// Handles a result that may contain page faults over a memory range,
48    /// returning the value if successful.
49    fn check_result_range<T>(
50        &self,
51        result: Result<T, VmiError>,
52        ctx: AccessContext,
53        length: usize,
54    ) -> Result<Option<T>, VmiError> {
55        match result {
56            Ok(value) => Ok(Some(value)),
57            Err(VmiError::Translation(pfs)) => {
58                debug_assert_eq!(pfs.len(), 1);
59                self.check_restricted_range(pfs[0], ctx, length);
60                Ok(None)
61            }
62            Err(err) => Err(err),
63        }
64    }
65    */
66
67    /// Records any page faults that are not in the restricted set.
68    fn check_restricted(&self, pfs: PageFaults) {
69        let mut page_faults = self.page_faults.borrow_mut();
70        for pf in pfs {
71            if !self.restricted.contains(&pf) {
72                tracing::trace!(va = %pf.va, "page fault");
73                page_faults.insert(pf);
74            }
75            else {
76                tracing::trace!(va = %pf.va, "page fault (restricted)");
77            }
78        }
79    }
80
81    /*
82    /// Records any page faults that are not in the restricted set over
83    /// a memory range.
84    fn check_restricted_range(&self, pf: PageFault, ctx: AccessContext, mut length: usize) {
85        let mut page_faults = self.page_faults.borrow_mut();
86
87        if length == 0 {
88            length = 1;
89        }
90
91        //
92        // Generate page faults for the range of addresses that would be accessed by the read.
93        // Start at the page containing the faulting address and end at the page containing the
94        // last byte of the read.
95        //
96
97        let pf_page = pf.address.0 >> Driver::Architecture::PAGE_SHIFT;
98        let last_page = (ctx.address + length as u64 - 1) >> Driver::Architecture::PAGE_SHIFT;
99        let number_of_pages = last_page.saturating_sub(pf_page) + 1;
100
101        let pf_address_aligned = Va(pf_page << Driver::Architecture::PAGE_SHIFT);
102        let last_address_aligned = Va(last_page << Driver::Architecture::PAGE_SHIFT);
103
104        if number_of_pages > 1 {
105            tracing::debug!(
106                from = %pf_address_aligned,
107                to = %last_address_aligned,
108                number_of_pages,
109                "page fault (range)"
110            );
111
112            if number_of_pages >= 4096 {
113                tracing::warn!(
114                    from = %pf_address_aligned,
115                    to = %last_address_aligned,
116                    number_of_pages,
117                    "page fault range too large"
118                );
119            }
120        }
121
122        for i in 0..number_of_pages {
123            //
124            // Ensure that the page fault is for the root that we are tracking.
125            //
126
127            debug_assert_eq!(
128                pf.root,
129                match ctx.mechanism {
130                    TranslationMechanism::Paging { root: Some(root) } => root,
131                    _ => panic!("page fault root doesn't match the context root"),
132                }
133            );
134
135            let pf = PageFault {
136                address: pf_address_aligned + i * Driver::Architecture::PAGE_SIZE,
137                root: pf.root,
138            };
139
140            if !self.restricted.contains(&pf) {
141                tracing::trace!(va = %pf.address, "page fault");
142                page_faults.insert(pf);
143            }
144            else {
145                tracing::trace!(va = %pf.address, "page fault (restricted)");
146            }
147        }
148    }
149    */
150
151    /// Checks for any unexpected page faults that have occurred and returns
152    /// an error if any are present.
153    #[tracing::instrument(skip_all, err)]
154    pub fn error_for_page_faults(&self) -> Result<(), VmiError> {
155        let pfs = self.page_faults.borrow();
156        let new_pfs = &*pfs - &self.restricted;
157        if !new_pfs.is_empty() {
158            tracing::trace!(?new_pfs);
159            return Err(VmiError::page_faults(new_pfs));
160        }
161
162        Ok(())
163    }
164}
165
166/// Probes for safely handling page faults during memory access operations.
167#[macro_export]
168macro_rules! vmi_probe {
169    ($prober:expr, $expr:expr) => {
170        $prober.check_result(|| -> Result<_, VmiError> { $expr }())
171    };
172}