wraith/manipulation/inline_hook/
guard.rs1#[cfg(all(not(feature = "std"), feature = "alloc"))]
7use alloc::vec::Vec;
8
9#[cfg(feature = "std")]
10use std::vec::Vec;
11
12use crate::error::{Result, WraithError};
13use crate::util::memory::ProtectionGuard;
14use super::arch::Architecture;
15use super::trampoline::ExecutableMemory;
16use core::marker::PhantomData;
17
18const PAGE_EXECUTE_READWRITE: u32 = 0x40;
19
20pub struct HookGuard<A: Architecture> {
25 target: usize,
27 detour: usize,
29 original_bytes: Vec<u8>,
31 trampoline: Option<ExecutableMemory>,
33 auto_restore: bool,
35 _arch: PhantomData<A>,
37}
38
39impl<A: Architecture> HookGuard<A> {
40 pub(crate) fn new(
42 target: usize,
43 detour: usize,
44 original_bytes: Vec<u8>,
45 trampoline: Option<ExecutableMemory>,
46 ) -> Self {
47 Self {
48 target,
49 detour,
50 original_bytes,
51 trampoline,
52 auto_restore: true,
53 _arch: PhantomData,
54 }
55 }
56
57 pub fn target(&self) -> usize {
59 self.target
60 }
61
62 pub fn detour(&self) -> usize {
64 self.detour
65 }
66
67 pub fn trampoline(&self) -> Option<usize> {
71 self.trampoline.as_ref().map(|t| t.base())
72 }
73
74 pub fn original_bytes(&self) -> &[u8] {
76 &self.original_bytes
77 }
78
79 pub fn will_auto_restore(&self) -> bool {
81 self.auto_restore
82 }
83
84 pub fn set_auto_restore(&mut self, restore: bool) {
86 self.auto_restore = restore;
87 }
88
89 pub fn leak(mut self) {
94 self.auto_restore = false;
95 if let Some(trampoline) = self.trampoline.take() {
96 trampoline.leak();
97 }
98 core::mem::forget(self);
99 }
100
101 pub fn restore(self) -> Result<()> {
105 self.restore_internal()?;
106 core::mem::forget(self);
108 Ok(())
109 }
110
111 pub fn disable(&mut self) -> Result<()> {
116 let _guard = ProtectionGuard::new(
117 self.target,
118 self.original_bytes.len(),
119 PAGE_EXECUTE_READWRITE,
120 )?;
121
122 unsafe {
124 core::ptr::copy_nonoverlapping(
125 self.original_bytes.as_ptr(),
126 self.target as *mut u8,
127 self.original_bytes.len(),
128 );
129 }
130
131 flush_icache(self.target, self.original_bytes.len())?;
132
133 Ok(())
134 }
135
136 pub fn enable(&mut self, hook_bytes: &[u8]) -> Result<()> {
140 if hook_bytes.len() != self.original_bytes.len() {
141 return Err(WraithError::WriteFailed {
142 address: self.target as u64,
143 size: hook_bytes.len(),
144 });
145 }
146
147 let _guard = ProtectionGuard::new(
148 self.target,
149 hook_bytes.len(),
150 PAGE_EXECUTE_READWRITE,
151 )?;
152
153 unsafe {
155 core::ptr::copy_nonoverlapping(
156 hook_bytes.as_ptr(),
157 self.target as *mut u8,
158 hook_bytes.len(),
159 );
160 }
161
162 flush_icache(self.target, hook_bytes.len())?;
163
164 Ok(())
165 }
166
167 fn restore_internal(&self) -> Result<()> {
169 let _guard = ProtectionGuard::new(
170 self.target,
171 self.original_bytes.len(),
172 PAGE_EXECUTE_READWRITE,
173 )?;
174
175 unsafe {
177 core::ptr::copy_nonoverlapping(
178 self.original_bytes.as_ptr(),
179 self.target as *mut u8,
180 self.original_bytes.len(),
181 );
182 }
183
184 flush_icache(self.target, self.original_bytes.len())?;
185
186 Ok(())
187 }
188}
189
190impl<A: Architecture> Drop for HookGuard<A> {
191 fn drop(&mut self) {
192 if self.auto_restore {
193 let _ = self.restore_internal();
195 }
196 }
197}
198
199unsafe impl<A: Architecture> Send for HookGuard<A> {}
202unsafe impl<A: Architecture> Sync for HookGuard<A> {}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub enum HookState {
207 Enabled,
209 Disabled,
211}
212
213pub struct StatefulHookGuard<A: Architecture> {
215 guard: HookGuard<A>,
216 hook_bytes: Vec<u8>,
217 state: HookState,
218}
219
220impl<A: Architecture> StatefulHookGuard<A> {
221 pub(crate) fn new(guard: HookGuard<A>, hook_bytes: Vec<u8>) -> Self {
223 Self {
224 guard,
225 hook_bytes,
226 state: HookState::Enabled,
227 }
228 }
229
230 pub fn state(&self) -> HookState {
232 self.state
233 }
234
235 pub fn is_enabled(&self) -> bool {
237 self.state == HookState::Enabled
238 }
239
240 pub fn disable(&mut self) -> Result<()> {
242 if self.state == HookState::Enabled {
243 self.guard.disable()?;
244 self.state = HookState::Disabled;
245 }
246 Ok(())
247 }
248
249 pub fn enable(&mut self) -> Result<()> {
251 if self.state == HookState::Disabled {
252 self.guard.enable(&self.hook_bytes)?;
253 self.state = HookState::Enabled;
254 }
255 Ok(())
256 }
257
258 pub fn toggle(&mut self) -> Result<()> {
260 match self.state {
261 HookState::Enabled => self.disable(),
262 HookState::Disabled => self.enable(),
263 }
264 }
265
266 pub fn target(&self) -> usize {
268 self.guard.target()
269 }
270
271 pub fn detour(&self) -> usize {
273 self.guard.detour()
274 }
275
276 pub fn trampoline(&self) -> Option<usize> {
278 self.guard.trampoline()
279 }
280
281 pub fn leak(self) {
283 self.guard.leak();
284 }
285
286 pub fn restore(self) -> Result<()> {
288 self.guard.restore()
289 }
290}
291
292fn flush_icache(address: usize, size: usize) -> Result<()> {
294 let result = unsafe {
295 FlushInstructionCache(
296 GetCurrentProcess(),
297 address as *const _,
298 size,
299 )
300 };
301
302 if result == 0 {
303 Err(WraithError::from_last_error("FlushInstructionCache"))
304 } else {
305 Ok(())
306 }
307}
308
309#[link(name = "kernel32")]
310extern "system" {
311 fn FlushInstructionCache(
312 hProcess: *mut core::ffi::c_void,
313 lpBaseAddress: *const core::ffi::c_void,
314 dwSize: usize,
315 ) -> i32;
316
317 fn GetCurrentProcess() -> *mut core::ffi::c_void;
318}