wraith/manipulation/hooks/
detector.rs

1//! Hook detection logic
2//!
3//! Detects various types of inline hooks by analyzing function prologues
4//! and comparing against known hook patterns or clean copies from disk.
5
6use core::fmt;
7
8#[cfg(all(not(feature = "std"), feature = "alloc"))]
9use alloc::{string::String, vec::Vec};
10
11#[cfg(feature = "std")]
12use std::{string::String, vec::Vec};
13
14use crate::error::{Result, WraithError};
15use crate::navigation::Module;
16use crate::structures::pe::{DataDirectoryType, ExportDirectory};
17
18/// type of detected hook
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum HookType {
21    /// direct jump at function start (jmp rel32)
22    JmpRel32,
23    /// indirect jump (jmp [rip+disp32] or jmp [addr])
24    JmpIndirect,
25    /// mov rax, addr; jmp rax pattern
26    MovJmpRax,
27    /// push addr; ret pattern (32-bit)
28    PushRet,
29    /// int3 breakpoint
30    Breakpoint,
31    /// bytes differ from clean copy but no recognized pattern
32    Unknown,
33}
34
35impl fmt::Display for HookType {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            Self::JmpRel32 => write!(f, "jmp rel32"),
39            Self::JmpIndirect => write!(f, "jmp [addr]"),
40            Self::MovJmpRax => write!(f, "mov rax, addr; jmp rax"),
41            Self::PushRet => write!(f, "push addr; ret"),
42            Self::Breakpoint => write!(f, "int3 breakpoint"),
43            Self::Unknown => write!(f, "unknown modification"),
44        }
45    }
46}
47
48/// information about a detected hook
49#[derive(Debug, Clone)]
50pub struct HookInfo {
51    /// name of hooked function
52    pub function_name: String,
53    /// address of function start
54    pub function_address: usize,
55    /// type of hook detected
56    pub hook_type: HookType,
57    /// where the hook redirects to (if determinable)
58    pub hook_destination: Option<usize>,
59    /// original bytes at function start (if available from clean copy)
60    pub original_bytes: Vec<u8>,
61    /// current bytes at function start
62    pub hooked_bytes: Vec<u8>,
63    /// module containing the function
64    pub module_name: String,
65}
66
67impl HookInfo {
68    /// check if we have original bytes for restoration
69    pub fn can_restore(&self) -> bool {
70        !self.original_bytes.is_empty()
71    }
72}
73
74/// hook detection for a module
75pub struct HookDetector<'a> {
76    module: &'a Module<'a>,
77    clean_copy: Option<Vec<u8>>,
78}
79
80impl<'a> HookDetector<'a> {
81    /// create detector for module, attempting to load clean copy from disk
82    pub fn new(module: &'a Module<'a>) -> Result<Self> {
83        let clean_copy = Self::load_clean_copy(module).ok();
84        Ok(Self { module, clean_copy })
85    }
86
87    /// create detector with explicit clean copy bytes
88    pub fn with_clean_copy(module: &'a Module<'a>, clean_copy: Vec<u8>) -> Self {
89        Self {
90            module,
91            clean_copy: Some(clean_copy),
92        }
93    }
94
95    /// create detector without clean copy (pattern detection only)
96    pub fn without_clean_copy(module: &'a Module<'a>) -> Self {
97        Self {
98            module,
99            clean_copy: None,
100        }
101    }
102
103    /// load clean copy of module from disk
104    #[cfg(feature = "std")]
105    fn load_clean_copy(module: &Module) -> Result<Vec<u8>> {
106        let path = module.full_path();
107        std::fs::read(&path).map_err(|_| WraithError::CleanCopyUnavailable)
108    }
109
110    /// load clean copy of module from disk (no_std stub)
111    #[cfg(not(feature = "std"))]
112    fn load_clean_copy(_module: &Module) -> Result<Vec<u8>> {
113        Err(WraithError::CleanCopyUnavailable)
114    }
115
116    /// check if clean copy is available
117    pub fn has_clean_copy(&self) -> bool {
118        self.clean_copy.is_some()
119    }
120
121    /// scan all exports for hooks
122    pub fn scan_exports(&self) -> Result<Vec<HookInfo>> {
123        let mut hooks = Vec::new();
124
125        let nt = self.module.nt_headers()?;
126        let export_dir = match nt.data_directory(DataDirectoryType::Export.index()) {
127            Some(dir) if dir.is_present() => dir,
128            _ => return Ok(hooks),
129        };
130
131        let base = self.module.base();
132        // SAFETY: export directory is present and valid for loaded modules
133        let exports = unsafe {
134            &*((base + export_dir.virtual_address as usize) as *const ExportDirectory)
135        };
136
137        let num_names = exports.number_of_names as usize;
138        let names_va = base + exports.address_of_names as usize;
139        let ordinals_va = base + exports.address_of_name_ordinals as usize;
140        let functions_va = base + exports.address_of_functions as usize;
141
142        for i in 0..num_names {
143            // SAFETY: iterating within bounds of export arrays
144            let name_rva = unsafe { *((names_va + i * 4) as *const u32) };
145            let name_ptr = (base + name_rva as usize) as *const u8;
146
147            let name = unsafe {
148                let mut len = 0;
149                while *name_ptr.add(len) != 0 && len < 256 {
150                    len += 1;
151                }
152                String::from_utf8_lossy(core::slice::from_raw_parts(name_ptr, len)).to_string()
153            };
154
155            let ordinal = unsafe { *((ordinals_va + i * 2) as *const u16) };
156            let func_rva = unsafe { *((functions_va + ordinal as usize * 4) as *const u32) };
157
158            // check for forwarded export (RVA points into export directory)
159            if func_rva >= export_dir.virtual_address
160                && func_rva < export_dir.virtual_address + export_dir.size
161            {
162                continue;
163            }
164
165            let func_addr = base + func_rva as usize;
166
167            if let Some(hook_info) = self.check_function(&name, func_addr)? {
168                hooks.push(hook_info);
169            }
170        }
171
172        Ok(hooks)
173    }
174
175    /// check a single function for hooks
176    pub fn check_function(&self, name: &str, addr: usize) -> Result<Option<HookInfo>> {
177        const PROLOGUE_SIZE: usize = 32;
178
179        // read current bytes at function
180        // SAFETY: function address is valid for loaded export
181        let current_bytes: [u8; PROLOGUE_SIZE] = unsafe { *(addr as *const [u8; PROLOGUE_SIZE]) };
182
183        // first check for known hook patterns
184        if let Some((hook_type, destination)) = self.detect_hook_pattern(&current_bytes, addr) {
185            let original_bytes = self
186                .get_original_bytes(addr, PROLOGUE_SIZE)
187                .unwrap_or_default();
188
189            return Ok(Some(HookInfo {
190                function_name: name.to_string(),
191                function_address: addr,
192                hook_type,
193                hook_destination: destination,
194                original_bytes,
195                hooked_bytes: current_bytes.to_vec(),
196                module_name: self.module.name(),
197            }));
198        }
199
200        // if we have a clean copy, compare against it
201        if let Some(clean) = &self.clean_copy {
202            if let Some(rva) = self.module.va_to_rva(addr) {
203                if let Some(original) = self.get_bytes_from_pe(clean, rva as usize, PROLOGUE_SIZE) {
204                    if current_bytes[..] != original[..] {
205                        return Ok(Some(HookInfo {
206                            function_name: name.to_string(),
207                            function_address: addr,
208                            hook_type: HookType::Unknown,
209                            hook_destination: None,
210                            original_bytes: original,
211                            hooked_bytes: current_bytes.to_vec(),
212                            module_name: self.module.name(),
213                        }));
214                    }
215                }
216            }
217        }
218
219        Ok(None)
220    }
221
222    /// detect hook pattern in bytes
223    fn detect_hook_pattern(&self, bytes: &[u8], addr: usize) -> Option<(HookType, Option<usize>)> {
224        if bytes.len() < 5 {
225            return None;
226        }
227
228        // E9 XX XX XX XX - jmp rel32
229        if bytes[0] == 0xE9 {
230            let offset = i32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
231            let target = (addr as i64 + 5 + offset as i64) as usize;
232            return Some((HookType::JmpRel32, Some(target)));
233        }
234
235        // FF 25 XX XX XX XX - jmp [rip+disp32] (x64)
236        if bytes.len() >= 6 && bytes[0] == 0xFF && bytes[1] == 0x25 {
237            let offset = i32::from_le_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]);
238            let ptr_addr = (addr as i64 + 6 + offset as i64) as usize;
239            // SAFETY: reading pointer from computed address
240            let target = unsafe { *(ptr_addr as *const usize) };
241            return Some((HookType::JmpIndirect, Some(target)));
242        }
243
244        // 48 B8 XX XX XX XX XX XX XX XX - mov rax, imm64
245        // FF E0 - jmp rax
246        if bytes.len() >= 12
247            && bytes[0] == 0x48
248            && bytes[1] == 0xB8
249            && bytes[10] == 0xFF
250            && bytes[11] == 0xE0
251        {
252            let target = u64::from_le_bytes([
253                bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9],
254            ]) as usize;
255            return Some((HookType::MovJmpRax, Some(target)));
256        }
257
258        // 68 XX XX XX XX - push imm32
259        // C3 - ret
260        if bytes.len() >= 6 && bytes[0] == 0x68 && bytes[5] == 0xC3 {
261            let target = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
262            return Some((HookType::PushRet, Some(target)));
263        }
264
265        // CC - int3 breakpoint
266        if bytes[0] == 0xCC {
267            return Some((HookType::Breakpoint, None));
268        }
269
270        None
271    }
272
273    /// get bytes from PE file at given RVA
274    fn get_bytes_from_pe(&self, pe_data: &[u8], rva: usize, len: usize) -> Option<Vec<u8>> {
275        // need to map RVA to file offset using section headers
276        let file_offset = self.rva_to_file_offset(pe_data, rva)?;
277
278        if file_offset + len <= pe_data.len() {
279            Some(pe_data[file_offset..file_offset + len].to_vec())
280        } else {
281            None
282        }
283    }
284
285    /// convert RVA to file offset in PE
286    fn rva_to_file_offset(&self, pe_data: &[u8], rva: usize) -> Option<usize> {
287        if pe_data.len() < 64 {
288            return None;
289        }
290
291        // check DOS signature
292        if pe_data[0] != 0x4D || pe_data[1] != 0x5A {
293            return None;
294        }
295
296        // get PE header offset
297        let pe_offset = u32::from_le_bytes([pe_data[0x3C], pe_data[0x3D], pe_data[0x3E], pe_data[0x3F]]) as usize;
298
299        if pe_offset + 24 > pe_data.len() {
300            return None;
301        }
302
303        // check PE signature
304        if pe_data[pe_offset..pe_offset + 4] != [0x50, 0x45, 0x00, 0x00] {
305            return None;
306        }
307
308        // get number of sections and optional header size
309        let num_sections =
310            u16::from_le_bytes([pe_data[pe_offset + 6], pe_data[pe_offset + 7]]) as usize;
311        let optional_header_size =
312            u16::from_le_bytes([pe_data[pe_offset + 20], pe_data[pe_offset + 21]]) as usize;
313
314        let section_table_offset = pe_offset + 24 + optional_header_size;
315
316        // iterate sections to find which contains the RVA
317        for i in 0..num_sections {
318            let section_offset = section_table_offset + i * 40;
319
320            if section_offset + 40 > pe_data.len() {
321                break;
322            }
323
324            let virtual_size = u32::from_le_bytes([
325                pe_data[section_offset + 8],
326                pe_data[section_offset + 9],
327                pe_data[section_offset + 10],
328                pe_data[section_offset + 11],
329            ]) as usize;
330
331            let virtual_address = u32::from_le_bytes([
332                pe_data[section_offset + 12],
333                pe_data[section_offset + 13],
334                pe_data[section_offset + 14],
335                pe_data[section_offset + 15],
336            ]) as usize;
337
338            let raw_data_ptr = u32::from_le_bytes([
339                pe_data[section_offset + 20],
340                pe_data[section_offset + 21],
341                pe_data[section_offset + 22],
342                pe_data[section_offset + 23],
343            ]) as usize;
344
345            // check if RVA falls within this section
346            if rva >= virtual_address && rva < virtual_address + virtual_size {
347                let offset_in_section = rva - virtual_address;
348                return Some(raw_data_ptr + offset_in_section);
349            }
350        }
351
352        None
353    }
354
355    /// get original bytes from clean copy (if available)
356    fn get_original_bytes(&self, addr: usize, len: usize) -> Option<Vec<u8>> {
357        let clean = self.clean_copy.as_ref()?;
358        let rva = self.module.va_to_rva(addr)?;
359        self.get_bytes_from_pe(clean, rva as usize, len)
360    }
361}
362
363/// check if a specific function is hooked
364pub fn is_hooked(module: &Module, function_name: &str) -> Result<bool> {
365    let addr = module.get_export(function_name)?;
366    let detector = HookDetector::new(module)?;
367    Ok(detector.check_function(function_name, addr)?.is_some())
368}
369
370/// check if a specific function is hooked (pattern detection only, no disk access)
371pub fn is_hooked_fast(module: &Module, function_name: &str) -> Result<bool> {
372    let addr = module.get_export(function_name)?;
373    let detector = HookDetector::without_clean_copy(module);
374    Ok(detector.check_function(function_name, addr)?.is_some())
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_jmp_rel32_detection() {
383        let detector = HookDetector {
384            module: unsafe { &*(0x1000 as *const Module) }, // dummy, won't be used
385            clean_copy: None,
386        };
387
388        // E9 01 00 00 00 = jmp +1
389        let bytes = [0xE9, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
390        let result = detector.detect_hook_pattern(&bytes, 0x1000);
391
392        assert!(result.is_some());
393        let (hook_type, dest) = result.unwrap();
394        assert_eq!(hook_type, HookType::JmpRel32);
395        assert_eq!(dest, Some(0x1006)); // 0x1000 + 5 + 1
396    }
397
398    #[test]
399    fn test_breakpoint_detection() {
400        let detector = HookDetector {
401            module: unsafe { &*(0x1000 as *const Module) },
402            clean_copy: None,
403        };
404
405        let bytes = [0xCC, 0x00, 0x00, 0x00, 0x00];
406        let result = detector.detect_hook_pattern(&bytes, 0x1000);
407
408        assert!(result.is_some());
409        let (hook_type, _) = result.unwrap();
410        assert_eq!(hook_type, HookType::Breakpoint);
411    }
412}