wraith/manipulation/inline_hook/hook/
iat.rs

1//! IAT (Import Address Table) hooking
2//!
3//! IAT hooks work by modifying entries in a module's Import Address Table.
4//! When a module imports a function from another DLL, the loader fills the IAT
5//! with the actual function addresses. By replacing an IAT entry with a detour
6//! address, all calls through that import are redirected.
7//!
8//! # Advantages
9//! - No code modification (safer for integrity checks)
10//! - Easy to install and remove
11//! - Works on any imported function
12//!
13//! # Limitations
14//! - Only affects calls through the IAT (not direct calls or GetProcAddress)
15//! - Module-specific (each module has its own IAT)
16//! - Does not affect already-resolved function pointers
17
18#[cfg(all(not(feature = "std"), feature = "alloc"))]
19use alloc::{format, string::String, vec::Vec};
20
21#[cfg(feature = "std")]
22use std::{format, string::String, vec::Vec};
23
24use crate::error::{Result, WraithError};
25use crate::navigation::{Module, ModuleQuery};
26use crate::structures::pe::{DataDirectoryType, ImportDescriptor, ImportByName};
27use crate::structures::Peb;
28use crate::util::memory::ProtectionGuard;
29
30const PAGE_READWRITE: u32 = 0x04;
31
32#[cfg(target_arch = "x86_64")]
33use crate::structures::pe::{ThunkData64 as ThunkData, IMAGE_ORDINAL_FLAG64 as IMAGE_ORDINAL_FLAG};
34#[cfg(target_arch = "x86")]
35use crate::structures::pe::{ThunkData32 as ThunkData, IMAGE_ORDINAL_FLAG32 as IMAGE_ORDINAL_FLAG};
36
37/// information about a single IAT entry
38#[derive(Debug, Clone)]
39pub struct IatEntry {
40    /// address of the IAT entry (pointer to the function pointer)
41    pub entry_address: usize,
42    /// current value in the IAT (the function address being called)
43    pub current_value: usize,
44    /// name of the imported function (if imported by name)
45    pub function_name: Option<String>,
46    /// ordinal (if imported by ordinal)
47    pub ordinal: Option<u16>,
48    /// name of the DLL this function is imported from
49    pub dll_name: String,
50}
51
52/// IAT hook instance
53pub struct IatHook {
54    /// address of the IAT entry we're hooking
55    iat_entry: usize,
56    /// original function address (what was in IAT before hook)
57    original: usize,
58    /// detour function address (what we replaced it with)
59    detour: usize,
60    /// whether the hook is currently active
61    active: bool,
62    /// whether to restore on drop
63    auto_restore: bool,
64}
65
66impl IatHook {
67    /// create and install an IAT hook
68    ///
69    /// # Arguments
70    /// * `target_module` - the module whose IAT to modify
71    /// * `import_dll` - the DLL name containing the function to hook (e.g., "kernel32.dll")
72    /// * `function_name` - the function name to hook
73    /// * `detour` - address of the detour function
74    ///
75    /// # Example
76    /// ```ignore
77    /// let hook = IatHook::new("myapp.exe", "kernel32.dll", "CreateFileW", my_detour as usize)?;
78    /// // calls to CreateFileW from myapp.exe now go to my_detour
79    /// ```
80    pub fn new(
81        target_module: &str,
82        import_dll: &str,
83        function_name: &str,
84        detour: usize,
85    ) -> Result<Self> {
86        let peb = Peb::current()?;
87        let query = ModuleQuery::new(&peb);
88        let module = query.find_by_name(target_module)?;
89
90        Self::new_in_module(&module, import_dll, function_name, detour)
91    }
92
93    /// create and install an IAT hook in a specific module
94    pub fn new_in_module(
95        module: &Module,
96        import_dll: &str,
97        function_name: &str,
98        detour: usize,
99    ) -> Result<Self> {
100        let iat_entry = find_iat_entry(module, import_dll, function_name)?;
101        Self::new_at_address(iat_entry.entry_address, detour)
102    }
103
104    /// create and install an IAT hook at a specific IAT entry address
105    ///
106    /// use this when you already know the IAT entry address
107    pub fn new_at_address(iat_entry: usize, detour: usize) -> Result<Self> {
108        if iat_entry == 0 {
109            return Err(WraithError::NullPointer { context: "iat_entry" });
110        }
111
112        // read original value
113        // SAFETY: iat_entry points to valid IAT entry
114        let original = unsafe { *(iat_entry as *const usize) };
115
116        let mut hook = Self {
117            iat_entry,
118            original,
119            detour,
120            active: false,
121            auto_restore: true,
122        };
123
124        hook.install()?;
125        Ok(hook)
126    }
127
128    /// install the hook (write detour address to IAT)
129    pub fn install(&mut self) -> Result<()> {
130        if self.active {
131            return Ok(());
132        }
133
134        write_iat_entry(self.iat_entry, self.detour)?;
135        self.active = true;
136
137        Ok(())
138    }
139
140    /// remove the hook (restore original address)
141    pub fn uninstall(&mut self) -> Result<()> {
142        if !self.active {
143            return Ok(());
144        }
145
146        write_iat_entry(self.iat_entry, self.original)?;
147        self.active = false;
148
149        Ok(())
150    }
151
152    /// check if hook is active
153    pub fn is_active(&self) -> bool {
154        self.active
155    }
156
157    /// get the original function address
158    pub fn original(&self) -> usize {
159        self.original
160    }
161
162    /// get the detour function address
163    pub fn detour(&self) -> usize {
164        self.detour
165    }
166
167    /// get the IAT entry address
168    pub fn iat_entry(&self) -> usize {
169        self.iat_entry
170    }
171
172    /// set whether to auto-restore on drop
173    pub fn set_auto_restore(&mut self, restore: bool) {
174        self.auto_restore = restore;
175    }
176
177    /// leak the hook (keep active after drop)
178    pub fn leak(mut self) {
179        self.auto_restore = false;
180        core::mem::forget(self);
181    }
182
183    /// consume the hook and restore the original
184    pub fn restore(mut self) -> Result<()> {
185        self.uninstall()?;
186        self.auto_restore = false;
187        Ok(())
188    }
189}
190
191impl Drop for IatHook {
192    fn drop(&mut self) {
193        if self.auto_restore && self.active {
194            let _ = self.uninstall();
195        }
196    }
197}
198
199// SAFETY: IAT hook operates on process-wide memory
200unsafe impl Send for IatHook {}
201unsafe impl Sync for IatHook {}
202
203/// RAII guard for an IAT hook
204pub type IatHookGuard = IatHook;
205
206/// find an IAT entry for a specific import
207pub fn find_iat_entry(
208    module: &Module,
209    import_dll: &str,
210    function_name: &str,
211) -> Result<IatEntry> {
212    let entries = enumerate_iat_entries(module)?;
213    let import_dll_lower = import_dll.to_lowercase();
214    let function_name_lower = function_name.to_lowercase();
215
216    for entry in entries {
217        let dll_matches = entry.dll_name.to_lowercase() == import_dll_lower
218            || entry.dll_name.to_lowercase().trim_end_matches(".dll")
219                == import_dll_lower.trim_end_matches(".dll");
220
221        if dll_matches {
222            if let Some(ref name) = entry.function_name {
223                if name.to_lowercase() == function_name_lower {
224                    return Ok(entry);
225                }
226            }
227        }
228    }
229
230    Err(WraithError::ModuleNotFound {
231        name: format!("IAT entry for {}!{}", import_dll, function_name),
232    })
233}
234
235/// enumerate all IAT entries in a module
236pub fn enumerate_iat_entries(module: &Module) -> Result<Vec<IatEntry>> {
237    let nt = module.nt_headers()?;
238    let import_dir = nt
239        .data_directory(DataDirectoryType::Import.index())
240        .ok_or_else(|| WraithError::InvalidPeFormat {
241            reason: "no import directory".into(),
242        })?;
243
244    if !import_dir.is_present() {
245        return Ok(Vec::new());
246    }
247
248    let base = module.base();
249    let mut entries = Vec::new();
250
251    // iterate import descriptors
252    let mut desc_va = base + import_dir.virtual_address as usize;
253    loop {
254        // SAFETY: desc_va points to valid import descriptor in loaded module
255        let desc = unsafe { &*(desc_va as *const ImportDescriptor) };
256
257        if desc.is_null() {
258            break;
259        }
260
261        // get DLL name
262        let dll_name_va = base + desc.name as usize;
263        let dll_name = read_cstring(dll_name_va, 256)?;
264
265        // get IAT and INT (Import Name Table)
266        let iat_va = base + desc.first_thunk as usize;
267        let int_va = if desc.original_first_thunk != 0 {
268            base + desc.original_first_thunk as usize
269        } else {
270            iat_va // use IAT if INT is not present
271        };
272
273        // iterate thunks
274        let mut thunk_idx = 0usize;
275        loop {
276            let thunk_size = core::mem::size_of::<ThunkData>();
277            let iat_entry_addr = iat_va + thunk_idx * thunk_size;
278            let int_entry_addr = int_va + thunk_idx * thunk_size;
279
280            // SAFETY: reading thunk data from loaded module
281            let iat_thunk = unsafe { *(iat_entry_addr as *const usize) };
282            if iat_thunk == 0 {
283                break;
284            }
285
286            let int_thunk = unsafe { *(int_entry_addr as *const usize) };
287
288            let (function_name, ordinal) = if is_ordinal_import(int_thunk) {
289                (None, Some(get_ordinal(int_thunk)))
290            } else {
291                // import by name
292                let hint_name_va = base + (int_thunk & !IMAGE_ORDINAL_FLAG as usize);
293                // SAFETY: hint_name_va points to valid IMAGE_IMPORT_BY_NAME
294                let hint_name = unsafe { &*(hint_name_va as *const ImportByName) };
295                let name_ptr = hint_name.name.as_ptr();
296                let name = read_cstring(name_ptr as usize, 256).ok();
297                (name, None)
298            };
299
300            entries.push(IatEntry {
301                entry_address: iat_entry_addr,
302                current_value: iat_thunk,
303                function_name,
304                ordinal,
305                dll_name: dll_name.clone(),
306            });
307
308            thunk_idx += 1;
309        }
310
311        desc_va += core::mem::size_of::<ImportDescriptor>();
312    }
313
314    Ok(entries)
315}
316
317/// hook an import in the current module
318pub fn hook_import(
319    import_dll: &str,
320    function_name: &str,
321    detour: usize,
322) -> Result<IatHook> {
323    let peb = Peb::current()?;
324    let query = ModuleQuery::new(&peb);
325    let current = query.current_module()?;
326
327    IatHook::new_in_module(&current, import_dll, function_name, detour)
328}
329
330/// hook an import in any module that imports it
331pub fn hook_import_all(
332    import_dll: &str,
333    function_name: &str,
334    detour: usize,
335) -> Result<Vec<IatHook>> {
336    let peb = Peb::current()?;
337    let mut hooks = Vec::new();
338
339    for module in crate::navigation::ModuleIterator::new(&peb, crate::navigation::ModuleListType::InLoadOrder)? {
340        if let Ok(hook) = IatHook::new_in_module(&module, import_dll, function_name, detour) {
341            hooks.push(hook);
342        }
343    }
344
345    if hooks.is_empty() {
346        Err(WraithError::ModuleNotFound {
347            name: format!("IAT entry for {}!{} in any module", import_dll, function_name),
348        })
349    } else {
350        Ok(hooks)
351    }
352}
353
354/// write a value to an IAT entry
355fn write_iat_entry(entry: usize, value: usize) -> Result<()> {
356    let _guard = ProtectionGuard::new(entry, core::mem::size_of::<usize>(), PAGE_READWRITE)?;
357
358    // SAFETY: entry is valid IAT address, protection changed to RW
359    unsafe {
360        *(entry as *mut usize) = value;
361    }
362
363    Ok(())
364}
365
366/// read a null-terminated C string
367fn read_cstring(addr: usize, max_len: usize) -> Result<String> {
368    let mut bytes = Vec::new();
369
370    for i in 0..max_len {
371        // SAFETY: reading bytes within max_len
372        let byte = unsafe { *((addr + i) as *const u8) };
373        if byte == 0 {
374            break;
375        }
376        bytes.push(byte);
377    }
378
379    String::from_utf8(bytes).map_err(|_| WraithError::InvalidPeFormat {
380        reason: "invalid string encoding".into(),
381    })
382}
383
384/// check if a thunk is an ordinal import
385#[cfg(target_arch = "x86_64")]
386fn is_ordinal_import(thunk: usize) -> bool {
387    (thunk as u64 & IMAGE_ORDINAL_FLAG) != 0
388}
389
390#[cfg(target_arch = "x86")]
391fn is_ordinal_import(thunk: usize) -> bool {
392    (thunk as u32 & IMAGE_ORDINAL_FLAG) != 0
393}
394
395/// extract ordinal from thunk
396fn get_ordinal(thunk: usize) -> u16 {
397    (thunk & 0xFFFF) as u16
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_enumerate_iat() {
406        let peb = Peb::current().expect("should get PEB");
407        let query = ModuleQuery::new(&peb);
408        let current = query.current_module().expect("should get current module");
409
410        let entries = enumerate_iat_entries(&current).expect("should enumerate IAT");
411        assert!(!entries.is_empty(), "should have imports");
412    }
413
414    #[test]
415    fn test_find_kernel32_import() {
416        let peb = Peb::current().expect("should get PEB");
417        let query = ModuleQuery::new(&peb);
418        let current = query.current_module().expect("should get current module");
419
420        // most programs import something from kernel32
421        let entries = enumerate_iat_entries(&current).expect("should enumerate IAT");
422        let has_kernel32 = entries.iter().any(|e|
423            e.dll_name.to_lowercase().contains("kernel32")
424        );
425        assert!(has_kernel32, "should import from kernel32");
426    }
427}