wraith/manipulation/hooks/
tracker.rs1use super::detector::{HookInfo, HookType};
7use std::collections::HashMap;
8use std::sync::Mutex;
9
10static GLOBAL_TRACKER: Mutex<Option<HookTracker>> = Mutex::new(None);
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum HookState {
16 Active,
18 Removed,
20 Restored,
22}
23
24#[derive(Debug, Clone)]
26pub struct TrackedHook {
27 pub info: HookInfo,
29 pub state: HookState,
31 pub detected_at: std::time::Instant,
33 pub last_changed: std::time::Instant,
35}
36
37impl TrackedHook {
38 fn new(info: HookInfo) -> Self {
39 let now = std::time::Instant::now();
40 Self {
41 info,
42 state: HookState::Active,
43 detected_at: now,
44 last_changed: now,
45 }
46 }
47}
48
49pub struct HookTracker {
51 hooks: HashMap<usize, TrackedHook>,
53 by_module: HashMap<String, Vec<usize>>,
55}
56
57impl HookTracker {
58 pub fn new() -> Self {
60 Self {
61 hooks: HashMap::new(),
62 by_module: HashMap::new(),
63 }
64 }
65
66 pub fn register(&mut self, info: HookInfo) {
68 let addr = info.function_address;
69 let module = info.module_name.clone();
70
71 self.hooks.insert(addr, TrackedHook::new(info));
72
73 self.by_module
74 .entry(module)
75 .or_insert_with(Vec::new)
76 .push(addr);
77 }
78
79 pub fn register_all(&mut self, hooks: impl IntoIterator<Item = HookInfo>) {
81 for hook in hooks {
82 self.register(hook);
83 }
84 }
85
86 pub fn mark_removed(&mut self, address: usize) {
88 if let Some(tracked) = self.hooks.get_mut(&address) {
89 tracked.state = HookState::Removed;
90 tracked.last_changed = std::time::Instant::now();
91 }
92 }
93
94 pub fn mark_restored(&mut self, address: usize) {
96 if let Some(tracked) = self.hooks.get_mut(&address) {
97 tracked.state = HookState::Restored;
98 tracked.last_changed = std::time::Instant::now();
99 }
100 }
101
102 pub fn get(&self, address: usize) -> Option<&TrackedHook> {
104 self.hooks.get(&address)
105 }
106
107 pub fn get_by_module(&self, module_name: &str) -> Vec<&TrackedHook> {
109 self.by_module
110 .get(module_name)
111 .map(|addrs| {
112 addrs
113 .iter()
114 .filter_map(|&addr| self.hooks.get(&addr))
115 .collect()
116 })
117 .unwrap_or_default()
118 }
119
120 pub fn active_hooks(&self) -> Vec<&TrackedHook> {
122 self.hooks
123 .values()
124 .filter(|h| h.state == HookState::Active)
125 .collect()
126 }
127
128 pub fn removed_hooks(&self) -> Vec<&TrackedHook> {
130 self.hooks
131 .values()
132 .filter(|h| h.state == HookState::Removed)
133 .collect()
134 }
135
136 pub fn get_by_type(&self, hook_type: HookType) -> Vec<&TrackedHook> {
138 self.hooks
139 .values()
140 .filter(|h| h.info.hook_type == hook_type)
141 .collect()
142 }
143
144 pub fn count(&self) -> usize {
146 self.hooks.len()
147 }
148
149 pub fn active_count(&self) -> usize {
151 self.hooks
152 .values()
153 .filter(|h| h.state == HookState::Active)
154 .count()
155 }
156
157 pub fn removed_count(&self) -> usize {
159 self.hooks
160 .values()
161 .filter(|h| h.state == HookState::Removed)
162 .count()
163 }
164
165 pub fn is_tracked(&self, address: usize) -> bool {
167 self.hooks.contains_key(&address)
168 }
169
170 pub fn unregister(&mut self, address: usize) -> Option<TrackedHook> {
172 if let Some(hook) = self.hooks.remove(&address) {
173 if let Some(addrs) = self.by_module.get_mut(&hook.info.module_name) {
175 addrs.retain(|&a| a != address);
176 }
177 Some(hook)
178 } else {
179 None
180 }
181 }
182
183 pub fn clear(&mut self) {
185 self.hooks.clear();
186 self.by_module.clear();
187 }
188
189 pub fn modules(&self) -> Vec<&str> {
191 self.by_module.keys().map(|s| s.as_str()).collect()
192 }
193
194 pub fn stats(&self) -> HookStats {
196 let mut stats = HookStats::default();
197
198 for hook in self.hooks.values() {
199 match hook.state {
200 HookState::Active => stats.active += 1,
201 HookState::Removed => stats.removed += 1,
202 HookState::Restored => stats.restored += 1,
203 }
204
205 match hook.info.hook_type {
206 HookType::JmpRel32 => stats.jmp_rel32 += 1,
207 HookType::JmpIndirect => stats.jmp_indirect += 1,
208 HookType::MovJmpRax => stats.mov_jmp_rax += 1,
209 HookType::PushRet => stats.push_ret += 1,
210 HookType::Breakpoint => stats.breakpoints += 1,
211 HookType::Unknown => stats.unknown += 1,
212 }
213 }
214
215 stats.total = self.hooks.len();
216 stats.modules = self.by_module.len();
217
218 stats
219 }
220}
221
222impl Default for HookTracker {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228#[derive(Debug, Default, Clone)]
230pub struct HookStats {
231 pub total: usize,
232 pub active: usize,
233 pub removed: usize,
234 pub restored: usize,
235 pub modules: usize,
236 pub jmp_rel32: usize,
237 pub jmp_indirect: usize,
238 pub mov_jmp_rax: usize,
239 pub push_ret: usize,
240 pub breakpoints: usize,
241 pub unknown: usize,
242}
243
244impl std::fmt::Display for HookStats {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 writeln!(f, "Hook Statistics:")?;
247 writeln!(f, " Total: {}", self.total)?;
248 writeln!(f, " Active: {}", self.active)?;
249 writeln!(f, " Removed: {}", self.removed)?;
250 writeln!(f, " Restored: {}", self.restored)?;
251 writeln!(f, " Modules: {}", self.modules)?;
252 writeln!(f, " By type:")?;
253 writeln!(f, " jmp rel32: {}", self.jmp_rel32)?;
254 writeln!(f, " jmp indirect: {}", self.jmp_indirect)?;
255 writeln!(f, " mov rax; jmp rax: {}", self.mov_jmp_rax)?;
256 writeln!(f, " push; ret: {}", self.push_ret)?;
257 writeln!(f, " breakpoints: {}", self.breakpoints)?;
258 writeln!(f, " unknown: {}", self.unknown)
259 }
260}
261
262pub fn init_global_tracker() -> bool {
268 match GLOBAL_TRACKER.lock() {
269 Ok(mut guard) => {
270 if guard.is_none() {
271 *guard = Some(HookTracker::new());
272 }
273 true
274 }
275 Err(poisoned) => {
276 let mut guard = poisoned.into_inner();
278 if guard.is_none() {
279 *guard = Some(HookTracker::new());
280 }
281 true
282 }
283 }
284}
285
286pub fn global_tracker() -> Option<std::sync::MutexGuard<'static, Option<HookTracker>>> {
290 match GLOBAL_TRACKER.lock() {
291 Ok(guard) => Some(guard),
292 Err(poisoned) => {
293 Some(poisoned.into_inner())
295 }
296 }
297}
298
299pub fn with_global_tracker<F, R>(f: F) -> Option<R>
303where
304 F: FnOnce(&mut HookTracker) -> R,
305{
306 let mut guard = global_tracker()?;
307 guard.as_mut().map(f)
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 fn dummy_hook(name: &str, addr: usize) -> HookInfo {
315 HookInfo {
316 function_name: name.to_string(),
317 function_address: addr,
318 hook_type: HookType::JmpRel32,
319 hook_destination: Some(0xDEADBEEF),
320 original_bytes: vec![0x90; 5],
321 hooked_bytes: vec![0xE9, 0x00, 0x00, 0x00, 0x00],
322 module_name: "test.dll".to_string(),
323 }
324 }
325
326 #[test]
327 fn test_tracker_basic() {
328 let mut tracker = HookTracker::new();
329
330 tracker.register(dummy_hook("NtReadVirtualMemory", 0x1000));
331 tracker.register(dummy_hook("NtWriteVirtualMemory", 0x2000));
332
333 assert_eq!(tracker.count(), 2);
334 assert_eq!(tracker.active_count(), 2);
335
336 tracker.mark_removed(0x1000);
337 assert_eq!(tracker.active_count(), 1);
338 assert_eq!(tracker.removed_count(), 1);
339 }
340
341 #[test]
342 fn test_stats() {
343 let mut tracker = HookTracker::new();
344
345 tracker.register(dummy_hook("Func1", 0x1000));
346 tracker.register(dummy_hook("Func2", 0x2000));
347 tracker.mark_removed(0x1000);
348
349 let stats = tracker.stats();
350 assert_eq!(stats.total, 2);
351 assert_eq!(stats.active, 1);
352 assert_eq!(stats.removed, 1);
353 assert_eq!(stats.jmp_rel32, 2);
354 }
355}