verity_memory/
utils.rs

1use std::ptr::null_mut;
2
3use winapi::um::libloaderapi::GetModuleHandleW;
4
5use crate::w;
6
7pub(crate) fn check_alignment<T>(ptr: *const T) -> bool {
8    if ptr.is_null() {
9        return false;
10    }
11    let alignment = std::mem::align_of::<T>();
12    (ptr as usize) % alignment == 0
13}
14
15pub unsafe fn import_function<'a, F>(
16    dll_name: &str,
17    proc_name: &str,
18) -> Option<(
19    libloading::os::windows::Library,
20    libloading::os::windows::Symbol<F>,
21)>
22where
23    F: Sized,
24{
25    match libloading::os::windows::Library::new(dll_name) {
26        Ok(lib) => {
27            let proc_name_c = std::ffi::CString::new(proc_name).unwrap();
28            match lib.get::<F>(proc_name_c.as_bytes_with_nul()) {
29                Ok(symbol) => Some((lib, symbol)),
30                Err(_) => {
31                    println!("Failed to get function address for {}", proc_name);
32                    None
33                }
34            }
35        }
36        Err(_) => {
37            println!("Failed to load DLL: {}", dll_name);
38            None
39        }
40    }
41}
42
43pub fn module_base(module_name: Option<&str>) -> *mut u8 {
44    unsafe {
45
46        let handle = match module_name {
47            Some(name) => {
48                GetModuleHandleW(w!(name))
49            }
50            None => GetModuleHandleW(null_mut()),
51        };
52
53        if handle.is_null() {
54            panic!("Failed to get module handle");
55        }
56        handle as *mut u8
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn test_check_alignment() {
66        let x: u8 = 42;
67        let ptr = &x as *const u8;
68        assert!(check_alignment(ptr));
69
70        let y: i32 = 42;
71        let ptr = &y as *const i32;
72        assert!(check_alignment(ptr));
73    }
74
75    #[test]
76    fn test_check_alignment_unaligned() {
77        let x: u8 = 42;
78        let ptr = &x as *const u8;
79        assert!(check_alignment(ptr));
80    }
81
82    #[test]
83    fn test_import_function_fail_load() {
84        let result = unsafe { import_function::<fn()>("non_existent_dll.dll", "non_existent_function") };
85        assert!(result.is_none());
86    }
87
88    #[test]
89    fn test_import_function_fail_get() {
90        let result = unsafe { import_function::<fn()>("kernel32.dll", "non_existent_function") };
91        assert!(result.is_none());
92    }
93
94    #[test]
95    fn test_import_function_success() {
96        let result = unsafe { import_function::<fn()>("kernel32.dll", "GetCurrentProcess") };
97        assert!(result.is_some());
98    }
99}