safe_hook/lib.rs
1//! Safe-Hook is an inline hook library for Rust.
2//! It provides a simple and safe way to create hooks in your Rust applications,
3//! allowing you to modify the behavior of functions at runtime.
4//!
5//! The design principle of Safe-Hook is safety and simplicity.
6//!
7//! ## Features
8//! - **Inline Hooking**: Safe-Hook allows you to hook into functions at runtime,
9//! enabling you to modify their behavior.
10//! - **Safe and Simple**: The library is designed to be safe and easy to use,
11//! it checks types of parameters and return values at runtime to ensure safety.
12//! - **Full Dynamic**: Safe-Hook is fully dynamic,
13//! allowing you to add and remove hooks at runtime without any restrictions.
14//! - **Cross-Platform**: Safe-Hook is designed to work on multiple platforms,
15//! it theoretically supports all platforms that Rust supports.
16//!
17//! ## Limitations
18//! - **Intrusive**: Needs to annotate target functions manually.
19//! Which means it's not suitable for hook third-party libraries.
20//!
21//!
22//! ## Usage
23//! More Examples:
24//! - [Hook a function with reference parameters](#hook-a-function-with-reference-parameters)
25//!
26//! Simple Usage:
27//! ```rust
28//! use std::sync::Arc;
29//! use safe_hook::{lookup_hookable, Hook};
30//! use safe_hook_macros::hookable;
31//!
32//! #[hookable("add")]
33//! fn add(left: i64, right: i64) -> i64 {
34//! left + right
35//! }
36//!
37//! #[derive(Debug)]
38//! struct HookAdd {
39//! x: i64,
40//! }
41//!
42//! impl Hook for HookAdd {
43//! type Args<'a> = (i64, i64);
44//! type Result = i64;
45//! fn call(&self, args: (i64, i64), next: &dyn Fn((i64, i64)) -> i64) -> i64 {
46//! next(args) + self.x
47//! }
48//! }
49//!
50//! fn main() {
51//! let hook = Arc::new(HookAdd {
52//! x: 1,
53//! });
54//! assert_eq!(add(1, 2), 3);
55//! lookup_hookable("add").unwrap().add_hook(hook).unwrap();
56//! assert_eq!(add(1, 2), 4);
57//! }
58//! ```
59//!
60//! ## Performance
61//! Extra overhead:
62//! - No Hook Added: One atomic load and one branch jump,
63//! which should be very lightweight in most cases.
64//! - Hooks Added: There is a read/write lock (just some atomic operations in most cases),
65//! some additional function calls via pointers,
66//! and some copy operations to pack parameters into a tuple.
67//!
68//! A sloppy benchmark (uses 12700H) shows that the extra overhead is
69//! about 0.5ns when no hooks are added
70//! (as a comparison, an `add(a,b)` function takes about 0.5ns),
71//! about 14ns when hooks are added,
72//! and that each additional hook results in about 2ns of overhead.
73
74use std::any::TypeId;
75use std::cell::Cell;
76use std::sync::atomic::AtomicBool;
77use std::sync::{Arc, LazyLock, RwLock};
78
79#[doc(hidden)]
80pub use inventory;
81
82pub use safe_hook_macros::hookable;
83/// A Trait for hooks.
84/// Implements this trait to create a hook.
85pub trait Hook: Send + Sync + 'static {
86 /// The arguments type of the hook. Must be a tuple.
87 /// Must be the same as the arguments of the target hookable function you want to hook.
88 type Args<'b>;
89
90 /// The result type of the hook.
91 /// Must be the same as the result of the target hookable function.
92 type Result;
93
94 /// The hook function.
95 /// This will be called when the target function is called.
96 /// # Parameters:
97 /// - `args`: The arguments of the target function.
98 /// - `next`: The next function to call. This is the next hook or original target function.
99 fn call<'a>(
100 &'a self,
101 args: Self::Args<'a>,
102 next: &dyn for<'c> Fn(Self::Args<'c>) -> Self::Result,
103 ) -> Self::Result;
104}
105
106/// A trait for dynamic dispatch of hooks.
107/// # Safety
108/// **THIS TRAIT SHOULD NEVER BE IMPLEMENTED BY USER CODE.**
109#[doc(hidden)]
110pub unsafe trait HookDyn: Send + Sync {
111 fn get_call_fn(&self) -> *const ();
112 fn type_info(&self) -> (TypeId, TypeId);
113}
114
115/// A wrapper layer to avoid the calling convention difference between &T and *const ().
116unsafe fn hook_call_wrapper<'a, T: Hook + 'static>(
117 self_ptr: *const (),
118 args: <T as Hook>::Args<'a>,
119 next: &dyn for<'b> Fn(<T as Hook>::Args<'b>) -> <T as Hook>::Result,
120) -> <T as Hook>::Result {
121 let self_ref = unsafe { &*(self_ptr as *const T) };
122 self_ref.call(args, next)
123}
124
125unsafe impl<T: Hook + 'static> HookDyn for T {
126 fn get_call_fn(&self) -> *const () {
127 hook_call_wrapper::<T> as *const ()
128 }
129 fn type_info(&self) -> (TypeId, TypeId) {
130 let res = TypeId::of::<<T as Hook>::Result>();
131 let args = TypeId::of::<<T as Hook>::Args<'static>>();
132 (res, args)
133 }
134}
135
136/// A registry entry for hookable functions.
137#[doc(hidden)]
138pub struct HookableFuncRegistry {
139 metadata: &'static LazyLock<HookableFuncMetadata>,
140}
141impl HookableFuncRegistry {
142 pub const fn new(metadata: &'static LazyLock<HookableFuncMetadata>) -> Self {
143 Self { metadata }
144 }
145}
146
147inventory::collect!(HookableFuncRegistry);
148
149/// Lookup a hookable function by name.
150pub fn lookup_hookable(name: &str) -> Option<&'static HookableFuncMetadata> {
151 // struct MyHashBuilder;
152 // impl BuildHasher for MyHashBuilder {
153 // type Hasher = DefaultHasher;
154 // fn build_hasher(&self) -> Self::Hasher {
155 // DefaultHasher::new()
156 // }
157 // }
158 // static CACHE: Mutex<HashMap<String, &'static LazyLock<HookableFuncMetadata>, MyHashBuilder>> = Mutex::new(HashMap::with_hasher(MyHashBuilder{}));
159
160 for item in inventory::iter::<HookableFuncRegistry> {
161 if item.metadata.name == name {
162 return Some(item.metadata);
163 }
164 }
165 None
166}
167
168struct HookableFuncPtr(*const ());
169unsafe impl Send for HookableFuncPtr {}
170unsafe impl Sync for HookableFuncPtr {}
171
172/// Metadata for a hookable function.
173#[doc(hidden)]
174pub struct HookableFuncMetadata {
175 name: String,
176 func: HookableFuncPtr,
177 type_info: (TypeId, TypeId),
178 fast_path_flag: &'static AtomicBool,
179 hooks: RwLock<Vec<(Arc<dyn HookDyn>, i32)>>,
180}
181impl HookableFuncMetadata {
182 /// Create a new [`HookableFuncMetadata`].
183 /// # Safety
184 /// This function is unsafe because it takes a raw pointer to a function without type checking.
185 /// It is used inside the macro [`hookable`] to create a new [`HookableFuncMetadata`] instance.
186 /// **THIS FUNCTION SHOULD NOT BE CALLED DIRECTLY.**
187 #[doc(hidden)]
188 pub unsafe fn new(
189 name: String,
190 func: *const (),
191 type_info: (TypeId, TypeId),
192 fast_path_flag: &'static AtomicBool,
193 ) -> Self {
194 Self {
195 name,
196 func: HookableFuncPtr(func),
197 type_info,
198 fast_path_flag,
199 hooks: RwLock::new(Vec::new()),
200 }
201 }
202
203 /// Get the name of the hookable function.
204 pub fn name(&self) -> &str {
205 &self.name
206 }
207
208 /// Get the pointer to the hookable function.
209 pub fn func_ptr(&self) -> *const () {
210 self.func.0
211 }
212
213 /// Add a hook to the hookable function.
214 /// The greatest priority will be called first.
215 pub fn add_hook_with_priority(
216 &self,
217 hook: Arc<dyn HookDyn>,
218 priority: i32,
219 ) -> Result<(), String> {
220 if hook.type_info() != self.type_info {
221 return Err(format!(
222 "Hook type mismatch: expected {:?}, got {:?}",
223 self.type_info,
224 hook.type_info()
225 ));
226 }
227 let mut hooks = self.hooks.write().unwrap();
228 let pos = hooks
229 .iter()
230 .position(|h| h.1 <= priority)
231 .unwrap_or(hooks.len());
232 hooks.insert(pos, (hook, priority));
233 self.fast_path_flag
234 .store(true, std::sync::atomic::Ordering::Release);
235 Ok(())
236 }
237
238 /// Add a hook to the hookable function with default (0) priority.
239 pub fn add_hook(&self, hook: Arc<dyn HookDyn>) -> Result<(), String> {
240 self.add_hook_with_priority(hook, 0)
241 }
242
243 /// Remove a hook from the hookable function.
244 pub fn remove_hook(&self, hook: &dyn HookDyn) -> bool {
245 let mut hooks = self.hooks.write().unwrap();
246 if let Some(pos) = hooks
247 .iter()
248 .position(|h| std::ptr::addr_eq(h.0.as_ref(), hook))
249 {
250 hooks.remove(pos);
251 if hooks.is_empty() {
252 self.fast_path_flag
253 .store(false, std::sync::atomic::Ordering::Relaxed);
254 }
255 true
256 } else {
257 false
258 }
259 }
260
261 /// Clear all hooks from the hookable function.
262 pub fn clear_hooks(&self) {
263 let mut hooks = self.hooks.write().unwrap();
264 hooks.clear();
265 self.fast_path_flag
266 .store(false, std::sync::atomic::Ordering::Relaxed);
267 }
268}
269
270/// Call a hookable function with hooks.
271#[doc(hidden)]
272pub fn call_with_hook<R, A>(func: fn(A) -> R, meta: &'static HookableFuncMetadata, args: A) -> R {
273 let hooks = meta.hooks.read().unwrap();
274 let pos = Cell::new(0);
275 #[allow(clippy::type_complexity)]
276 let next_fn_ref: Cell<Option<&dyn Fn(A) -> R>> = Cell::new(None);
277 type HookFn<A, R> = fn(*const (), args: A, next: &dyn Fn(A) -> R) -> R;
278 let next_fn = |args: A| {
279 if pos.get() < hooks.len() {
280 let hook = hooks[pos.get()].0.as_ref();
281 // SAFETY: get_call_fn should return a function pointer to hook_call_wrapper<A>
282 let f: HookFn<A, R> = unsafe { std::mem::transmute(hook.get_call_fn()) };
283 pos.set(pos.get() + 1);
284 let res = f(
285 hook as *const dyn HookDyn as *const (),
286 args,
287 // SAFETY: next_fn_ref must be set before calling next_fn
288 unsafe { next_fn_ref.get().unwrap_unchecked() },
289 );
290 pos.set(pos.get() - 1);
291 res
292 } else {
293 func(args)
294 }
295 };
296 next_fn_ref.set(Some(&next_fn));
297 next_fn(args)
298}