wraith/manipulation/manual_map/
allocator.rs

1//! Memory allocation for manual mapping
2
3use crate::error::{Result, WraithError};
4
5// memory allocation constants
6const MEM_COMMIT: u32 = 0x1000;
7const MEM_RESERVE: u32 = 0x2000;
8const MEM_RELEASE: u32 = 0x8000;
9const PAGE_READWRITE: u32 = 0x04;
10
11/// allocated memory region for PE image
12pub struct MappedMemory {
13    base: *mut u8,
14    size: usize,
15}
16
17impl MappedMemory {
18    /// get base address
19    pub fn base(&self) -> usize {
20        self.base as usize
21    }
22
23    /// get allocated size
24    pub fn size(&self) -> usize {
25        self.size
26    }
27
28    /// get mutable slice to entire region
29    pub fn as_slice_mut(&mut self) -> &mut [u8] {
30        // SAFETY: base is valid for size bytes, we own the memory
31        unsafe { core::slice::from_raw_parts_mut(self.base, self.size) }
32    }
33
34    /// get immutable slice to entire region
35    pub fn as_slice(&self) -> &[u8] {
36        // SAFETY: base is valid for size bytes, we own the memory
37        unsafe { core::slice::from_raw_parts(self.base, self.size) }
38    }
39
40    /// write data at offset
41    pub fn write_at(&mut self, offset: usize, data: &[u8]) -> Result<()> {
42        if offset + data.len() > self.size {
43            return Err(WraithError::WriteFailed {
44                address: (self.base as usize + offset) as u64,
45                size: data.len(),
46            });
47        }
48
49        // SAFETY: bounds checked, we own the memory
50        unsafe {
51            core::ptr::copy_nonoverlapping(data.as_ptr(), self.base.add(offset), data.len());
52        }
53        Ok(())
54    }
55
56    /// read value at offset
57    pub fn read_at<T: Copy>(&self, offset: usize) -> Result<T> {
58        if offset + core::mem::size_of::<T>() > self.size {
59            return Err(WraithError::ReadFailed {
60                address: (self.base as usize + offset) as u64,
61                size: core::mem::size_of::<T>(),
62            });
63        }
64
65        // SAFETY: bounds checked, read_unaligned handles alignment
66        Ok(unsafe { (self.base.add(offset) as *const T).read_unaligned() })
67    }
68
69    /// write value at offset
70    pub fn write_value_at<T>(&mut self, offset: usize, value: T) -> Result<()> {
71        if offset + core::mem::size_of::<T>() > self.size {
72            return Err(WraithError::WriteFailed {
73                address: (self.base as usize + offset) as u64,
74                size: core::mem::size_of::<T>(),
75            });
76        }
77
78        // SAFETY: bounds checked, write_unaligned handles alignment
79        unsafe {
80            (self.base.add(offset) as *mut T).write_unaligned(value);
81        }
82        Ok(())
83    }
84
85    /// set memory protection for a region
86    pub fn protect(&self, offset: usize, size: usize, protection: u32) -> Result<u32> {
87        if offset + size > self.size {
88            return Err(WraithError::ProtectionChangeFailed {
89                address: (self.base as usize + offset) as u64,
90                size,
91            });
92        }
93
94        let mut old_protect: u32 = 0;
95
96        // SAFETY: address is within our allocated range
97        let result = unsafe {
98            VirtualProtect(
99                self.base.add(offset) as *mut _,
100                size,
101                protection,
102                &mut old_protect,
103            )
104        };
105
106        if result == 0 {
107            return Err(WraithError::ProtectionChangeFailed {
108                address: (self.base as usize + offset) as u64,
109                size,
110            });
111        }
112
113        Ok(old_protect)
114    }
115
116    /// free the allocated memory
117    pub fn free(self) -> Result<()> {
118        // SAFETY: self.base was allocated with VirtualAlloc
119        let result = unsafe { VirtualFree(self.base as *mut _, 0, MEM_RELEASE) };
120
121        if result == 0 {
122            return Err(WraithError::from_last_error("VirtualFree"));
123        }
124
125        // prevent Drop from double-freeing
126        core::mem::forget(self);
127        Ok(())
128    }
129
130    /// get pointer at offset
131    pub fn ptr_at(&self, offset: usize) -> *mut u8 {
132        // SAFETY: caller responsible for bounds
133        unsafe { self.base.add(offset) }
134    }
135}
136
137impl Drop for MappedMemory {
138    fn drop(&mut self) {
139        // SAFETY: self.base was allocated with VirtualAlloc
140        unsafe {
141            VirtualFree(self.base as *mut _, 0, MEM_RELEASE);
142        }
143    }
144}
145
146// SAFETY: we own the memory, safe to move between threads
147unsafe impl Send for MappedMemory {}
148unsafe impl Sync for MappedMemory {}
149
150/// allocate memory for PE image, trying preferred base first
151pub fn allocate_image(size: usize, preferred_base: usize) -> Result<MappedMemory> {
152    // try preferred base first
153    let mut base = unsafe {
154        VirtualAlloc(
155            preferred_base as *mut _,
156            size,
157            MEM_COMMIT | MEM_RESERVE,
158            PAGE_READWRITE,
159        )
160    };
161
162    // fall back to any available address
163    if base.is_null() {
164        base = unsafe {
165            VirtualAlloc(
166                core::ptr::null_mut(),
167                size,
168                MEM_COMMIT | MEM_RESERVE,
169                PAGE_READWRITE,
170            )
171        };
172    }
173
174    if base.is_null() {
175        return Err(WraithError::AllocationFailed {
176            size,
177            protection: PAGE_READWRITE,
178        });
179    }
180
181    // zero the memory
182    // SAFETY: base is valid for size bytes
183    unsafe {
184        core::ptr::write_bytes(base, 0, size);
185    }
186
187    Ok(MappedMemory {
188        base: base as *mut u8,
189        size,
190    })
191}
192
193/// allocate memory at specific address (fails if not available)
194pub fn allocate_at(base: usize, size: usize) -> Result<MappedMemory> {
195    let ptr = unsafe {
196        VirtualAlloc(
197            base as *mut _,
198            size,
199            MEM_COMMIT | MEM_RESERVE,
200            PAGE_READWRITE,
201        )
202    };
203
204    // must get exact address requested
205    if ptr.is_null() || ptr as usize != base {
206        if !ptr.is_null() {
207            // got wrong address, free it
208            unsafe {
209                VirtualFree(ptr, 0, MEM_RELEASE);
210            }
211        }
212        return Err(WraithError::AllocationFailed {
213            size,
214            protection: PAGE_READWRITE,
215        });
216    }
217
218    // zero the memory
219    // SAFETY: ptr is valid for size bytes
220    unsafe {
221        core::ptr::write_bytes(ptr, 0, size);
222    }
223
224    Ok(MappedMemory {
225        base: ptr as *mut u8,
226        size,
227    })
228}
229
230/// allocate memory anywhere (no preference)
231pub fn allocate_anywhere(size: usize) -> Result<MappedMemory> {
232    let base = unsafe {
233        VirtualAlloc(
234            core::ptr::null_mut(),
235            size,
236            MEM_COMMIT | MEM_RESERVE,
237            PAGE_READWRITE,
238        )
239    };
240
241    if base.is_null() {
242        return Err(WraithError::AllocationFailed {
243            size,
244            protection: PAGE_READWRITE,
245        });
246    }
247
248    // zero the memory
249    // SAFETY: base is valid for size bytes
250    unsafe {
251        core::ptr::write_bytes(base, 0, size);
252    }
253
254    Ok(MappedMemory {
255        base: base as *mut u8,
256        size,
257    })
258}
259
260#[link(name = "kernel32")]
261extern "system" {
262    fn VirtualAlloc(
263        address: *mut core::ffi::c_void,
264        size: usize,
265        allocation_type: u32,
266        protection: u32,
267    ) -> *mut core::ffi::c_void;
268
269    fn VirtualFree(address: *mut core::ffi::c_void, size: usize, free_type: u32) -> i32;
270
271    fn VirtualProtect(
272        address: *mut core::ffi::c_void,
273        size: usize,
274        protection: u32,
275        old_protection: *mut u32,
276    ) -> i32;
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_allocate_anywhere() {
285        let mem = allocate_anywhere(0x1000).expect("should allocate");
286        assert!(mem.base() != 0);
287        assert_eq!(mem.size(), 0x1000);
288        mem.free().expect("should free");
289    }
290
291    #[test]
292    fn test_read_write() {
293        let mut mem = allocate_anywhere(0x1000).expect("should allocate");
294
295        mem.write_value_at(0, 0xDEADBEEFu32).expect("should write");
296        let val: u32 = mem.read_at(0).expect("should read");
297        assert_eq!(val, 0xDEADBEEF);
298
299        let data = [1u8, 2, 3, 4];
300        mem.write_at(0x100, &data).expect("should write bytes");
301        let slice = mem.as_slice();
302        assert_eq!(&slice[0x100..0x104], &data);
303    }
304
305    #[test]
306    fn test_protect() {
307        let mem = allocate_anywhere(0x1000).expect("should allocate");
308
309        const PAGE_READONLY: u32 = 0x02;
310        let old = mem.protect(0, 0x1000, PAGE_READONLY).expect("should protect");
311        assert_eq!(old, PAGE_READWRITE);
312    }
313}