wraith/manipulation/hooks/
detector.rs

1//! Hook detection logic
2//!
3//! Detects various types of inline hooks by analyzing function prologues
4//! and comparing against known hook patterns or clean copies from disk.
5
6use crate::error::{Result, WraithError};
7use crate::navigation::Module;
8use crate::structures::pe::{DataDirectoryType, ExportDirectory};
9
10/// type of detected hook
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum HookType {
13    /// direct jump at function start (jmp rel32)
14    JmpRel32,
15    /// indirect jump (jmp [rip+disp32] or jmp [addr])
16    JmpIndirect,
17    /// mov rax, addr; jmp rax pattern
18    MovJmpRax,
19    /// push addr; ret pattern (32-bit)
20    PushRet,
21    /// int3 breakpoint
22    Breakpoint,
23    /// bytes differ from clean copy but no recognized pattern
24    Unknown,
25}
26
27impl std::fmt::Display for HookType {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            Self::JmpRel32 => write!(f, "jmp rel32"),
31            Self::JmpIndirect => write!(f, "jmp [addr]"),
32            Self::MovJmpRax => write!(f, "mov rax, addr; jmp rax"),
33            Self::PushRet => write!(f, "push addr; ret"),
34            Self::Breakpoint => write!(f, "int3 breakpoint"),
35            Self::Unknown => write!(f, "unknown modification"),
36        }
37    }
38}
39
40/// information about a detected hook
41#[derive(Debug, Clone)]
42pub struct HookInfo {
43    /// name of hooked function
44    pub function_name: String,
45    /// address of function start
46    pub function_address: usize,
47    /// type of hook detected
48    pub hook_type: HookType,
49    /// where the hook redirects to (if determinable)
50    pub hook_destination: Option<usize>,
51    /// original bytes at function start (if available from clean copy)
52    pub original_bytes: Vec<u8>,
53    /// current bytes at function start
54    pub hooked_bytes: Vec<u8>,
55    /// module containing the function
56    pub module_name: String,
57}
58
59impl HookInfo {
60    /// check if we have original bytes for restoration
61    pub fn can_restore(&self) -> bool {
62        !self.original_bytes.is_empty()
63    }
64}
65
66/// hook detection for a module
67pub struct HookDetector<'a> {
68    module: &'a Module<'a>,
69    clean_copy: Option<Vec<u8>>,
70}
71
72impl<'a> HookDetector<'a> {
73    /// create detector for module, attempting to load clean copy from disk
74    pub fn new(module: &'a Module<'a>) -> Result<Self> {
75        let clean_copy = Self::load_clean_copy(module).ok();
76        Ok(Self { module, clean_copy })
77    }
78
79    /// create detector with explicit clean copy bytes
80    pub fn with_clean_copy(module: &'a Module<'a>, clean_copy: Vec<u8>) -> Self {
81        Self {
82            module,
83            clean_copy: Some(clean_copy),
84        }
85    }
86
87    /// create detector without clean copy (pattern detection only)
88    pub fn without_clean_copy(module: &'a Module<'a>) -> Self {
89        Self {
90            module,
91            clean_copy: None,
92        }
93    }
94
95    /// load clean copy of module from disk
96    fn load_clean_copy(module: &Module) -> Result<Vec<u8>> {
97        let path = module.full_path();
98        std::fs::read(&path).map_err(|_| WraithError::CleanCopyUnavailable)
99    }
100
101    /// check if clean copy is available
102    pub fn has_clean_copy(&self) -> bool {
103        self.clean_copy.is_some()
104    }
105
106    /// scan all exports for hooks
107    pub fn scan_exports(&self) -> Result<Vec<HookInfo>> {
108        let mut hooks = Vec::new();
109
110        let nt = self.module.nt_headers()?;
111        let export_dir = match nt.data_directory(DataDirectoryType::Export.index()) {
112            Some(dir) if dir.is_present() => dir,
113            _ => return Ok(hooks),
114        };
115
116        let base = self.module.base();
117        // SAFETY: export directory is present and valid for loaded modules
118        let exports = unsafe {
119            &*((base + export_dir.virtual_address as usize) as *const ExportDirectory)
120        };
121
122        let num_names = exports.number_of_names as usize;
123        let names_va = base + exports.address_of_names as usize;
124        let ordinals_va = base + exports.address_of_name_ordinals as usize;
125        let functions_va = base + exports.address_of_functions as usize;
126
127        for i in 0..num_names {
128            // SAFETY: iterating within bounds of export arrays
129            let name_rva = unsafe { *((names_va + i * 4) as *const u32) };
130            let name_ptr = (base + name_rva as usize) as *const u8;
131
132            let name = unsafe {
133                let mut len = 0;
134                while *name_ptr.add(len) != 0 && len < 256 {
135                    len += 1;
136                }
137                String::from_utf8_lossy(core::slice::from_raw_parts(name_ptr, len)).to_string()
138            };
139
140            let ordinal = unsafe { *((ordinals_va + i * 2) as *const u16) };
141            let func_rva = unsafe { *((functions_va + ordinal as usize * 4) as *const u32) };
142
143            // check for forwarded export (RVA points into export directory)
144            if func_rva >= export_dir.virtual_address
145                && func_rva < export_dir.virtual_address + export_dir.size
146            {
147                continue;
148            }
149
150            let func_addr = base + func_rva as usize;
151
152            if let Some(hook_info) = self.check_function(&name, func_addr)? {
153                hooks.push(hook_info);
154            }
155        }
156
157        Ok(hooks)
158    }
159
160    /// check a single function for hooks
161    pub fn check_function(&self, name: &str, addr: usize) -> Result<Option<HookInfo>> {
162        const PROLOGUE_SIZE: usize = 32;
163
164        // read current bytes at function
165        // SAFETY: function address is valid for loaded export
166        let current_bytes: [u8; PROLOGUE_SIZE] = unsafe { *(addr as *const [u8; PROLOGUE_SIZE]) };
167
168        // first check for known hook patterns
169        if let Some((hook_type, destination)) = self.detect_hook_pattern(&current_bytes, addr) {
170            let original_bytes = self
171                .get_original_bytes(addr, PROLOGUE_SIZE)
172                .unwrap_or_default();
173
174            return Ok(Some(HookInfo {
175                function_name: name.to_string(),
176                function_address: addr,
177                hook_type,
178                hook_destination: destination,
179                original_bytes,
180                hooked_bytes: current_bytes.to_vec(),
181                module_name: self.module.name(),
182            }));
183        }
184
185        // if we have a clean copy, compare against it
186        if let Some(clean) = &self.clean_copy {
187            if let Some(rva) = self.module.va_to_rva(addr) {
188                if let Some(original) = self.get_bytes_from_pe(clean, rva as usize, PROLOGUE_SIZE) {
189                    if current_bytes[..] != original[..] {
190                        return Ok(Some(HookInfo {
191                            function_name: name.to_string(),
192                            function_address: addr,
193                            hook_type: HookType::Unknown,
194                            hook_destination: None,
195                            original_bytes: original,
196                            hooked_bytes: current_bytes.to_vec(),
197                            module_name: self.module.name(),
198                        }));
199                    }
200                }
201            }
202        }
203
204        Ok(None)
205    }
206
207    /// detect hook pattern in bytes
208    fn detect_hook_pattern(&self, bytes: &[u8], addr: usize) -> Option<(HookType, Option<usize>)> {
209        if bytes.len() < 5 {
210            return None;
211        }
212
213        // E9 XX XX XX XX - jmp rel32
214        if bytes[0] == 0xE9 {
215            let offset = i32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
216            let target = (addr as i64 + 5 + offset as i64) as usize;
217            return Some((HookType::JmpRel32, Some(target)));
218        }
219
220        // FF 25 XX XX XX XX - jmp [rip+disp32] (x64)
221        if bytes.len() >= 6 && bytes[0] == 0xFF && bytes[1] == 0x25 {
222            let offset = i32::from_le_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]);
223            let ptr_addr = (addr as i64 + 6 + offset as i64) as usize;
224            // SAFETY: reading pointer from computed address
225            let target = unsafe { *(ptr_addr as *const usize) };
226            return Some((HookType::JmpIndirect, Some(target)));
227        }
228
229        // 48 B8 XX XX XX XX XX XX XX XX - mov rax, imm64
230        // FF E0 - jmp rax
231        if bytes.len() >= 12
232            && bytes[0] == 0x48
233            && bytes[1] == 0xB8
234            && bytes[10] == 0xFF
235            && bytes[11] == 0xE0
236        {
237            let target = u64::from_le_bytes([
238                bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9],
239            ]) as usize;
240            return Some((HookType::MovJmpRax, Some(target)));
241        }
242
243        // 68 XX XX XX XX - push imm32
244        // C3 - ret
245        if bytes.len() >= 6 && bytes[0] == 0x68 && bytes[5] == 0xC3 {
246            let target = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
247            return Some((HookType::PushRet, Some(target)));
248        }
249
250        // CC - int3 breakpoint
251        if bytes[0] == 0xCC {
252            return Some((HookType::Breakpoint, None));
253        }
254
255        None
256    }
257
258    /// get bytes from PE file at given RVA
259    fn get_bytes_from_pe(&self, pe_data: &[u8], rva: usize, len: usize) -> Option<Vec<u8>> {
260        // need to map RVA to file offset using section headers
261        let file_offset = self.rva_to_file_offset(pe_data, rva)?;
262
263        if file_offset + len <= pe_data.len() {
264            Some(pe_data[file_offset..file_offset + len].to_vec())
265        } else {
266            None
267        }
268    }
269
270    /// convert RVA to file offset in PE
271    fn rva_to_file_offset(&self, pe_data: &[u8], rva: usize) -> Option<usize> {
272        if pe_data.len() < 64 {
273            return None;
274        }
275
276        // check DOS signature
277        if pe_data[0] != 0x4D || pe_data[1] != 0x5A {
278            return None;
279        }
280
281        // get PE header offset
282        let pe_offset = u32::from_le_bytes([pe_data[0x3C], pe_data[0x3D], pe_data[0x3E], pe_data[0x3F]]) as usize;
283
284        if pe_offset + 24 > pe_data.len() {
285            return None;
286        }
287
288        // check PE signature
289        if pe_data[pe_offset..pe_offset + 4] != [0x50, 0x45, 0x00, 0x00] {
290            return None;
291        }
292
293        // get number of sections and optional header size
294        let num_sections =
295            u16::from_le_bytes([pe_data[pe_offset + 6], pe_data[pe_offset + 7]]) as usize;
296        let optional_header_size =
297            u16::from_le_bytes([pe_data[pe_offset + 20], pe_data[pe_offset + 21]]) as usize;
298
299        let section_table_offset = pe_offset + 24 + optional_header_size;
300
301        // iterate sections to find which contains the RVA
302        for i in 0..num_sections {
303            let section_offset = section_table_offset + i * 40;
304
305            if section_offset + 40 > pe_data.len() {
306                break;
307            }
308
309            let virtual_size = u32::from_le_bytes([
310                pe_data[section_offset + 8],
311                pe_data[section_offset + 9],
312                pe_data[section_offset + 10],
313                pe_data[section_offset + 11],
314            ]) as usize;
315
316            let virtual_address = u32::from_le_bytes([
317                pe_data[section_offset + 12],
318                pe_data[section_offset + 13],
319                pe_data[section_offset + 14],
320                pe_data[section_offset + 15],
321            ]) as usize;
322
323            let raw_data_ptr = u32::from_le_bytes([
324                pe_data[section_offset + 20],
325                pe_data[section_offset + 21],
326                pe_data[section_offset + 22],
327                pe_data[section_offset + 23],
328            ]) as usize;
329
330            // check if RVA falls within this section
331            if rva >= virtual_address && rva < virtual_address + virtual_size {
332                let offset_in_section = rva - virtual_address;
333                return Some(raw_data_ptr + offset_in_section);
334            }
335        }
336
337        None
338    }
339
340    /// get original bytes from clean copy (if available)
341    fn get_original_bytes(&self, addr: usize, len: usize) -> Option<Vec<u8>> {
342        let clean = self.clean_copy.as_ref()?;
343        let rva = self.module.va_to_rva(addr)?;
344        self.get_bytes_from_pe(clean, rva as usize, len)
345    }
346}
347
348/// check if a specific function is hooked
349pub fn is_hooked(module: &Module, function_name: &str) -> Result<bool> {
350    let addr = module.get_export(function_name)?;
351    let detector = HookDetector::new(module)?;
352    Ok(detector.check_function(function_name, addr)?.is_some())
353}
354
355/// check if a specific function is hooked (pattern detection only, no disk access)
356pub fn is_hooked_fast(module: &Module, function_name: &str) -> Result<bool> {
357    let addr = module.get_export(function_name)?;
358    let detector = HookDetector::without_clean_copy(module);
359    Ok(detector.check_function(function_name, addr)?.is_some())
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_jmp_rel32_detection() {
368        let detector = HookDetector {
369            module: unsafe { &*(0x1000 as *const Module) }, // dummy, won't be used
370            clean_copy: None,
371        };
372
373        // E9 01 00 00 00 = jmp +1
374        let bytes = [0xE9, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
375        let result = detector.detect_hook_pattern(&bytes, 0x1000);
376
377        assert!(result.is_some());
378        let (hook_type, dest) = result.unwrap();
379        assert_eq!(hook_type, HookType::JmpRel32);
380        assert_eq!(dest, Some(0x1006)); // 0x1000 + 5 + 1
381    }
382
383    #[test]
384    fn test_breakpoint_detection() {
385        let detector = HookDetector {
386            module: unsafe { &*(0x1000 as *const Module) },
387            clean_copy: None,
388        };
389
390        let bytes = [0xCC, 0x00, 0x00, 0x00, 0x00];
391        let result = detector.detect_hook_pattern(&bytes, 0x1000);
392
393        assert!(result.is_some());
394        let (hook_type, _) = result.unwrap();
395        assert_eq!(hook_type, HookType::Breakpoint);
396    }
397}