wraith/manipulation/inline_hook/
guard.rs

1//! RAII hook guard for automatic cleanup
2//!
3//! Provides automatic restoration of hooked functions when the guard is dropped,
4//! similar to UnlinkGuard in the unlink module.
5
6#[cfg(all(not(feature = "std"), feature = "alloc"))]
7use alloc::vec::Vec;
8
9#[cfg(feature = "std")]
10use std::vec::Vec;
11
12use crate::error::{Result, WraithError};
13use crate::util::memory::ProtectionGuard;
14use super::arch::Architecture;
15use super::trampoline::ExecutableMemory;
16use core::marker::PhantomData;
17
18const PAGE_EXECUTE_READWRITE: u32 = 0x40;
19
20/// RAII guard for an installed inline hook
21///
22/// when dropped, automatically restores the original function bytes
23/// unless `leak()` was called.
24pub struct HookGuard<A: Architecture> {
25    /// address of the hooked function
26    target: usize,
27    /// address of the detour function
28    detour: usize,
29    /// original bytes that were overwritten
30    original_bytes: Vec<u8>,
31    /// trampoline memory (if allocated)
32    trampoline: Option<ExecutableMemory>,
33    /// whether to restore on drop
34    auto_restore: bool,
35    /// architecture marker
36    _arch: PhantomData<A>,
37}
38
39impl<A: Architecture> HookGuard<A> {
40    /// create a new hook guard
41    pub(crate) fn new(
42        target: usize,
43        detour: usize,
44        original_bytes: Vec<u8>,
45        trampoline: Option<ExecutableMemory>,
46    ) -> Self {
47        Self {
48            target,
49            detour,
50            original_bytes,
51            trampoline,
52            auto_restore: true,
53            _arch: PhantomData,
54        }
55    }
56
57    /// get the target (hooked) function address
58    pub fn target(&self) -> usize {
59        self.target
60    }
61
62    /// get the detour function address
63    pub fn detour(&self) -> usize {
64        self.detour
65    }
66
67    /// get the trampoline address (call this to invoke the original function)
68    ///
69    /// returns None if no trampoline was allocated
70    pub fn trampoline(&self) -> Option<usize> {
71        self.trampoline.as_ref().map(|t| t.base())
72    }
73
74    /// get the original bytes that were overwritten
75    pub fn original_bytes(&self) -> &[u8] {
76        &self.original_bytes
77    }
78
79    /// check if auto-restore is enabled
80    pub fn will_auto_restore(&self) -> bool {
81        self.auto_restore
82    }
83
84    /// set whether to auto-restore on drop
85    pub fn set_auto_restore(&mut self, restore: bool) {
86        self.auto_restore = restore;
87    }
88
89    /// disable auto-restore and keep the hook active permanently
90    ///
91    /// consumes the guard without restoring the original function.
92    /// the trampoline memory is leaked and remains valid.
93    pub fn leak(mut self) {
94        self.auto_restore = false;
95        if let Some(trampoline) = self.trampoline.take() {
96            trampoline.leak();
97        }
98        core::mem::forget(self);
99    }
100
101    /// manually restore the original function
102    ///
103    /// consumes the guard and restores the hooked function to its original state.
104    pub fn restore(self) -> Result<()> {
105        self.restore_internal()?;
106        // prevent double-restore in Drop
107        core::mem::forget(self);
108        Ok(())
109    }
110
111    /// temporarily disable the hook
112    ///
113    /// restores original bytes but keeps the guard and trampoline alive
114    /// for re-enabling later.
115    pub fn disable(&mut self) -> Result<()> {
116        let _guard = ProtectionGuard::new(
117            self.target,
118            self.original_bytes.len(),
119            PAGE_EXECUTE_READWRITE,
120        )?;
121
122        // SAFETY: protection changed to RWX, original bytes length matches what we overwrote
123        unsafe {
124            core::ptr::copy_nonoverlapping(
125                self.original_bytes.as_ptr(),
126                self.target as *mut u8,
127                self.original_bytes.len(),
128            );
129        }
130
131        flush_icache(self.target, self.original_bytes.len())?;
132
133        Ok(())
134    }
135
136    /// re-enable a previously disabled hook
137    ///
138    /// writes the hook stub back to the target function.
139    pub fn enable(&mut self, hook_bytes: &[u8]) -> Result<()> {
140        if hook_bytes.len() != self.original_bytes.len() {
141            return Err(WraithError::WriteFailed {
142                address: self.target as u64,
143                size: hook_bytes.len(),
144            });
145        }
146
147        let _guard = ProtectionGuard::new(
148            self.target,
149            hook_bytes.len(),
150            PAGE_EXECUTE_READWRITE,
151        )?;
152
153        // SAFETY: protection changed, size matches
154        unsafe {
155            core::ptr::copy_nonoverlapping(
156                hook_bytes.as_ptr(),
157                self.target as *mut u8,
158                hook_bytes.len(),
159            );
160        }
161
162        flush_icache(self.target, hook_bytes.len())?;
163
164        Ok(())
165    }
166
167    /// internal restore implementation
168    fn restore_internal(&self) -> Result<()> {
169        let _guard = ProtectionGuard::new(
170            self.target,
171            self.original_bytes.len(),
172            PAGE_EXECUTE_READWRITE,
173        )?;
174
175        // SAFETY: protection changed, original_bytes verified at hook install time
176        unsafe {
177            core::ptr::copy_nonoverlapping(
178                self.original_bytes.as_ptr(),
179                self.target as *mut u8,
180                self.original_bytes.len(),
181            );
182        }
183
184        flush_icache(self.target, self.original_bytes.len())?;
185
186        Ok(())
187    }
188}
189
190impl<A: Architecture> Drop for HookGuard<A> {
191    fn drop(&mut self) {
192        if self.auto_restore {
193            // ignore errors during drop
194            let _ = self.restore_internal();
195        }
196    }
197}
198
199// SAFETY: HookGuard contains process-wide memory addresses
200// the hook state is shared across threads anyway
201unsafe impl<A: Architecture> Send for HookGuard<A> {}
202unsafe impl<A: Architecture> Sync for HookGuard<A> {}
203
204/// hook state for enable/disable tracking
205#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub enum HookState {
207    /// hook is active
208    Enabled,
209    /// hook is temporarily disabled
210    Disabled,
211}
212
213/// stateful hook guard that tracks enable/disable state
214pub struct StatefulHookGuard<A: Architecture> {
215    guard: HookGuard<A>,
216    hook_bytes: Vec<u8>,
217    state: HookState,
218}
219
220impl<A: Architecture> StatefulHookGuard<A> {
221    /// create from guard and hook bytes
222    pub(crate) fn new(guard: HookGuard<A>, hook_bytes: Vec<u8>) -> Self {
223        Self {
224            guard,
225            hook_bytes,
226            state: HookState::Enabled,
227        }
228    }
229
230    /// get current state
231    pub fn state(&self) -> HookState {
232        self.state
233    }
234
235    /// check if enabled
236    pub fn is_enabled(&self) -> bool {
237        self.state == HookState::Enabled
238    }
239
240    /// disable the hook
241    pub fn disable(&mut self) -> Result<()> {
242        if self.state == HookState::Enabled {
243            self.guard.disable()?;
244            self.state = HookState::Disabled;
245        }
246        Ok(())
247    }
248
249    /// enable the hook
250    pub fn enable(&mut self) -> Result<()> {
251        if self.state == HookState::Disabled {
252            self.guard.enable(&self.hook_bytes)?;
253            self.state = HookState::Enabled;
254        }
255        Ok(())
256    }
257
258    /// toggle hook state
259    pub fn toggle(&mut self) -> Result<()> {
260        match self.state {
261            HookState::Enabled => self.disable(),
262            HookState::Disabled => self.enable(),
263        }
264    }
265
266    /// get target address
267    pub fn target(&self) -> usize {
268        self.guard.target()
269    }
270
271    /// get detour address
272    pub fn detour(&self) -> usize {
273        self.guard.detour()
274    }
275
276    /// get trampoline address
277    pub fn trampoline(&self) -> Option<usize> {
278        self.guard.trampoline()
279    }
280
281    /// leak the hook (keep it active permanently)
282    pub fn leak(self) {
283        self.guard.leak();
284    }
285
286    /// restore and consume
287    pub fn restore(self) -> Result<()> {
288        self.guard.restore()
289    }
290}
291
292/// flush instruction cache
293fn flush_icache(address: usize, size: usize) -> Result<()> {
294    let result = unsafe {
295        FlushInstructionCache(
296            GetCurrentProcess(),
297            address as *const _,
298            size,
299        )
300    };
301
302    if result == 0 {
303        Err(WraithError::from_last_error("FlushInstructionCache"))
304    } else {
305        Ok(())
306    }
307}
308
309#[link(name = "kernel32")]
310extern "system" {
311    fn FlushInstructionCache(
312        hProcess: *mut core::ffi::c_void,
313        lpBaseAddress: *const core::ffi::c_void,
314        dwSize: usize,
315    ) -> i32;
316
317    fn GetCurrentProcess() -> *mut core::ffi::c_void;
318}