safe_hook/
lib.rs

1//! Safe-Hook is an inline hook library for Rust.
2//! It provides a simple and safe way to create hooks in your Rust applications,
3//! allowing you to modify the behavior of functions at runtime.
4//! 
5//! The design principle of Safe-Hook is safety and simplicity.
6//! 
7//! ## Features
8//! - **Inline Hooking**: Safe-Hook allows you to hook into functions at runtime,
9//!   enabling you to modify their behavior.
10//! - **Safe and Simple**: The library is designed to be safe and easy to use,
11//!   it checks types of parameters and return values at runtime to ensure safety.
12//! - **Full Dynamic**: Safe-Hook is fully dynamic,
13//!   allowing you to add and remove hooks at runtime without any restrictions.
14//! - **Cross-Platform**: Safe-Hook is designed to work on multiple platforms,
15//!   it theoretically supports all platforms that Rust supports.
16//! 
17//! ## Limitations
18//! - **Intrusive**: Needs to annotate target functions manually.
19//!   Which means it's not suitable for hook third-party libraries.
20//! 
21//! 
22//! ## Usage
23//! More Examples:
24//! - [Hook a function with reference parameters](#hook-a-function-with-reference-parameters)
25//! 
26//! Simple Usage:
27//! ```rust
28//! use std::sync::Arc;
29//! use safe_hook::{lookup_hookable, Hook};
30//! use safe_hook_macros::hookable;
31//! 
32//! #[hookable("add")]
33//! fn add(left: i64, right: i64) -> i64 {
34//!     left + right
35//! }
36//! 
37//! #[derive(Debug)]
38//! struct HookAdd {
39//!     x: i64,
40//! }
41//! 
42//! impl Hook for HookAdd {
43//!     type Args<'a> = (i64, i64);
44//!     type Result = i64;
45//!     fn call(&self, args: (i64, i64), next: &dyn Fn((i64, i64)) -> i64) -> i64 {
46//!         next(args) + self.x
47//!     }
48//! }
49//! 
50//! fn main() {
51//!     let hook = Arc::new(HookAdd {
52//!         x: 1,
53//!     });
54//!     assert_eq!(add(1, 2), 3);
55//!     lookup_hookable("add").unwrap().add_hook(hook).unwrap();
56//!     assert_eq!(add(1, 2), 4);
57//! }
58//! ```
59//! 
60//! ## Performance
61//! Extra overhead:
62//! - No Hook Added: One atomic load and one branch jump,
63//!   which should be very lightweight in most cases.
64//! - Hooks Added: There is a read/write lock (just some atomic operations in most cases),
65//!   some additional function calls via pointers,
66//!   and some copy operations to pack parameters into a tuple.
67//! 
68//! A sloppy benchmark (uses 12700H) shows that the extra overhead is
69//! about 0.5ns when no hooks are added
70//! (as a comparison, an `add(a,b)` function takes about 0.5ns),
71//! about 14ns when hooks are added,
72//! and that each additional hook results in about 2ns of overhead.
73
74use std::any::TypeId;
75use std::cell::Cell;
76use std::sync::atomic::AtomicBool;
77use std::sync::{Arc, LazyLock, RwLock};
78
79#[doc(hidden)]
80pub use inventory;
81
82pub use safe_hook_macros::hookable;
83/// A Trait for hooks.
84/// Implements this trait to create a hook.
85pub trait Hook: Send + Sync + 'static {
86    /// The arguments type of the hook. Must be a tuple.
87    /// Must be the same as the arguments of the target hookable function you want to hook.
88    type Args<'b>;
89
90    /// The result type of the hook.
91    /// Must be the same as the result of the target hookable function.
92    type Result;
93
94    /// The hook function.
95    /// This will be called when the target function is called.
96    /// # Parameters:
97    /// - `args`: The arguments of the target function.
98    /// - `next`: The next function to call. This is the next hook or original target function.
99    fn call<'a>(
100        &'a self,
101        args: Self::Args<'a>,
102        next: &dyn for<'c> Fn(Self::Args<'c>) -> Self::Result,
103    ) -> Self::Result;
104}
105
106/// A trait for dynamic dispatch of hooks.
107/// # Safety
108/// **THIS TRAIT SHOULD NEVER BE IMPLEMENTED BY USER CODE.**
109#[doc(hidden)]
110pub unsafe trait HookDyn: Send + Sync {
111    fn get_call_fn(&self) -> *const ();
112    fn type_info(&self) -> (TypeId, TypeId);
113}
114
115/// A wrapper layer to avoid the calling convention difference between &T and *const ().
116unsafe fn hook_call_wrapper<'a, T: Hook + 'static>(
117    self_ptr: *const (),
118    args: <T as Hook>::Args<'a>,
119    next: &dyn for<'b> Fn(<T as Hook>::Args<'b>) -> <T as Hook>::Result,
120) -> <T as Hook>::Result {
121    let self_ref = unsafe { &*(self_ptr as *const T) };
122    self_ref.call(args, next)
123}
124
125unsafe impl<T: Hook + 'static> HookDyn for T {
126    fn get_call_fn(&self) -> *const () {
127        hook_call_wrapper::<T> as *const ()
128    }
129    fn type_info(&self) -> (TypeId, TypeId) {
130        let res = TypeId::of::<<T as Hook>::Result>();
131        let args = TypeId::of::<<T as Hook>::Args<'static>>();
132        (res, args)
133    }
134}
135
136/// A registry entry for hookable functions.
137#[doc(hidden)]
138pub struct HookableFuncRegistry {
139    metadata: &'static LazyLock<HookableFuncMetadata>,
140}
141impl HookableFuncRegistry {
142    pub const fn new(metadata: &'static LazyLock<HookableFuncMetadata>) -> Self {
143        Self { metadata }
144    }
145}
146
147inventory::collect!(HookableFuncRegistry);
148
149/// Lookup a hookable function by name.
150pub fn lookup_hookable(name: &str) -> Option<&'static HookableFuncMetadata> {
151    // struct MyHashBuilder;
152    // impl BuildHasher for MyHashBuilder {
153    //     type Hasher = DefaultHasher;
154    //     fn build_hasher(&self) -> Self::Hasher {
155    //         DefaultHasher::new()
156    //     }
157    // }
158    // static CACHE: Mutex<HashMap<String, &'static LazyLock<HookableFuncMetadata>, MyHashBuilder>> = Mutex::new(HashMap::with_hasher(MyHashBuilder{}));
159
160    for item in inventory::iter::<HookableFuncRegistry> {
161        if item.metadata.name == name {
162            return Some(item.metadata);
163        }
164    }
165    None
166}
167
168struct HookableFuncPtr(*const ());
169unsafe impl Send for HookableFuncPtr {}
170unsafe impl Sync for HookableFuncPtr {}
171
172/// Metadata for a hookable function.
173#[doc(hidden)]
174pub struct HookableFuncMetadata {
175    name: String,
176    func: HookableFuncPtr,
177    type_info: (TypeId, TypeId),
178    fast_path_flag: &'static AtomicBool,
179    hooks: RwLock<Vec<(Arc<dyn HookDyn>, i32)>>,
180}
181impl HookableFuncMetadata {
182    /// Create a new [`HookableFuncMetadata`].
183    /// # Safety
184    /// This function is unsafe because it takes a raw pointer to a function without type checking.
185    /// It is used inside the macro [`hookable`] to create a new [`HookableFuncMetadata`] instance.
186    /// **THIS FUNCTION SHOULD NOT BE CALLED DIRECTLY.**
187    #[doc(hidden)]
188    pub unsafe fn new(
189        name: String,
190        func: *const (),
191        type_info: (TypeId, TypeId),
192        fast_path_flag: &'static AtomicBool,
193    ) -> Self {
194        Self {
195            name,
196            func: HookableFuncPtr(func),
197            type_info,
198            fast_path_flag,
199            hooks: RwLock::new(Vec::new()),
200        }
201    }
202
203    /// Get the name of the hookable function.
204    pub fn name(&self) -> &str {
205        &self.name
206    }
207
208    /// Get the pointer to the hookable function.
209    pub fn func_ptr(&self) -> *const () {
210        self.func.0
211    }
212
213    /// Add a hook to the hookable function.
214    /// The greatest priority will be called first.
215    pub fn add_hook_with_priority(
216        &self,
217        hook: Arc<dyn HookDyn>,
218        priority: i32,
219    ) -> Result<(), String> {
220        if hook.type_info() != self.type_info {
221            return Err(format!(
222                "Hook type mismatch: expected {:?}, got {:?}",
223                self.type_info,
224                hook.type_info()
225            ));
226        }
227        let mut hooks = self.hooks.write().unwrap();
228        let pos = hooks
229            .iter()
230            .position(|h| h.1 <= priority)
231            .unwrap_or(hooks.len());
232        hooks.insert(pos, (hook, priority));
233        self.fast_path_flag
234            .store(true, std::sync::atomic::Ordering::Release);
235        Ok(())
236    }
237
238    /// Add a hook to the hookable function with default (0) priority.
239    pub fn add_hook(&self, hook: Arc<dyn HookDyn>) -> Result<(), String> {
240        self.add_hook_with_priority(hook, 0)
241    }
242
243    /// Remove a hook from the hookable function.
244    pub fn remove_hook(&self, hook: &dyn HookDyn) -> bool {
245        let mut hooks = self.hooks.write().unwrap();
246        if let Some(pos) = hooks
247            .iter()
248            .position(|h| std::ptr::addr_eq(h.0.as_ref(), hook))
249        {
250            hooks.remove(pos);
251            if hooks.is_empty() {
252                self.fast_path_flag
253                    .store(false, std::sync::atomic::Ordering::Relaxed);
254            }
255            true
256        } else {
257            false
258        }
259    }
260
261    /// Clear all hooks from the hookable function.
262    pub fn clear_hooks(&self) {
263        let mut hooks = self.hooks.write().unwrap();
264        hooks.clear();
265        self.fast_path_flag
266            .store(false, std::sync::atomic::Ordering::Relaxed);
267    }
268}
269
270/// Call a hookable function with hooks.
271#[doc(hidden)]
272pub fn call_with_hook<R, A>(func: fn(A) -> R, meta: &'static HookableFuncMetadata, args: A) -> R {
273    let hooks = meta.hooks.read().unwrap();
274    let pos = Cell::new(0);
275    #[allow(clippy::type_complexity)]
276    let next_fn_ref: Cell<Option<&dyn Fn(A) -> R>> = Cell::new(None);
277    type HookFn<A, R> = fn(*const (), args: A, next: &dyn Fn(A) -> R) -> R;
278    let next_fn = |args: A| {
279        if pos.get() < hooks.len() {
280            let hook = hooks[pos.get()].0.as_ref();
281            // SAFETY: get_call_fn should return a function pointer to hook_call_wrapper<A>
282            let f: HookFn<A, R> = unsafe { std::mem::transmute(hook.get_call_fn()) };
283            pos.set(pos.get() + 1);
284            let res = f(
285                hook as *const dyn HookDyn as *const (),
286                args,
287                // SAFETY: next_fn_ref must be set before calling next_fn
288                unsafe { next_fn_ref.get().unwrap_unchecked() },
289            );
290            pos.set(pos.get() - 1);
291            res
292        } else {
293            func(args)
294        }
295    };
296    next_fn_ref.set(Some(&next_fn));
297    next_fn(args)
298}