1#[cfg(all(not(feature = "std"), feature = "alloc"))]
20use alloc::string::String;
21
22#[cfg(feature = "std")]
23use std::string::String;
24
25use crate::error::{Result, WraithError};
26use core::cell::UnsafeCell;
27use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
28
29const EXCEPTION_CONTINUE_EXECUTION: i32 = -1;
31const EXCEPTION_CONTINUE_SEARCH: i32 = 0;
32
33const EXCEPTION_BREAKPOINT: u32 = 0x80000003;
35const EXCEPTION_SINGLE_STEP: u32 = 0x80000004;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39#[repr(u8)]
40pub enum DebugRegister {
41 Dr0 = 0,
42 Dr1 = 1,
43 Dr2 = 2,
44 Dr3 = 3,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49#[repr(u8)]
50pub enum BreakCondition {
51 Execute = 0b00,
53 Write = 0b01,
55 Io = 0b10,
57 ReadWrite = 0b11,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63#[repr(u8)]
64pub enum BreakLength {
65 Byte = 0b00,
66 Word = 0b01,
67 Qword = 0b10, Dword = 0b11,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum VehHookType {
74 Hardware(DebugRegister),
76 Int3,
78}
79
80struct VehHookTarget {
82 target: usize,
84 detour: usize,
86 original_byte: u8,
88 hook_type: VehHookType,
90 active: bool,
92}
93
94const MAX_VEH_HOOKS: usize = 64;
96
97static VEH_HOOKS: VehHookRegistry = VehHookRegistry::new();
99
100static VEH_HANDLER: AtomicUsize = AtomicUsize::new(0);
102
103static VEH_INSTALLED: AtomicBool = AtomicBool::new(false);
105
106struct VehHookRegistry {
108 hooks: UnsafeCell<[Option<VehHookTarget>; MAX_VEH_HOOKS]>,
109}
110
111impl VehHookRegistry {
112 const fn new() -> Self {
113 const INIT: Option<VehHookTarget> = None;
115 Self {
116 hooks: UnsafeCell::new([INIT; MAX_VEH_HOOKS]),
117 }
118 }
119
120 fn find_hook(&self, address: usize) -> Option<(usize, usize)> {
121 let hooks = unsafe { &*self.hooks.get() };
123 for hook in hooks.iter().flatten() {
124 if hook.active && hook.target == address {
125 return Some((hook.target, hook.detour));
126 }
127 }
128 None
129 }
130
131 fn register(&self, target: VehHookTarget) -> Result<usize> {
132 let hooks = unsafe { &mut *self.hooks.get() };
134 for (i, slot) in hooks.iter_mut().enumerate() {
135 if slot.is_none() {
136 *slot = Some(target);
137 return Ok(i);
138 }
139 }
140 Err(WraithError::HookInstallFailed {
141 target: 0,
142 reason: "VEH hook registry full".into(),
143 })
144 }
145
146 fn unregister(&self, index: usize) -> Option<VehHookTarget> {
147 let hooks = unsafe { &mut *self.hooks.get() };
149 if index < MAX_VEH_HOOKS {
150 hooks[index].take()
151 } else {
152 None
153 }
154 }
155
156 fn get(&self, index: usize) -> Option<&VehHookTarget> {
157 let hooks = unsafe { &*self.hooks.get() };
159 hooks.get(index).and_then(|h| h.as_ref())
160 }
161}
162
163unsafe impl Sync for VehHookRegistry {}
165
166pub struct VehHook {
168 index: usize,
170 hook_type: VehHookType,
172 auto_restore: bool,
174}
175
176impl VehHook {
177 pub fn hardware(target: usize, detour: usize, dr: DebugRegister) -> Result<Self> {
189 ensure_veh_handler()?;
190
191 set_hardware_breakpoint(dr, target, BreakCondition::Execute, BreakLength::Byte)?;
193
194 let hook = VehHookTarget {
195 target,
196 detour,
197 original_byte: 0,
198 hook_type: VehHookType::Hardware(dr),
199 active: true,
200 };
201
202 let index = VEH_HOOKS.register(hook)?;
203
204 Ok(Self {
205 index,
206 hook_type: VehHookType::Hardware(dr),
207 auto_restore: true,
208 })
209 }
210
211 pub fn int3(target: usize, detour: usize) -> Result<Self> {
222 ensure_veh_handler()?;
223
224 let original_byte = unsafe { *(target as *const u8) };
227
228 let _guard = crate::util::memory::ProtectionGuard::new(target, 1, 0x40)?;
230 unsafe {
231 *(target as *mut u8) = 0xCC;
232 }
233
234 let hook = VehHookTarget {
235 target,
236 detour,
237 original_byte,
238 hook_type: VehHookType::Int3,
239 active: true,
240 };
241
242 let index = VEH_HOOKS.register(hook)?;
243
244 Ok(Self {
245 index,
246 hook_type: VehHookType::Int3,
247 auto_restore: true,
248 })
249 }
250
251 pub fn is_active(&self) -> bool {
253 VEH_HOOKS.get(self.index).map_or(false, |h| h.active)
254 }
255
256 pub fn target(&self) -> Option<usize> {
258 VEH_HOOKS.get(self.index).map(|h| h.target)
259 }
260
261 pub fn detour(&self) -> Option<usize> {
263 VEH_HOOKS.get(self.index).map(|h| h.detour)
264 }
265
266 pub fn hook_type(&self) -> VehHookType {
268 self.hook_type
269 }
270
271 pub fn set_auto_restore(&mut self, restore: bool) {
273 self.auto_restore = restore;
274 }
275
276 pub fn leak(mut self) {
278 self.auto_restore = false;
279 core::mem::forget(self);
280 }
281
282 pub fn restore(self) -> Result<()> {
284 if let Some(hook) = VEH_HOOKS.unregister(self.index) {
285 match hook.hook_type {
286 VehHookType::Hardware(dr) => {
287 clear_hardware_breakpoint(dr)?;
288 }
289 VehHookType::Int3 => {
290 let _guard = crate::util::memory::ProtectionGuard::new(
292 hook.target, 1, 0x40,
293 )?;
294 unsafe {
295 *(hook.target as *mut u8) = hook.original_byte;
296 }
297 }
298 }
299 }
300
301 core::mem::forget(self);
303 Ok(())
304 }
305}
306
307impl Drop for VehHook {
308 fn drop(&mut self) {
309 if self.auto_restore {
310 if let Some(hook) = VEH_HOOKS.unregister(self.index) {
311 match hook.hook_type {
312 VehHookType::Hardware(dr) => {
313 let _ = clear_hardware_breakpoint(dr);
314 }
315 VehHookType::Int3 => {
316 if let Ok(_guard) = crate::util::memory::ProtectionGuard::new(
317 hook.target, 1, 0x40,
318 ) {
319 unsafe {
320 *(hook.target as *mut u8) = hook.original_byte;
321 }
322 }
323 }
324 }
325 }
326 }
327 }
328}
329
330unsafe impl Send for VehHook {}
332unsafe impl Sync for VehHook {}
333
334fn ensure_veh_handler() -> Result<()> {
336 if VEH_INSTALLED.load(Ordering::Acquire) {
337 return Ok(());
338 }
339
340 let handler = unsafe {
341 AddVectoredExceptionHandler(1, Some(veh_handler))
342 };
343
344 if handler.is_null() {
345 return Err(WraithError::from_last_error("AddVectoredExceptionHandler"));
346 }
347
348 VEH_HANDLER.store(handler as usize, Ordering::Release);
349 VEH_INSTALLED.store(true, Ordering::Release);
350
351 Ok(())
352}
353
354extern "system" fn veh_handler(exception_info: *mut ExceptionPointers) -> i32 {
356 if exception_info.is_null() {
357 return EXCEPTION_CONTINUE_SEARCH;
358 }
359
360 let info = unsafe { &*exception_info };
362 let record = unsafe { &*info.exception_record };
363 let context = unsafe { &mut *info.context_record };
364
365 let exception_code = record.exception_code;
366
367 if exception_code == EXCEPTION_BREAKPOINT || exception_code == EXCEPTION_SINGLE_STEP {
369 #[cfg(target_arch = "x86_64")]
370 let exception_address = context.rip as usize;
371 #[cfg(target_arch = "x86")]
372 let exception_address = context.eip as usize;
373
374 if let Some((target, detour)) = VEH_HOOKS.find_hook(exception_address) {
376 let adjusted_addr = if exception_code == EXCEPTION_BREAKPOINT {
378 exception_address.saturating_sub(1)
379 } else {
380 exception_address
381 };
382
383 if adjusted_addr == target || exception_address == target {
384 #[cfg(target_arch = "x86_64")]
386 {
387 context.rip = detour as u64;
388 }
389 #[cfg(target_arch = "x86")]
390 {
391 context.eip = detour as u32;
392 }
393
394 if exception_code == EXCEPTION_SINGLE_STEP {
396 #[cfg(target_arch = "x86_64")]
397 {
398 context.eflags |= 0x10000; }
400 #[cfg(target_arch = "x86")]
401 {
402 context.eflags |= 0x10000; }
404 }
405
406 return EXCEPTION_CONTINUE_EXECUTION;
407 }
408 }
409 }
410
411 EXCEPTION_CONTINUE_SEARCH
412}
413
414fn set_hardware_breakpoint(
416 dr: DebugRegister,
417 address: usize,
418 condition: BreakCondition,
419 length: BreakLength,
420) -> Result<()> {
421 let mut context = unsafe { core::mem::zeroed::<Context>() };
422
423 #[cfg(target_arch = "x86_64")]
424 {
425 context.context_flags = CONTEXT_DEBUG_REGISTERS;
426 }
427 #[cfg(target_arch = "x86")]
428 {
429 context.context_flags = CONTEXT_DEBUG_REGISTERS;
430 }
431
432 let thread = unsafe { GetCurrentThread() };
433
434 if unsafe { GetThreadContext(thread, &mut context) } == 0 {
435 return Err(WraithError::from_last_error("GetThreadContext"));
436 }
437
438 match dr {
440 DebugRegister::Dr0 => context.dr0 = address as u64,
441 DebugRegister::Dr1 => context.dr1 = address as u64,
442 DebugRegister::Dr2 => context.dr2 = address as u64,
443 DebugRegister::Dr3 => context.dr3 = address as u64,
444 }
445
446 let dr_index = dr as u8;
448 let enable_bit = 1u64 << (dr_index * 2); let condition_bits = (condition as u64) << (16 + dr_index * 4);
450 let length_bits = (length as u64) << (18 + dr_index * 4);
451
452 let clear_mask = !(0b11u64 << (dr_index * 2) | 0b1111u64 << (16 + dr_index * 4));
454 context.dr7 &= clear_mask;
455
456 context.dr7 |= enable_bit | condition_bits | length_bits;
458
459 if unsafe { SetThreadContext(thread, &context) } == 0 {
460 return Err(WraithError::from_last_error("SetThreadContext"));
461 }
462
463 Ok(())
464}
465
466fn clear_hardware_breakpoint(dr: DebugRegister) -> Result<()> {
468 let mut context = unsafe { core::mem::zeroed::<Context>() };
469
470 #[cfg(target_arch = "x86_64")]
471 {
472 context.context_flags = CONTEXT_DEBUG_REGISTERS;
473 }
474 #[cfg(target_arch = "x86")]
475 {
476 context.context_flags = CONTEXT_DEBUG_REGISTERS;
477 }
478
479 let thread = unsafe { GetCurrentThread() };
480
481 if unsafe { GetThreadContext(thread, &mut context) } == 0 {
482 return Err(WraithError::from_last_error("GetThreadContext"));
483 }
484
485 match dr {
487 DebugRegister::Dr0 => context.dr0 = 0,
488 DebugRegister::Dr1 => context.dr1 = 0,
489 DebugRegister::Dr2 => context.dr2 = 0,
490 DebugRegister::Dr3 => context.dr3 = 0,
491 }
492
493 let dr_index = dr as u8;
495 let disable_mask = !(0b11u64 << (dr_index * 2) | 0b1111u64 << (16 + dr_index * 4));
496 context.dr7 &= disable_mask;
497
498 if unsafe { SetThreadContext(thread, &context) } == 0 {
499 return Err(WraithError::from_last_error("SetThreadContext"));
500 }
501
502 Ok(())
503}
504
505pub fn get_available_debug_register() -> Result<DebugRegister> {
507 let mut context = unsafe { core::mem::zeroed::<Context>() };
508
509 #[cfg(target_arch = "x86_64")]
510 {
511 context.context_flags = CONTEXT_DEBUG_REGISTERS;
512 }
513 #[cfg(target_arch = "x86")]
514 {
515 context.context_flags = CONTEXT_DEBUG_REGISTERS;
516 }
517
518 let thread = unsafe { GetCurrentThread() };
519
520 if unsafe { GetThreadContext(thread, &mut context) } == 0 {
521 return Err(WraithError::from_last_error("GetThreadContext"));
522 }
523
524 for i in 0..4u8 {
526 let is_enabled = (context.dr7 & (1u64 << (i * 2))) != 0;
527 if !is_enabled {
528 return Ok(match i {
529 0 => DebugRegister::Dr0,
530 1 => DebugRegister::Dr1,
531 2 => DebugRegister::Dr2,
532 _ => DebugRegister::Dr3,
533 });
534 }
535 }
536
537 Err(WraithError::GadgetNotFound {
538 gadget_type: "available debug register",
539 })
540}
541
542#[cfg(target_arch = "x86_64")]
544const CONTEXT_DEBUG_REGISTERS: u32 = 0x00100010;
545#[cfg(target_arch = "x86")]
546const CONTEXT_DEBUG_REGISTERS: u32 = 0x00010010;
547
548#[repr(C)]
550struct ExceptionRecord {
551 exception_code: u32,
552 exception_flags: u32,
553 exception_record: *mut ExceptionRecord,
554 exception_address: *mut core::ffi::c_void,
555 number_parameters: u32,
556 exception_information: [usize; 15],
557}
558
559#[repr(C)]
561struct ExceptionPointers {
562 exception_record: *mut ExceptionRecord,
563 context_record: *mut Context,
564}
565
566#[repr(C)]
568#[cfg(target_arch = "x86_64")]
569struct Context {
570 p1_home: u64,
571 p2_home: u64,
572 p3_home: u64,
573 p4_home: u64,
574 p5_home: u64,
575 p6_home: u64,
576 context_flags: u32,
577 mx_csr: u32,
578 seg_cs: u16,
579 seg_ds: u16,
580 seg_es: u16,
581 seg_fs: u16,
582 seg_gs: u16,
583 seg_ss: u16,
584 eflags: u32,
585 dr0: u64,
586 dr1: u64,
587 dr2: u64,
588 dr3: u64,
589 dr6: u64,
590 dr7: u64,
591 rax: u64,
592 rcx: u64,
593 rdx: u64,
594 rbx: u64,
595 rsp: u64,
596 rbp: u64,
597 rsi: u64,
598 rdi: u64,
599 r8: u64,
600 r9: u64,
601 r10: u64,
602 r11: u64,
603 r12: u64,
604 r13: u64,
605 r14: u64,
606 r15: u64,
607 rip: u64,
608 _padding: [u8; 512], }
610
611#[repr(C)]
612#[cfg(target_arch = "x86")]
613struct Context {
614 context_flags: u32,
615 dr0: u32,
616 dr1: u32,
617 dr2: u32,
618 dr3: u32,
619 dr6: u32,
620 dr7: u32,
621 float_save: [u8; 112],
622 seg_gs: u32,
623 seg_fs: u32,
624 seg_es: u32,
625 seg_ds: u32,
626 edi: u32,
627 esi: u32,
628 ebx: u32,
629 edx: u32,
630 ecx: u32,
631 eax: u32,
632 ebp: u32,
633 eip: u32,
634 seg_cs: u32,
635 eflags: u32,
636 esp: u32,
637 seg_ss: u32,
638 extended_registers: [u8; 512],
639}
640
641#[cfg(target_arch = "x86_64")]
642impl Context {
643 }
645
646#[cfg(target_arch = "x86")]
647impl Context {
648 }
650
651type VectoredHandler = Option<extern "system" fn(*mut ExceptionPointers) -> i32>;
652
653#[link(name = "kernel32")]
654extern "system" {
655 fn AddVectoredExceptionHandler(first: u32, handler: VectoredHandler) -> *mut core::ffi::c_void;
656 fn RemoveVectoredExceptionHandler(handle: *mut core::ffi::c_void) -> u32;
657 fn GetCurrentThread() -> *mut core::ffi::c_void;
658 fn GetThreadContext(thread: *mut core::ffi::c_void, context: *mut Context) -> i32;
659 fn SetThreadContext(thread: *mut core::ffi::c_void, context: *const Context) -> i32;
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn test_ensure_veh_handler() {
668 ensure_veh_handler().expect("should install VEH handler");
669 assert!(VEH_INSTALLED.load(Ordering::Relaxed));
670 }
671
672 #[test]
673 fn test_get_available_dr() {
674 let dr = get_available_debug_register();
675 if dr.is_ok() {
677 let dr = dr.unwrap();
678 assert!(matches!(
679 dr,
680 DebugRegister::Dr0
681 | DebugRegister::Dr1
682 | DebugRegister::Dr2
683 | DebugRegister::Dr3
684 ));
685 }
686 }
687}