safe_hook_macros/
lib.rs

1//! This create provides a convenient macro [macro@hookable] to mark functions, allowing them to be hooked.
2//! 
3//! See `safe-hook` crate for more details.
4
5
6
7use proc_macro::TokenStream;
8use quote::{ToTokens, format_ident, quote};
9use syn::parse::Parse;
10use syn::{ItemFn, LitStr, parse_macro_input};
11struct HookableProcArgs {
12    name: LitStr,
13    // args: Punctuated<MetaNameValue, Token![,]>,
14}
15
16impl Parse for HookableProcArgs {
17    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
18        let name = input.parse::<LitStr>()?;
19        // let args = Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)?;
20        Ok(HookableProcArgs {
21            name,
22            // args
23        })
24    }
25}
26
27fn gen_args_name_list(f: &ItemFn) -> proc_macro2::TokenStream {
28    // fn xxx(a:ta,b:tb,c:tc) -> td;  ==> a,b,c
29    let mut args = Vec::new();
30    for arg in f.sig.inputs.iter() {
31        if let syn::FnArg::Typed(pat_type) = arg {
32            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
33                args.push(pat_ident.ident.clone());
34            } else {
35                panic!("Argument pattern is not supported");
36            }
37        }
38    }
39    quote! {
40        #(#args),*
41    }
42}
43
44fn get_hookable_lifetime(f: &ItemFn) -> Option<proc_macro2::TokenStream> {
45    if f.sig.generics.where_clause.is_some() {
46        panic!("Where clause is not supported");
47    }
48    match f.sig.generics.params.iter().count() {
49        0 => None,
50        1 => {
51            let g = f.sig.generics.params.iter().next().unwrap();
52            if let syn::GenericParam::Lifetime(lifetime) = g {
53                Some(quote! { #lifetime })
54            } else {
55                panic!(
56                    "Hookable cannot be used with generic '{}'",
57                    g.to_token_stream()
58                );
59            }
60        }
61        _ => panic!(
62            "Hookable cannot be used with more than one generics <{}>",
63            f.sig.generics.params.to_token_stream()
64        ),
65    }
66}
67/// This macro is used to mark a function as hookable, without changing the signature.
68/// It generates some extra codes to support hooks, and registers the function to the inventory.
69/// 
70/// Not Supported:
71/// - functions with generic types
72/// - functions with `self` receiver
73/// - functions returns references
74/// 
75/// # Examples:
76/// ```
77/// #[hookable("add")]
78/// fn add(left: i64, right: i64) -> i64 {
79///    left + right
80/// }
81/// ```
82#[proc_macro_attribute]
83pub fn hookable(args: TokenStream, input: TokenStream) -> TokenStream {
84    let args = parse_macro_input!(args as HookableProcArgs);
85    let input_fn = parse_macro_input!(input as ItemFn);
86
87    let input_fn_ident = input_fn.sig.ident.clone();
88
89    let _ = get_hookable_lifetime(&input_fn);
90    let generics = input_fn.sig.generics.clone();
91
92    let input_type = input_fn
93        .sig
94        .inputs
95        .iter()
96        .map(|arg| match arg {
97            syn::FnArg::Typed(pat_type) => {
98                if !matches!(&*pat_type.pat, syn::Pat::Ident(_)) {
99                    panic!("Argument pattern is not supported");
100                }
101                pat_type.ty.clone()
102            }
103            syn::FnArg::Receiver(_) => panic!("Method receiver (self) is not supported"),
104        })
105        .collect::<Vec<_>>();
106    let input_type_with_static_lifetime = input_type
107        .iter()
108        .map(|ty| {
109            if let syn::Type::Reference(ref_ty) = &**ty {
110                let mut ref_ty = ref_ty.clone();
111                ref_ty.lifetime = Some(syn::Lifetime::new("'static", proc_macro2::Span::call_site()));
112                quote! { #ref_ty }
113            } else {
114                quote! { #ty }
115            }
116        })
117        .collect::<Vec<_>>();
118
119    let ret_type = match &input_fn.sig.output {
120        syn::ReturnType::Default => quote! { () },
121        syn::ReturnType::Type(_, ty) => quote! { #ty },
122    };
123
124    let func_type = quote! {
125        fn(#(#input_type),*) -> #ret_type
126    };
127
128    let hookable_name = args.name;
129
130    let args_name_list = gen_args_name_list(&input_fn);
131
132    let mut inner_fn = input_fn.clone();
133    inner_fn.sig.ident = format_ident!("__hookable_inner");
134    let fn_sig = &input_fn.sig;
135
136    let unpack_list: proc_macro2::TokenStream = (0..input_fn.sig.inputs.len())
137        .map(|i| {
138            let idx = syn::Index::from(i);
139            quote! { args.#idx, }
140        })
141        .collect();
142
143    // 原样返回函数代码
144    let generated = quote! {
145        #fn_sig {
146            #inner_fn
147
148            use ::safe_hook::HookableFuncMetadata;
149            use ::core::sync::atomic::AtomicBool;
150            use ::std::sync::LazyLock;
151            use ::std::sync::atomic::Ordering;
152
153            type SelfFunc #generics = #func_type;
154
155            static FLAG: AtomicBool = AtomicBool::new(false);
156            static META: LazyLock<HookableFuncMetadata> = LazyLock::new(|| {
157                let metadata = unsafe {
158                    HookableFuncMetadata::new(
159                        #hookable_name.to_string(),
160                        #input_fn_ident as *const (),
161                        (
162                            std::any::TypeId::of::<#ret_type>(),
163                            std::any::TypeId::of::<(#(#input_type_with_static_lifetime),*)>(),
164                        ),
165                        &FLAG,
166                    )
167                };
168                metadata
169            });
170            ::safe_hook::inventory::submit! {
171                ::safe_hook::HookableFuncRegistry::new(&META)
172            }
173            if !FLAG.load(Ordering::Acquire) {
174                return __hookable_inner(#args_name_list);
175            }
176            ::safe_hook::call_with_hook::<#ret_type, (#(#input_type),*)>(|args| __hookable_inner(#unpack_list), &META, (#args_name_list))
177        }
178    };
179    generated.into()
180}