wraith/manipulation/inline_hook/
guard.rs

1//! RAII hook guard for automatic cleanup
2//!
3//! Provides automatic restoration of hooked functions when the guard is dropped,
4//! similar to UnlinkGuard in the unlink module.
5
6use crate::error::{Result, WraithError};
7use crate::util::memory::ProtectionGuard;
8use super::arch::Architecture;
9use super::trampoline::ExecutableMemory;
10use core::marker::PhantomData;
11
12const PAGE_EXECUTE_READWRITE: u32 = 0x40;
13
14/// RAII guard for an installed inline hook
15///
16/// when dropped, automatically restores the original function bytes
17/// unless `leak()` was called.
18pub struct HookGuard<A: Architecture> {
19    /// address of the hooked function
20    target: usize,
21    /// address of the detour function
22    detour: usize,
23    /// original bytes that were overwritten
24    original_bytes: Vec<u8>,
25    /// trampoline memory (if allocated)
26    trampoline: Option<ExecutableMemory>,
27    /// whether to restore on drop
28    auto_restore: bool,
29    /// architecture marker
30    _arch: PhantomData<A>,
31}
32
33impl<A: Architecture> HookGuard<A> {
34    /// create a new hook guard
35    pub(crate) fn new(
36        target: usize,
37        detour: usize,
38        original_bytes: Vec<u8>,
39        trampoline: Option<ExecutableMemory>,
40    ) -> Self {
41        Self {
42            target,
43            detour,
44            original_bytes,
45            trampoline,
46            auto_restore: true,
47            _arch: PhantomData,
48        }
49    }
50
51    /// get the target (hooked) function address
52    pub fn target(&self) -> usize {
53        self.target
54    }
55
56    /// get the detour function address
57    pub fn detour(&self) -> usize {
58        self.detour
59    }
60
61    /// get the trampoline address (call this to invoke the original function)
62    ///
63    /// returns None if no trampoline was allocated
64    pub fn trampoline(&self) -> Option<usize> {
65        self.trampoline.as_ref().map(|t| t.base())
66    }
67
68    /// get the original bytes that were overwritten
69    pub fn original_bytes(&self) -> &[u8] {
70        &self.original_bytes
71    }
72
73    /// check if auto-restore is enabled
74    pub fn will_auto_restore(&self) -> bool {
75        self.auto_restore
76    }
77
78    /// set whether to auto-restore on drop
79    pub fn set_auto_restore(&mut self, restore: bool) {
80        self.auto_restore = restore;
81    }
82
83    /// disable auto-restore and keep the hook active permanently
84    ///
85    /// consumes the guard without restoring the original function.
86    /// the trampoline memory is leaked and remains valid.
87    pub fn leak(mut self) {
88        self.auto_restore = false;
89        if let Some(trampoline) = self.trampoline.take() {
90            trampoline.leak();
91        }
92        core::mem::forget(self);
93    }
94
95    /// manually restore the original function
96    ///
97    /// consumes the guard and restores the hooked function to its original state.
98    pub fn restore(self) -> Result<()> {
99        self.restore_internal()?;
100        // prevent double-restore in Drop
101        core::mem::forget(self);
102        Ok(())
103    }
104
105    /// temporarily disable the hook
106    ///
107    /// restores original bytes but keeps the guard and trampoline alive
108    /// for re-enabling later.
109    pub fn disable(&mut self) -> Result<()> {
110        let _guard = ProtectionGuard::new(
111            self.target,
112            self.original_bytes.len(),
113            PAGE_EXECUTE_READWRITE,
114        )?;
115
116        // SAFETY: protection changed to RWX, original bytes length matches what we overwrote
117        unsafe {
118            core::ptr::copy_nonoverlapping(
119                self.original_bytes.as_ptr(),
120                self.target as *mut u8,
121                self.original_bytes.len(),
122            );
123        }
124
125        flush_icache(self.target, self.original_bytes.len())?;
126
127        Ok(())
128    }
129
130    /// re-enable a previously disabled hook
131    ///
132    /// writes the hook stub back to the target function.
133    pub fn enable(&mut self, hook_bytes: &[u8]) -> Result<()> {
134        if hook_bytes.len() != self.original_bytes.len() {
135            return Err(WraithError::WriteFailed {
136                address: self.target as u64,
137                size: hook_bytes.len(),
138            });
139        }
140
141        let _guard = ProtectionGuard::new(
142            self.target,
143            hook_bytes.len(),
144            PAGE_EXECUTE_READWRITE,
145        )?;
146
147        // SAFETY: protection changed, size matches
148        unsafe {
149            core::ptr::copy_nonoverlapping(
150                hook_bytes.as_ptr(),
151                self.target as *mut u8,
152                hook_bytes.len(),
153            );
154        }
155
156        flush_icache(self.target, hook_bytes.len())?;
157
158        Ok(())
159    }
160
161    /// internal restore implementation
162    fn restore_internal(&self) -> Result<()> {
163        let _guard = ProtectionGuard::new(
164            self.target,
165            self.original_bytes.len(),
166            PAGE_EXECUTE_READWRITE,
167        )?;
168
169        // SAFETY: protection changed, original_bytes verified at hook install time
170        unsafe {
171            core::ptr::copy_nonoverlapping(
172                self.original_bytes.as_ptr(),
173                self.target as *mut u8,
174                self.original_bytes.len(),
175            );
176        }
177
178        flush_icache(self.target, self.original_bytes.len())?;
179
180        Ok(())
181    }
182}
183
184impl<A: Architecture> Drop for HookGuard<A> {
185    fn drop(&mut self) {
186        if self.auto_restore {
187            // ignore errors during drop
188            let _ = self.restore_internal();
189        }
190    }
191}
192
193// SAFETY: HookGuard contains process-wide memory addresses
194// the hook state is shared across threads anyway
195unsafe impl<A: Architecture> Send for HookGuard<A> {}
196unsafe impl<A: Architecture> Sync for HookGuard<A> {}
197
198/// hook state for enable/disable tracking
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
200pub enum HookState {
201    /// hook is active
202    Enabled,
203    /// hook is temporarily disabled
204    Disabled,
205}
206
207/// stateful hook guard that tracks enable/disable state
208pub struct StatefulHookGuard<A: Architecture> {
209    guard: HookGuard<A>,
210    hook_bytes: Vec<u8>,
211    state: HookState,
212}
213
214impl<A: Architecture> StatefulHookGuard<A> {
215    /// create from guard and hook bytes
216    pub(crate) fn new(guard: HookGuard<A>, hook_bytes: Vec<u8>) -> Self {
217        Self {
218            guard,
219            hook_bytes,
220            state: HookState::Enabled,
221        }
222    }
223
224    /// get current state
225    pub fn state(&self) -> HookState {
226        self.state
227    }
228
229    /// check if enabled
230    pub fn is_enabled(&self) -> bool {
231        self.state == HookState::Enabled
232    }
233
234    /// disable the hook
235    pub fn disable(&mut self) -> Result<()> {
236        if self.state == HookState::Enabled {
237            self.guard.disable()?;
238            self.state = HookState::Disabled;
239        }
240        Ok(())
241    }
242
243    /// enable the hook
244    pub fn enable(&mut self) -> Result<()> {
245        if self.state == HookState::Disabled {
246            self.guard.enable(&self.hook_bytes)?;
247            self.state = HookState::Enabled;
248        }
249        Ok(())
250    }
251
252    /// toggle hook state
253    pub fn toggle(&mut self) -> Result<()> {
254        match self.state {
255            HookState::Enabled => self.disable(),
256            HookState::Disabled => self.enable(),
257        }
258    }
259
260    /// get target address
261    pub fn target(&self) -> usize {
262        self.guard.target()
263    }
264
265    /// get detour address
266    pub fn detour(&self) -> usize {
267        self.guard.detour()
268    }
269
270    /// get trampoline address
271    pub fn trampoline(&self) -> Option<usize> {
272        self.guard.trampoline()
273    }
274
275    /// leak the hook (keep it active permanently)
276    pub fn leak(self) {
277        self.guard.leak();
278    }
279
280    /// restore and consume
281    pub fn restore(self) -> Result<()> {
282        self.guard.restore()
283    }
284}
285
286/// flush instruction cache
287fn flush_icache(address: usize, size: usize) -> Result<()> {
288    let result = unsafe {
289        FlushInstructionCache(
290            GetCurrentProcess(),
291            address as *const _,
292            size,
293        )
294    };
295
296    if result == 0 {
297        Err(WraithError::from_last_error("FlushInstructionCache"))
298    } else {
299        Ok(())
300    }
301}
302
303#[link(name = "kernel32")]
304extern "system" {
305    fn FlushInstructionCache(
306        hProcess: *mut core::ffi::c_void,
307        lpBaseAddress: *const core::ffi::c_void,
308        dwSize: usize,
309    ) -> i32;
310
311    fn GetCurrentProcess() -> *mut core::ffi::c_void;
312}