1use std::ffi::c_void;
4use std::path::Path;
5use std::{
6 io,
7 ptr,
8};
9
10use windows::Win32::Foundation::{
11 FreeLibrary,
12 HINSTANCE,
13 HMODULE,
14};
15use windows::Win32::System::LibraryLoader::{
16 GetModuleHandleExW,
17 GetProcAddress,
18 LOAD_LIBRARY_AS_DATAFILE,
19 LOAD_LIBRARY_AS_IMAGE_RESOURCE,
20 LOAD_LIBRARY_FLAGS,
21 LoadLibraryExW,
22};
23use windows::core::PCSTR;
24
25use crate::internal::ResultExt;
26use crate::string::{
27 ZeroTerminatedString,
28 ZeroTerminatedWideString,
29};
30
31#[derive(Eq, PartialEq, Debug)]
33pub struct ExecutableModule {
34 raw_handle: HMODULE,
35}
36
37impl ExecutableModule {
38 pub fn from_current_process_exe() -> io::Result<Self> {
40 Self::get_loaded_internal(None::<&Path>)
41 }
42
43 pub fn from_loaded<A: AsRef<Path>>(name: A) -> io::Result<Self> {
45 Self::get_loaded_internal(Some(name))
46 }
47
48 fn get_loaded_internal(name: Option<impl AsRef<Path>>) -> io::Result<Self> {
49 let name_wide = name.map(|x| ZeroTerminatedWideString::from_os_str(x.as_ref()));
50 let name_param = name_wide
51 .as_ref()
52 .map(ZeroTerminatedWideString::as_raw_pcwstr);
53 let mut raw_handle: HMODULE = Default::default();
54 unsafe { GetModuleHandleExW(0, name_param.as_ref(), &raw mut raw_handle) }?;
55 Ok(ExecutableModule { raw_handle })
56 }
57
58 pub fn load_module_as_data_file<P: AsRef<Path>>(file_name: P) -> io::Result<Self> {
60 Self::load_module_internal(
61 file_name,
62 LOAD_LIBRARY_AS_DATAFILE | LOAD_LIBRARY_AS_IMAGE_RESOURCE,
63 )
64 }
65
66 pub fn load_module<P: AsRef<Path>>(file_name: P) -> io::Result<Self> {
68 Self::load_module_internal(file_name, Default::default())
69 }
70
71 fn load_module_internal(
72 file_name: impl AsRef<Path>,
73 flags: LOAD_LIBRARY_FLAGS,
74 ) -> io::Result<Self> {
75 let file_name = ZeroTerminatedWideString::from_os_str(file_name.as_ref());
76 let raw_handle: HMODULE =
77 unsafe { LoadLibraryExW(file_name.as_raw_pcwstr(), None, flags) }?;
78 Ok(ExecutableModule { raw_handle })
79 }
80
81 pub fn get_symbol_ptr_by_ordinal(&self, ordinal: u16) -> io::Result<*const c_void> {
82 self.get_symbol_ptr(&SymbolIdentifier::from(ordinal))
83 }
84
85 pub fn get_symbol_ptr_by_name<S: AsRef<str>>(&self, name: S) -> io::Result<*const c_void> {
86 self.get_symbol_ptr(&SymbolIdentifier::from(name.as_ref()))
87 }
88
89 fn get_symbol_ptr(&self, symbol: &SymbolIdentifier) -> io::Result<*const c_void> {
90 let symbol_ptr = unsafe { GetProcAddress(self.as_hmodule(), symbol.as_param()) }
91 .ok_or_else(io::Error::last_os_error)?;
92 Ok(ptr::with_exposed_provenance(symbol_ptr as usize))
93 }
94
95 pub(crate) fn as_hmodule(&self) -> HMODULE {
96 self.raw_handle
97 }
98
99 #[allow(dead_code)]
100 pub(crate) fn as_hinstance(&self) -> HINSTANCE {
101 self.as_hmodule().into()
102 }
103}
104
105impl Drop for ExecutableModule {
106 fn drop(&mut self) {
107 unsafe { FreeLibrary(self.as_hmodule()) }.unwrap_or_default_and_print_error();
108 }
109}
110
111#[derive(Clone, PartialEq, Eq, Debug)]
112enum SymbolIdentifier {
113 Ordinal(u16),
114 Name(ZeroTerminatedString),
115}
116
117impl SymbolIdentifier {
118 fn as_param(&self) -> PCSTR {
119 match self {
120 SymbolIdentifier::Ordinal(ordinal) => PCSTR(usize::from(*ordinal) as *const u8),
121 SymbolIdentifier::Name(name) => name.as_raw_pcstr(),
122 }
123 }
124}
125
126impl From<u16> for SymbolIdentifier {
127 fn from(value: u16) -> Self {
128 Self::Ordinal(value)
129 }
130}
131
132impl From<&str> for SymbolIdentifier {
133 fn from(value: &str) -> Self {
134 Self::Name(ZeroTerminatedString::from(value))
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn get_current_exe_module() -> io::Result<()> {
144 let module = ExecutableModule::from_current_process_exe()?;
145 assert!(!module.as_hmodule().is_invalid());
146 Ok(())
147 }
148
149 #[test]
150 fn load_shell32_module() -> io::Result<()> {
151 let module = ExecutableModule::load_module_as_data_file("shell32.dll")?;
152 assert!(!module.as_hmodule().is_invalid());
153 Ok(())
154 }
155
156 #[test]
157 fn get_symbol_ptr() -> io::Result<()> {
158 let module = ExecutableModule::from_loaded("kernel32.dll")?;
159 let symbol_ptr = module.get_symbol_ptr_by_name("GetProcAddress")?;
160 assert!(!symbol_ptr.is_null());
161 Ok(())
162 }
163}