rustic_macros/
lib.rs

1#![doc = include_str!("../README.md")]
2
3extern crate proc_macro;
4use proc_macro::TokenStream;
5use quote::quote;
6//use syn::punctuated::Punctuated;
7use syn::{parse_macro_input, AttributeArgs, ItemFn, NestedMeta};
8
9/// Macro used to add guards to functions. If the guard function returns `Ok` then execution would continue.
10/// If the guard function returns `Err` the execution would stop and state changes would be reverted by invoking `ic_cdk::api::call::reject`.
11///
12/// If multiple guard functions exist they're executed in the order declared.
13///
14/// Usually the `modifiers` macro should be the last macro declared for a function, below everything else such as `update` and `query`.
15/// Those are the only macros tested to work together with this `modifiers` macro.
16/// If you would like to use it with other macros or use a different order, remember Rust macro ordering rules apply, and always check the expanded result.
17///
18/// The signatures of guard functions must be of
19/// ```
20/// fn guard_func<T>(param: T) -> Result<(), String>
21/// # { Ok(()) }
22/// ```
23/// Then the guard functions can be added as modifiers using the follow syntax:
24/// ```ignore
25/// # use rustic_macros::modifiers;
26/// #[modifiers("guard_func@param", ...)]
27/// fn my_func() {}
28/// ```
29/// The `param` can be a commma-separated list(without spaces in between) of the arguments.
30/// # Examples
31/// ```
32/// # fn guard_func0() -> Result<(), String>
33/// # { Ok(()) }
34/// # fn guard_func1() -> Result<(), String>
35/// # { Ok(()) }
36/// # fn guard_func(a: u8, b: SomeEnum) -> Result<(), String>
37/// # { Ok(()) }
38/// # enum SomeEnum {A,B}
39/// # use rustic_macros::modifiers;
40/// #[modifiers("guard_func0")]
41/// fn my_func0() {}
42/// #[modifiers("guard_func@42,SomeEnum::A")]
43/// fn my_func1() {}
44/// #[modifiers("guard_func0", "guard_func1")]
45/// fn my_func2() {}
46/// ```
47#[proc_macro_attribute]
48pub fn modifiers(args: TokenStream, input: TokenStream) -> TokenStream {
49    let args = parse_macro_input!(args as AttributeArgs);
50    let func = parse_macro_input!(input as ItemFn);
51
52    // Capture asyncness and visibility
53    let asyncness = &func.sig.asyncness;
54    let vis = &func.vis;
55    let ret_type = &func.sig.output;
56    let generics = &func.sig.generics;
57    let where_clause = &func.sig.generics.where_clause;
58    let attrs = &func.attrs;
59    let _unsafety = &func.sig.unsafety;
60    let _abi = &func.sig.abi;
61
62    // Transform attribute arguments to strings
63    let mut modifiers: Vec<(String, Vec<String>)> = Vec::new();
64    for arg in args {
65        if let NestedMeta::Lit(syn::Lit::Str(lit)) = arg {
66            let val = lit.value();
67            let parts: Vec<_> = val.split('@').collect();
68            let func_name = parts[0].to_string();
69            let params = if parts.len() > 1 {
70                parts[1..]
71                    .join("@")
72                    .split(',')
73                    .map(|s| s.trim().to_string())
74                    .collect()
75            } else {
76                Vec::new()
77            };
78            modifiers.push((func_name, params));
79        }
80    }
81
82    // The function name and parameters
83    let fn_name = &func.sig.ident;
84    let fn_params = &func.sig.inputs;
85
86    // Extract statements from the original block
87    let fn_stmts = &func.block.stmts;
88
89    // Generate each modifier check and function call
90    let modifier_checks: Vec<proc_macro2::TokenStream> = modifiers
91        .iter()
92        .map(|(modi, params)| {
93            let modi_ident: syn::Ident = syn::Ident::new(modi, proc_macro2::Span::call_site());
94            let params_tokens: Vec<proc_macro2::TokenStream> = params
95                .iter()
96                .map(|p| {
97                    let param_token: proc_macro2::TokenStream = p.parse().unwrap();
98                    param_token
99                })
100                .collect();
101            if modi == "non_reentrant" {
102                quote! {
103                    let __guard = ::rustic::reentrancy_guard::ReentrancyGuard::new();
104                }
105            } else {
106                quote! {
107                    let r: Result<(), String> = #modi_ident(#(#params_tokens),*);
108                    if let Err(e) = r {
109                        ic_cdk::api::call::reject(&e);
110                        panic!("{} failed: {}", stringify!(#modi_ident), e);
111                    }
112                }
113            }
114        })
115        .collect();
116
117    let expanded = quote! {
118        #(#attrs)*
119        #vis #asyncness fn #fn_name #generics (#fn_params) #ret_type #where_clause {
120            #(#modifier_checks)*
121            #(#fn_stmts)*
122        }
123    };
124
125    TokenStream::from(expanded)
126}