vmi_core/ctx/prober.rs
1use std::cell::RefCell;
2
3use indexmap::IndexSet;
4
5use crate::{AddressContext, PageFaults, VmiError};
6
7/// Prober for safely handling page faults during memory access operations.
8pub struct VmiProber {
9 /// The set of restricted page faults that are allowed to occur.
10 restricted: IndexSet<AddressContext>,
11
12 /// The set of page faults that have occurred.
13 page_faults: RefCell<IndexSet<AddressContext>>,
14}
15
16impl VmiProber {
17 /// Creates a new prober.
18 pub fn new(restricted: &IndexSet<AddressContext>) -> Self {
19 Self {
20 restricted: restricted.clone(),
21 page_faults: RefCell::new(IndexSet::new()),
22 }
23 }
24
25 /// Probes for safely handling page faults during memory access operations.
26 pub fn probe<T, F>(&self, f: F) -> Result<Option<T>, VmiError>
27 where
28 F: FnOnce() -> Result<T, VmiError>,
29 {
30 self.check_result(f())
31 }
32
33 /// Handles a result that may contain page faults, returning the value
34 /// if successful.
35 pub fn check_result<T>(&self, result: Result<T, VmiError>) -> Result<Option<T>, VmiError> {
36 match result {
37 Ok(value) => Ok(Some(value)),
38 Err(VmiError::Translation(pfs)) => {
39 self.check_restricted(pfs);
40 Ok(None)
41 }
42 Err(err) => Err(err),
43 }
44 }
45
46 /*
47 /// Handles a result that may contain page faults over a memory range,
48 /// returning the value if successful.
49 fn check_result_range<T>(
50 &self,
51 result: Result<T, VmiError>,
52 ctx: AccessContext,
53 length: usize,
54 ) -> Result<Option<T>, VmiError> {
55 match result {
56 Ok(value) => Ok(Some(value)),
57 Err(VmiError::Translation(pfs)) => {
58 debug_assert_eq!(pfs.len(), 1);
59 self.check_restricted_range(pfs[0], ctx, length);
60 Ok(None)
61 }
62 Err(err) => Err(err),
63 }
64 }
65 */
66
67 /// Records any page faults that are not in the restricted set.
68 fn check_restricted(&self, pfs: PageFaults) {
69 let mut page_faults = self.page_faults.borrow_mut();
70 for pf in pfs {
71 if !self.restricted.contains(&pf) {
72 tracing::trace!(va = %pf.va, "page fault");
73 page_faults.insert(pf);
74 }
75 else {
76 tracing::trace!(va = %pf.va, "page fault (restricted)");
77 }
78 }
79 }
80
81 /*
82 /// Records any page faults that are not in the restricted set over
83 /// a memory range.
84 fn check_restricted_range(&self, pf: PageFault, ctx: AccessContext, mut length: usize) {
85 let mut page_faults = self.page_faults.borrow_mut();
86
87 if length == 0 {
88 length = 1;
89 }
90
91 //
92 // Generate page faults for the range of addresses that would be accessed by the read.
93 // Start at the page containing the faulting address and end at the page containing the
94 // last byte of the read.
95 //
96
97 let pf_page = pf.address.0 >> Driver::Architecture::PAGE_SHIFT;
98 let last_page = (ctx.address + length as u64 - 1) >> Driver::Architecture::PAGE_SHIFT;
99 let number_of_pages = last_page.saturating_sub(pf_page) + 1;
100
101 let pf_address_aligned = Va(pf_page << Driver::Architecture::PAGE_SHIFT);
102 let last_address_aligned = Va(last_page << Driver::Architecture::PAGE_SHIFT);
103
104 if number_of_pages > 1 {
105 tracing::debug!(
106 from = %pf_address_aligned,
107 to = %last_address_aligned,
108 number_of_pages,
109 "page fault (range)"
110 );
111
112 if number_of_pages >= 4096 {
113 tracing::warn!(
114 from = %pf_address_aligned,
115 to = %last_address_aligned,
116 number_of_pages,
117 "page fault range too large"
118 );
119 }
120 }
121
122 for i in 0..number_of_pages {
123 //
124 // Ensure that the page fault is for the root that we are tracking.
125 //
126
127 debug_assert_eq!(
128 pf.root,
129 match ctx.mechanism {
130 TranslationMechanism::Paging { root: Some(root) } => root,
131 _ => panic!("page fault root doesn't match the context root"),
132 }
133 );
134
135 let pf = PageFault {
136 address: pf_address_aligned + i * Driver::Architecture::PAGE_SIZE,
137 root: pf.root,
138 };
139
140 if !self.restricted.contains(&pf) {
141 tracing::trace!(va = %pf.address, "page fault");
142 page_faults.insert(pf);
143 }
144 else {
145 tracing::trace!(va = %pf.address, "page fault (restricted)");
146 }
147 }
148 }
149 */
150
151 /// Checks for any unexpected page faults that have occurred and returns
152 /// an error if any are present.
153 #[tracing::instrument(skip_all, err)]
154 pub fn error_for_page_faults(&self) -> Result<(), VmiError> {
155 let pfs = self.page_faults.borrow();
156 let new_pfs = &*pfs - &self.restricted;
157 if !new_pfs.is_empty() {
158 tracing::trace!(?new_pfs);
159 return Err(VmiError::page_faults(new_pfs));
160 }
161
162 Ok(())
163 }
164}
165
166/// Probes for safely handling page faults during memory access operations.
167#[macro_export]
168macro_rules! vmi_probe {
169 ($prober:expr, $expr:expr) => {
170 $prober.check_result(|| -> Result<_, VmiError> { $expr }())
171 };
172}