wraith/manipulation/inline_hook/hook/
veh.rs

1//! VEH (Vectored Exception Handler) hooking
2//!
3//! VEH hooks use Windows Vectored Exception Handling to intercept execution.
4//! By placing a hardware breakpoint (debug register) or INT3 at the target,
5//! execution triggers an exception that our handler catches and redirects.
6//!
7//! # Advantages
8//! - No code modification at the hook site (INT3 is single byte, hardware BP is zero bytes)
9//! - Harder to detect than inline hooks
10//! - Works even on read-only memory
11//! - Hardware breakpoints are invisible to code integrity checks
12//!
13//! # Limitations
14//! - Only 4 hardware breakpoints available per thread
15//! - Performance overhead from exception handling
16//! - Must manage debug registers carefully
17//! - VEH handler is visible to GetVectoredExceptionHandlerCount
18
19#[cfg(all(not(feature = "std"), feature = "alloc"))]
20use alloc::string::String;
21
22#[cfg(feature = "std")]
23use std::string::String;
24
25use crate::error::{Result, WraithError};
26use core::cell::UnsafeCell;
27use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
28
29/// exception handler return values
30const EXCEPTION_CONTINUE_EXECUTION: i32 = -1;
31const EXCEPTION_CONTINUE_SEARCH: i32 = 0;
32
33/// exception codes
34const EXCEPTION_BREAKPOINT: u32 = 0x80000003;
35const EXCEPTION_SINGLE_STEP: u32 = 0x80000004;
36
37/// debug register indices
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39#[repr(u8)]
40pub enum DebugRegister {
41    Dr0 = 0,
42    Dr1 = 1,
43    Dr2 = 2,
44    Dr3 = 3,
45}
46
47/// breakpoint condition
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49#[repr(u8)]
50pub enum BreakCondition {
51    /// break on execution
52    Execute = 0b00,
53    /// break on write
54    Write = 0b01,
55    /// break on I/O (typically not supported)
56    Io = 0b10,
57    /// break on read/write
58    ReadWrite = 0b11,
59}
60
61/// breakpoint length
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63#[repr(u8)]
64pub enum BreakLength {
65    Byte = 0b00,
66    Word = 0b01,
67    Qword = 0b10, // or undefined on 32-bit
68    Dword = 0b11,
69}
70
71/// VEH hook type
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum VehHookType {
74    /// use hardware breakpoint (debug register)
75    Hardware(DebugRegister),
76    /// use INT3 software breakpoint
77    Int3,
78}
79
80/// information about a VEH hook target
81struct VehHookTarget {
82    /// address of the hooked function
83    target: usize,
84    /// address of the detour function
85    detour: usize,
86    /// original byte at target (for INT3 hooks)
87    original_byte: u8,
88    /// hook type
89    hook_type: VehHookType,
90    /// whether this slot is active
91    active: bool,
92}
93
94/// maximum number of VEH hooks we support
95const MAX_VEH_HOOKS: usize = 64;
96
97/// global VEH hook registry
98static VEH_HOOKS: VehHookRegistry = VehHookRegistry::new();
99
100/// VEH handler handle
101static VEH_HANDLER: AtomicUsize = AtomicUsize::new(0);
102
103/// flag indicating VEH handler is installed
104static VEH_INSTALLED: AtomicBool = AtomicBool::new(false);
105
106/// thread-safe hook registry
107struct VehHookRegistry {
108    hooks: UnsafeCell<[Option<VehHookTarget>; MAX_VEH_HOOKS]>,
109}
110
111impl VehHookRegistry {
112    const fn new() -> Self {
113        // initialize with None values
114        const INIT: Option<VehHookTarget> = None;
115        Self {
116            hooks: UnsafeCell::new([INIT; MAX_VEH_HOOKS]),
117        }
118    }
119
120    fn find_hook(&self, address: usize) -> Option<(usize, usize)> {
121        // SAFETY: single-threaded access during exception handling
122        let hooks = unsafe { &*self.hooks.get() };
123        for hook in hooks.iter().flatten() {
124            if hook.active && hook.target == address {
125                return Some((hook.target, hook.detour));
126            }
127        }
128        None
129    }
130
131    fn register(&self, target: VehHookTarget) -> Result<usize> {
132        // SAFETY: we maintain proper synchronization
133        let hooks = unsafe { &mut *self.hooks.get() };
134        for (i, slot) in hooks.iter_mut().enumerate() {
135            if slot.is_none() {
136                *slot = Some(target);
137                return Ok(i);
138            }
139        }
140        Err(WraithError::HookInstallFailed {
141            target: 0,
142            reason: "VEH hook registry full".into(),
143        })
144    }
145
146    fn unregister(&self, index: usize) -> Option<VehHookTarget> {
147        // SAFETY: we maintain proper synchronization
148        let hooks = unsafe { &mut *self.hooks.get() };
149        if index < MAX_VEH_HOOKS {
150            hooks[index].take()
151        } else {
152            None
153        }
154    }
155
156    fn get(&self, index: usize) -> Option<&VehHookTarget> {
157        // SAFETY: read-only access
158        let hooks = unsafe { &*self.hooks.get() };
159        hooks.get(index).and_then(|h| h.as_ref())
160    }
161}
162
163// SAFETY: we use atomic operations for synchronization
164unsafe impl Sync for VehHookRegistry {}
165
166/// VEH hook instance using hardware breakpoints
167pub struct VehHook {
168    /// registry index
169    index: usize,
170    /// hook type
171    hook_type: VehHookType,
172    /// whether to restore on drop
173    auto_restore: bool,
174}
175
176impl VehHook {
177    /// create a VEH hook using a hardware breakpoint
178    ///
179    /// # Arguments
180    /// * `target` - address of the function to hook
181    /// * `detour` - address of the detour function
182    /// * `dr` - which debug register to use (Dr0-Dr3)
183    ///
184    /// # Example
185    /// ```ignore
186    /// let hook = VehHook::hardware(target_addr, my_detour as usize, DebugRegister::Dr0)?;
187    /// ```
188    pub fn hardware(target: usize, detour: usize, dr: DebugRegister) -> Result<Self> {
189        ensure_veh_handler()?;
190
191        // set hardware breakpoint in debug registers
192        set_hardware_breakpoint(dr, target, BreakCondition::Execute, BreakLength::Byte)?;
193
194        let hook = VehHookTarget {
195            target,
196            detour,
197            original_byte: 0,
198            hook_type: VehHookType::Hardware(dr),
199            active: true,
200        };
201
202        let index = VEH_HOOKS.register(hook)?;
203
204        Ok(Self {
205            index,
206            hook_type: VehHookType::Hardware(dr),
207            auto_restore: true,
208        })
209    }
210
211    /// create a VEH hook using INT3 software breakpoint
212    ///
213    /// # Arguments
214    /// * `target` - address of the function to hook
215    /// * `detour` - address of the detour function
216    ///
217    /// # Example
218    /// ```ignore
219    /// let hook = VehHook::int3(target_addr, my_detour as usize)?;
220    /// ```
221    pub fn int3(target: usize, detour: usize) -> Result<Self> {
222        ensure_veh_handler()?;
223
224        // read original byte and write INT3
225        // SAFETY: target is valid function address
226        let original_byte = unsafe { *(target as *const u8) };
227
228        // change protection and write INT3
229        let _guard = crate::util::memory::ProtectionGuard::new(target, 1, 0x40)?;
230        unsafe {
231            *(target as *mut u8) = 0xCC;
232        }
233
234        let hook = VehHookTarget {
235            target,
236            detour,
237            original_byte,
238            hook_type: VehHookType::Int3,
239            active: true,
240        };
241
242        let index = VEH_HOOKS.register(hook)?;
243
244        Ok(Self {
245            index,
246            hook_type: VehHookType::Int3,
247            auto_restore: true,
248        })
249    }
250
251    /// check if hook is active
252    pub fn is_active(&self) -> bool {
253        VEH_HOOKS.get(self.index).map_or(false, |h| h.active)
254    }
255
256    /// get the target address
257    pub fn target(&self) -> Option<usize> {
258        VEH_HOOKS.get(self.index).map(|h| h.target)
259    }
260
261    /// get the detour address
262    pub fn detour(&self) -> Option<usize> {
263        VEH_HOOKS.get(self.index).map(|h| h.detour)
264    }
265
266    /// get the hook type
267    pub fn hook_type(&self) -> VehHookType {
268        self.hook_type
269    }
270
271    /// set whether to auto-restore on drop
272    pub fn set_auto_restore(&mut self, restore: bool) {
273        self.auto_restore = restore;
274    }
275
276    /// leak the hook (keep active after drop)
277    pub fn leak(mut self) {
278        self.auto_restore = false;
279        core::mem::forget(self);
280    }
281
282    /// restore the hook
283    pub fn restore(self) -> Result<()> {
284        if let Some(hook) = VEH_HOOKS.unregister(self.index) {
285            match hook.hook_type {
286                VehHookType::Hardware(dr) => {
287                    clear_hardware_breakpoint(dr)?;
288                }
289                VehHookType::Int3 => {
290                    // restore original byte
291                    let _guard = crate::util::memory::ProtectionGuard::new(
292                        hook.target, 1, 0x40,
293                    )?;
294                    unsafe {
295                        *(hook.target as *mut u8) = hook.original_byte;
296                    }
297                }
298            }
299        }
300
301        // prevent drop from running
302        core::mem::forget(self);
303        Ok(())
304    }
305}
306
307impl Drop for VehHook {
308    fn drop(&mut self) {
309        if self.auto_restore {
310            if let Some(hook) = VEH_HOOKS.unregister(self.index) {
311                match hook.hook_type {
312                    VehHookType::Hardware(dr) => {
313                        let _ = clear_hardware_breakpoint(dr);
314                    }
315                    VehHookType::Int3 => {
316                        if let Ok(_guard) = crate::util::memory::ProtectionGuard::new(
317                            hook.target, 1, 0x40,
318                        ) {
319                            unsafe {
320                                *(hook.target as *mut u8) = hook.original_byte;
321                            }
322                        }
323                    }
324                }
325            }
326        }
327    }
328}
329
330// SAFETY: VehHook operates on process-wide exception handling
331unsafe impl Send for VehHook {}
332unsafe impl Sync for VehHook {}
333
334/// ensure VEH handler is installed
335fn ensure_veh_handler() -> Result<()> {
336    if VEH_INSTALLED.load(Ordering::Acquire) {
337        return Ok(());
338    }
339
340    let handler = unsafe {
341        AddVectoredExceptionHandler(1, Some(veh_handler))
342    };
343
344    if handler.is_null() {
345        return Err(WraithError::from_last_error("AddVectoredExceptionHandler"));
346    }
347
348    VEH_HANDLER.store(handler as usize, Ordering::Release);
349    VEH_INSTALLED.store(true, Ordering::Release);
350
351    Ok(())
352}
353
354/// the vectored exception handler
355extern "system" fn veh_handler(exception_info: *mut ExceptionPointers) -> i32 {
356    if exception_info.is_null() {
357        return EXCEPTION_CONTINUE_SEARCH;
358    }
359
360    // SAFETY: exception_info is valid during exception handling
361    let info = unsafe { &*exception_info };
362    let record = unsafe { &*info.exception_record };
363    let context = unsafe { &mut *info.context_record };
364
365    let exception_code = record.exception_code;
366
367    // handle breakpoint exceptions
368    if exception_code == EXCEPTION_BREAKPOINT || exception_code == EXCEPTION_SINGLE_STEP {
369        #[cfg(target_arch = "x86_64")]
370        let exception_address = context.rip as usize;
371        #[cfg(target_arch = "x86")]
372        let exception_address = context.eip as usize;
373
374        // check if this is one of our hooks
375        if let Some((target, detour)) = VEH_HOOKS.find_hook(exception_address) {
376            // for INT3, the exception address is after the INT3
377            let adjusted_addr = if exception_code == EXCEPTION_BREAKPOINT {
378                exception_address.saturating_sub(1)
379            } else {
380                exception_address
381            };
382
383            if adjusted_addr == target || exception_address == target {
384                // redirect to detour
385                #[cfg(target_arch = "x86_64")]
386                {
387                    context.rip = detour as u64;
388                }
389                #[cfg(target_arch = "x86")]
390                {
391                    context.eip = detour as u32;
392                }
393
394                // for hardware breakpoints, we need to set RF flag to prevent re-triggering
395                if exception_code == EXCEPTION_SINGLE_STEP {
396                    #[cfg(target_arch = "x86_64")]
397                    {
398                        context.eflags |= 0x10000; // RF flag
399                    }
400                    #[cfg(target_arch = "x86")]
401                    {
402                        context.eflags |= 0x10000; // RF flag
403                    }
404                }
405
406                return EXCEPTION_CONTINUE_EXECUTION;
407            }
408        }
409    }
410
411    EXCEPTION_CONTINUE_SEARCH
412}
413
414/// set a hardware breakpoint
415fn set_hardware_breakpoint(
416    dr: DebugRegister,
417    address: usize,
418    condition: BreakCondition,
419    length: BreakLength,
420) -> Result<()> {
421    let mut context = unsafe { core::mem::zeroed::<Context>() };
422
423    #[cfg(target_arch = "x86_64")]
424    {
425        context.context_flags = CONTEXT_DEBUG_REGISTERS;
426    }
427    #[cfg(target_arch = "x86")]
428    {
429        context.context_flags = CONTEXT_DEBUG_REGISTERS;
430    }
431
432    let thread = unsafe { GetCurrentThread() };
433
434    if unsafe { GetThreadContext(thread, &mut context) } == 0 {
435        return Err(WraithError::from_last_error("GetThreadContext"));
436    }
437
438    // set the debug register address
439    match dr {
440        DebugRegister::Dr0 => context.dr0 = address as u64,
441        DebugRegister::Dr1 => context.dr1 = address as u64,
442        DebugRegister::Dr2 => context.dr2 = address as u64,
443        DebugRegister::Dr3 => context.dr3 = address as u64,
444    }
445
446    // configure DR7
447    let dr_index = dr as u8;
448    let enable_bit = 1u64 << (dr_index * 2); // local enable
449    let condition_bits = (condition as u64) << (16 + dr_index * 4);
450    let length_bits = (length as u64) << (18 + dr_index * 4);
451
452    // clear old settings for this DR
453    let clear_mask = !(0b11u64 << (dr_index * 2) | 0b1111u64 << (16 + dr_index * 4));
454    context.dr7 &= clear_mask;
455
456    // set new settings
457    context.dr7 |= enable_bit | condition_bits | length_bits;
458
459    if unsafe { SetThreadContext(thread, &context) } == 0 {
460        return Err(WraithError::from_last_error("SetThreadContext"));
461    }
462
463    Ok(())
464}
465
466/// clear a hardware breakpoint
467fn clear_hardware_breakpoint(dr: DebugRegister) -> Result<()> {
468    let mut context = unsafe { core::mem::zeroed::<Context>() };
469
470    #[cfg(target_arch = "x86_64")]
471    {
472        context.context_flags = CONTEXT_DEBUG_REGISTERS;
473    }
474    #[cfg(target_arch = "x86")]
475    {
476        context.context_flags = CONTEXT_DEBUG_REGISTERS;
477    }
478
479    let thread = unsafe { GetCurrentThread() };
480
481    if unsafe { GetThreadContext(thread, &mut context) } == 0 {
482        return Err(WraithError::from_last_error("GetThreadContext"));
483    }
484
485    // clear the debug register
486    match dr {
487        DebugRegister::Dr0 => context.dr0 = 0,
488        DebugRegister::Dr1 => context.dr1 = 0,
489        DebugRegister::Dr2 => context.dr2 = 0,
490        DebugRegister::Dr3 => context.dr3 = 0,
491    }
492
493    // disable in DR7
494    let dr_index = dr as u8;
495    let disable_mask = !(0b11u64 << (dr_index * 2) | 0b1111u64 << (16 + dr_index * 4));
496    context.dr7 &= disable_mask;
497
498    if unsafe { SetThreadContext(thread, &context) } == 0 {
499        return Err(WraithError::from_last_error("SetThreadContext"));
500    }
501
502    Ok(())
503}
504
505/// get available debug register
506pub fn get_available_debug_register() -> Result<DebugRegister> {
507    let mut context = unsafe { core::mem::zeroed::<Context>() };
508
509    #[cfg(target_arch = "x86_64")]
510    {
511        context.context_flags = CONTEXT_DEBUG_REGISTERS;
512    }
513    #[cfg(target_arch = "x86")]
514    {
515        context.context_flags = CONTEXT_DEBUG_REGISTERS;
516    }
517
518    let thread = unsafe { GetCurrentThread() };
519
520    if unsafe { GetThreadContext(thread, &mut context) } == 0 {
521        return Err(WraithError::from_last_error("GetThreadContext"));
522    }
523
524    // check which debug registers are free
525    for i in 0..4u8 {
526        let is_enabled = (context.dr7 & (1u64 << (i * 2))) != 0;
527        if !is_enabled {
528            return Ok(match i {
529                0 => DebugRegister::Dr0,
530                1 => DebugRegister::Dr1,
531                2 => DebugRegister::Dr2,
532                _ => DebugRegister::Dr3,
533            });
534        }
535    }
536
537    Err(WraithError::GadgetNotFound {
538        gadget_type: "available debug register",
539    })
540}
541
542// context flags
543#[cfg(target_arch = "x86_64")]
544const CONTEXT_DEBUG_REGISTERS: u32 = 0x00100010;
545#[cfg(target_arch = "x86")]
546const CONTEXT_DEBUG_REGISTERS: u32 = 0x00010010;
547
548/// EXCEPTION_RECORD structure
549#[repr(C)]
550struct ExceptionRecord {
551    exception_code: u32,
552    exception_flags: u32,
553    exception_record: *mut ExceptionRecord,
554    exception_address: *mut core::ffi::c_void,
555    number_parameters: u32,
556    exception_information: [usize; 15],
557}
558
559/// EXCEPTION_POINTERS structure
560#[repr(C)]
561struct ExceptionPointers {
562    exception_record: *mut ExceptionRecord,
563    context_record: *mut Context,
564}
565
566/// CONTEXT structure (simplified, only what we need)
567#[repr(C)]
568#[cfg(target_arch = "x86_64")]
569struct Context {
570    p1_home: u64,
571    p2_home: u64,
572    p3_home: u64,
573    p4_home: u64,
574    p5_home: u64,
575    p6_home: u64,
576    context_flags: u32,
577    mx_csr: u32,
578    seg_cs: u16,
579    seg_ds: u16,
580    seg_es: u16,
581    seg_fs: u16,
582    seg_gs: u16,
583    seg_ss: u16,
584    eflags: u32,
585    dr0: u64,
586    dr1: u64,
587    dr2: u64,
588    dr3: u64,
589    dr6: u64,
590    dr7: u64,
591    rax: u64,
592    rcx: u64,
593    rdx: u64,
594    rbx: u64,
595    rsp: u64,
596    rbp: u64,
597    rsi: u64,
598    rdi: u64,
599    r8: u64,
600    r9: u64,
601    r10: u64,
602    r11: u64,
603    r12: u64,
604    r13: u64,
605    r14: u64,
606    r15: u64,
607    rip: u64,
608    _padding: [u8; 512], // FP/vector state we don't care about
609}
610
611#[repr(C)]
612#[cfg(target_arch = "x86")]
613struct Context {
614    context_flags: u32,
615    dr0: u32,
616    dr1: u32,
617    dr2: u32,
618    dr3: u32,
619    dr6: u32,
620    dr7: u32,
621    float_save: [u8; 112],
622    seg_gs: u32,
623    seg_fs: u32,
624    seg_es: u32,
625    seg_ds: u32,
626    edi: u32,
627    esi: u32,
628    ebx: u32,
629    edx: u32,
630    ecx: u32,
631    eax: u32,
632    ebp: u32,
633    eip: u32,
634    seg_cs: u32,
635    eflags: u32,
636    esp: u32,
637    seg_ss: u32,
638    extended_registers: [u8; 512],
639}
640
641#[cfg(target_arch = "x86_64")]
642impl Context {
643    // make dr fields accessible as u64 regardless of actual type
644}
645
646#[cfg(target_arch = "x86")]
647impl Context {
648    // x86 context has 32-bit debug registers
649}
650
651type VectoredHandler = Option<extern "system" fn(*mut ExceptionPointers) -> i32>;
652
653#[link(name = "kernel32")]
654extern "system" {
655    fn AddVectoredExceptionHandler(first: u32, handler: VectoredHandler) -> *mut core::ffi::c_void;
656    fn RemoveVectoredExceptionHandler(handle: *mut core::ffi::c_void) -> u32;
657    fn GetCurrentThread() -> *mut core::ffi::c_void;
658    fn GetThreadContext(thread: *mut core::ffi::c_void, context: *mut Context) -> i32;
659    fn SetThreadContext(thread: *mut core::ffi::c_void, context: *const Context) -> i32;
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_ensure_veh_handler() {
668        ensure_veh_handler().expect("should install VEH handler");
669        assert!(VEH_INSTALLED.load(Ordering::Relaxed));
670    }
671
672    #[test]
673    fn test_get_available_dr() {
674        let dr = get_available_debug_register();
675        // might fail if all DRs are in use by debugger
676        if dr.is_ok() {
677            let dr = dr.unwrap();
678            assert!(matches!(
679                dr,
680                DebugRegister::Dr0
681                    | DebugRegister::Dr1
682                    | DebugRegister::Dr2
683                    | DebugRegister::Dr3
684            ));
685        }
686    }
687}