1use std::ptr::null_mut;
2
3use winapi::um::libloaderapi::GetModuleHandleW;
4
5use crate::w;
6
7pub(crate) fn check_alignment<T>(ptr: *const T) -> bool {
8 if ptr.is_null() {
9 return false;
10 }
11 let alignment = std::mem::align_of::<T>();
12 (ptr as usize) % alignment == 0
13}
14
15pub unsafe fn import_function<'a, F>(
16 dll_name: &str,
17 proc_name: &str,
18) -> Option<(
19 libloading::os::windows::Library,
20 libloading::os::windows::Symbol<F>,
21)>
22where
23 F: Sized,
24{
25 match libloading::os::windows::Library::new(dll_name) {
26 Ok(lib) => {
27 let proc_name_c = std::ffi::CString::new(proc_name).unwrap();
28 match lib.get::<F>(proc_name_c.as_bytes_with_nul()) {
29 Ok(symbol) => Some((lib, symbol)),
30 Err(_) => {
31 println!("Failed to get function address for {}", proc_name);
32 None
33 }
34 }
35 }
36 Err(_) => {
37 println!("Failed to load DLL: {}", dll_name);
38 None
39 }
40 }
41}
42
43pub fn module_base(module_name: Option<&str>) -> *mut u8 {
44 unsafe {
45
46 let handle = match module_name {
47 Some(name) => {
48 GetModuleHandleW(w!(name))
49 }
50 None => GetModuleHandleW(null_mut()),
51 };
52
53 if handle.is_null() {
54 panic!("Failed to get module handle");
55 }
56 handle as *mut u8
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[test]
65 fn test_check_alignment() {
66 let x: u8 = 42;
67 let ptr = &x as *const u8;
68 assert!(check_alignment(ptr));
69
70 let y: i32 = 42;
71 let ptr = &y as *const i32;
72 assert!(check_alignment(ptr));
73 }
74
75 #[test]
76 fn test_check_alignment_unaligned() {
77 let x: u8 = 42;
78 let ptr = &x as *const u8;
79 assert!(check_alignment(ptr));
80 }
81
82 #[test]
83 fn test_import_function_fail_load() {
84 let result = unsafe { import_function::<fn()>("non_existent_dll.dll", "non_existent_function") };
85 assert!(result.is_none());
86 }
87
88 #[test]
89 fn test_import_function_fail_get() {
90 let result = unsafe { import_function::<fn()>("kernel32.dll", "non_existent_function") };
91 assert!(result.is_none());
92 }
93
94 #[test]
95 fn test_import_function_success() {
96 let result = unsafe { import_function::<fn()>("kernel32.dll", "GetCurrentProcess") };
97 assert!(result.is_some());
98 }
99}