wraith/navigation/
memory_regions.rs

1//! Memory region enumeration using VirtualQuery
2
3#[cfg(all(not(feature = "std"), feature = "alloc"))]
4use alloc::vec::Vec;
5
6#[cfg(feature = "std")]
7use std::vec::Vec;
8
9use crate::error::{Result, WraithError};
10
11/// memory region information
12#[derive(Debug, Clone)]
13pub struct MemoryRegion {
14    pub base_address: usize,
15    pub allocation_base: usize,
16    pub allocation_protect: u32,
17    pub region_size: usize,
18    pub state: MemoryState,
19    pub protect: u32,
20    pub memory_type: MemoryType,
21}
22
23/// memory state
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum MemoryState {
26    Commit,
27    Reserve,
28    Free,
29}
30
31/// memory type
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum MemoryType {
34    Image,   // mapped executable image
35    Mapped,  // memory-mapped file
36    Private, // private memory
37    Unknown,
38}
39
40/// macro for generating protection check methods with `#[must_use]`
41macro_rules! define_protection_check {
42    ($(#[$attr:meta])* $name:ident, $($flag:ident)|+) => {
43        $(#[$attr])*
44        #[must_use]
45        pub fn $name(&self) -> bool {
46            $(self.protect & $flag != 0)||+
47        }
48    };
49}
50
51impl MemoryRegion {
52    define_protection_check!(
53        /// check if region is executable
54        is_executable,
55        PAGE_EXECUTE | PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE | PAGE_EXECUTE_WRITECOPY
56    );
57
58    define_protection_check!(
59        /// check if region is readable
60        is_readable,
61        PAGE_READONLY | PAGE_READWRITE | PAGE_WRITECOPY |
62        PAGE_EXECUTE_READ | PAGE_EXECUTE_READWRITE | PAGE_EXECUTE_WRITECOPY
63    );
64
65    define_protection_check!(
66        /// check if region is writable
67        is_writable,
68        PAGE_READWRITE | PAGE_WRITECOPY | PAGE_EXECUTE_READWRITE | PAGE_EXECUTE_WRITECOPY
69    );
70
71    /// check if region is committed (accessible)
72    #[must_use]
73    pub fn is_committed(&self) -> bool {
74        self.state == MemoryState::Commit
75    }
76
77    /// check if this is part of an image
78    #[must_use]
79    pub fn is_image(&self) -> bool {
80        self.memory_type == MemoryType::Image
81    }
82
83    /// check if region is private memory
84    #[must_use]
85    pub fn is_private(&self) -> bool {
86        self.memory_type == MemoryType::Private
87    }
88
89    /// check if region is reserved (not yet committed)
90    #[must_use]
91    pub fn is_reserved(&self) -> bool {
92        self.state == MemoryState::Reserve
93    }
94
95    /// check if region is free (not allocated)
96    #[must_use]
97    pub fn is_free(&self) -> bool {
98        self.state == MemoryState::Free
99    }
100
101    /// get protection string (e.g., "RWX", "R--", etc.)
102    #[must_use]
103    pub fn protection_string(&self) -> &'static str {
104        match self.protect {
105            PAGE_NOACCESS => "---",
106            PAGE_READONLY => "R--",
107            PAGE_READWRITE => "RW-",
108            PAGE_WRITECOPY => "RC-",
109            PAGE_EXECUTE => "--X",
110            PAGE_EXECUTE_READ => "R-X",
111            PAGE_EXECUTE_READWRITE => "RWX",
112            PAGE_EXECUTE_WRITECOPY => "RCX",
113            _ => "???",
114        }
115    }
116}
117
118/// iterator over memory regions in current process
119pub struct MemoryRegionIterator {
120    current_address: usize,
121    max_address: usize,
122}
123
124impl MemoryRegionIterator {
125    /// create new iterator starting from address 0
126    pub fn new() -> Self {
127        Self {
128            current_address: 0,
129            max_address: Self::max_user_address(),
130        }
131    }
132
133    /// create iterator starting from specific address
134    pub fn from_address(address: usize) -> Self {
135        Self {
136            current_address: address,
137            max_address: Self::max_user_address(),
138        }
139    }
140
141    fn max_user_address() -> usize {
142        #[cfg(target_arch = "x86_64")]
143        {
144            0x7FFFFFFFFFFF // typical x64 user space limit
145        }
146        #[cfg(target_arch = "x86")]
147        {
148            0x7FFFFFFF // typical x86 user space limit
149        }
150    }
151}
152
153impl Default for MemoryRegionIterator {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159impl Iterator for MemoryRegionIterator {
160    type Item = MemoryRegion;
161
162    fn next(&mut self) -> Option<Self::Item> {
163        if self.current_address >= self.max_address {
164            return None;
165        }
166
167        let mut mbi = MemoryBasicInformation::default();
168        // SAFETY: VirtualQuery is safe to call with valid buffer
169        let result = unsafe {
170            VirtualQuery(
171                self.current_address as *const _,
172                &mut mbi,
173                core::mem::size_of::<MemoryBasicInformation>(),
174            )
175        };
176
177        if result == 0 {
178            return None;
179        }
180
181        // advance to next region
182        self.current_address = mbi.base_address + mbi.region_size;
183
184        let state = match mbi.state {
185            MEM_COMMIT => MemoryState::Commit,
186            MEM_RESERVE => MemoryState::Reserve,
187            MEM_FREE => MemoryState::Free,
188            _ => MemoryState::Free,
189        };
190
191        let memory_type = match mbi.memory_type {
192            MEM_IMAGE => MemoryType::Image,
193            MEM_MAPPED => MemoryType::Mapped,
194            MEM_PRIVATE => MemoryType::Private,
195            _ => MemoryType::Unknown,
196        };
197
198        Some(MemoryRegion {
199            base_address: mbi.base_address,
200            allocation_base: mbi.allocation_base,
201            allocation_protect: mbi.allocation_protect,
202            region_size: mbi.region_size,
203            state,
204            protect: mbi.protect,
205            memory_type,
206        })
207    }
208}
209
210/// find all executable memory regions
211pub fn find_executable_regions() -> Vec<MemoryRegion> {
212    MemoryRegionIterator::new()
213        .filter(|r| r.is_committed() && r.is_executable())
214        .collect()
215}
216
217/// find all image (module) regions
218pub fn find_image_regions() -> Vec<MemoryRegion> {
219    MemoryRegionIterator::new()
220        .filter(|r| r.is_committed() && r.is_image())
221        .collect()
222}
223
224/// find all private memory regions
225pub fn find_private_regions() -> Vec<MemoryRegion> {
226    MemoryRegionIterator::new()
227        .filter(|r| r.is_committed() && r.memory_type == MemoryType::Private)
228        .collect()
229}
230
231/// query single memory region at address
232pub fn query_region(address: usize) -> Result<MemoryRegion> {
233    let mut mbi = MemoryBasicInformation::default();
234    // SAFETY: VirtualQuery is safe to call with valid buffer
235    let result = unsafe {
236        VirtualQuery(
237            address as *const _,
238            &mut mbi,
239            core::mem::size_of::<MemoryBasicInformation>(),
240        )
241    };
242
243    if result == 0 {
244        return Err(WraithError::ReadFailed {
245            address: u64::try_from(address).unwrap_or(u64::MAX),
246            size: 0,
247        });
248    }
249
250    let state = match mbi.state {
251        MEM_COMMIT => MemoryState::Commit,
252        MEM_RESERVE => MemoryState::Reserve,
253        _ => MemoryState::Free,
254    };
255
256    let memory_type = match mbi.memory_type {
257        MEM_IMAGE => MemoryType::Image,
258        MEM_MAPPED => MemoryType::Mapped,
259        MEM_PRIVATE => MemoryType::Private,
260        _ => MemoryType::Unknown,
261    };
262
263    Ok(MemoryRegion {
264        base_address: mbi.base_address,
265        allocation_base: mbi.allocation_base,
266        allocation_protect: mbi.allocation_protect,
267        region_size: mbi.region_size,
268        state,
269        protect: mbi.protect,
270        memory_type,
271    })
272}
273
274// internal structures for VirtualQuery
275#[repr(C)]
276#[derive(Default)]
277struct MemoryBasicInformation {
278    base_address: usize,
279    allocation_base: usize,
280    allocation_protect: u32,
281    #[cfg(target_arch = "x86_64")]
282    partition_id: u16,
283    region_size: usize,
284    state: u32,
285    protect: u32,
286    memory_type: u32,
287}
288
289// memory state constants
290const MEM_COMMIT: u32 = 0x1000;
291const MEM_RESERVE: u32 = 0x2000;
292const MEM_FREE: u32 = 0x10000;
293
294// memory type constants
295const MEM_IMAGE: u32 = 0x1000000;
296const MEM_MAPPED: u32 = 0x40000;
297const MEM_PRIVATE: u32 = 0x20000;
298
299// page protection constants
300const PAGE_NOACCESS: u32 = 0x01;
301const PAGE_READONLY: u32 = 0x02;
302const PAGE_READWRITE: u32 = 0x04;
303const PAGE_WRITECOPY: u32 = 0x08;
304const PAGE_EXECUTE: u32 = 0x10;
305const PAGE_EXECUTE_READ: u32 = 0x20;
306const PAGE_EXECUTE_READWRITE: u32 = 0x40;
307const PAGE_EXECUTE_WRITECOPY: u32 = 0x80;
308
309#[link(name = "kernel32")]
310extern "system" {
311    fn VirtualQuery(
312        address: *const core::ffi::c_void,
313        buffer: *mut MemoryBasicInformation,
314        length: usize,
315    ) -> usize;
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_memory_iterator() {
324        let regions: Vec<_> = MemoryRegionIterator::new().take(10).collect();
325        assert!(!regions.is_empty());
326    }
327
328    #[test]
329    fn test_find_executable() {
330        let exec_regions = find_executable_regions();
331        // should find at least our own code
332        assert!(!exec_regions.is_empty());
333    }
334
335    #[test]
336    fn test_query_region() {
337        // query our own code
338        let addr = test_query_region as usize;
339        let region = query_region(addr).expect("should query region");
340        assert!(region.is_executable());
341        assert!(region.is_committed());
342    }
343
344    #[test]
345    fn test_protection_string() {
346        let region = query_region(test_protection_string as usize).expect("should query");
347        let prot_str = region.protection_string();
348        // code should be executable
349        assert!(prot_str.contains('X'));
350    }
351}