rustic_macros/lib.rs
1#![doc = include_str!("../README.md")]
2
3extern crate proc_macro;
4use proc_macro::TokenStream;
5use quote::quote;
6//use syn::punctuated::Punctuated;
7use syn::{parse_macro_input, AttributeArgs, ItemFn, NestedMeta};
8
9/// Macro used to add guards to functions. If the guard function returns `Ok` then execution would continue.
10/// If the guard function returns `Err` the execution would stop and state changes would be reverted by invoking `ic_cdk::api::call::reject`.
11///
12/// If multiple guard functions exist they're executed in the order declared.
13///
14/// Usually the `modifiers` macro should be the last macro declared for a function, below everything else such as `update` and `query`.
15/// Those are the only macros tested to work together with this `modifiers` macro.
16/// If you would like to use it with other macros or use a different order, remember Rust macro ordering rules apply, and always check the expanded result.
17///
18/// The signatures of guard functions must be of
19/// ```
20/// fn guard_func<T>(param: T) -> Result<(), String>
21/// # { Ok(()) }
22/// ```
23/// Then the guard functions can be added as modifiers using the follow syntax:
24/// ```ignore
25/// # use rustic_macros::modifiers;
26/// #[modifiers("guard_func@param", ...)]
27/// fn my_func() {}
28/// ```
29/// The `param` can be a commma-separated list(without spaces in between) of the arguments.
30/// # Examples
31/// ```
32/// # fn guard_func0() -> Result<(), String>
33/// # { Ok(()) }
34/// # fn guard_func1() -> Result<(), String>
35/// # { Ok(()) }
36/// # fn guard_func(a: u8, b: SomeEnum) -> Result<(), String>
37/// # { Ok(()) }
38/// # enum SomeEnum {A,B}
39/// # use rustic_macros::modifiers;
40/// #[modifiers("guard_func0")]
41/// fn my_func0() {}
42/// #[modifiers("guard_func@42,SomeEnum::A")]
43/// fn my_func1() {}
44/// #[modifiers("guard_func0", "guard_func1")]
45/// fn my_func2() {}
46/// ```
47#[proc_macro_attribute]
48pub fn modifiers(args: TokenStream, input: TokenStream) -> TokenStream {
49 let args = parse_macro_input!(args as AttributeArgs);
50 let func = parse_macro_input!(input as ItemFn);
51
52 // Capture asyncness and visibility
53 let asyncness = &func.sig.asyncness;
54 let vis = &func.vis;
55 let ret_type = &func.sig.output;
56 let generics = &func.sig.generics;
57 let where_clause = &func.sig.generics.where_clause;
58 let attrs = &func.attrs;
59 let _unsafety = &func.sig.unsafety;
60 let _abi = &func.sig.abi;
61
62 // Transform attribute arguments to strings
63 let mut modifiers: Vec<(String, Vec<String>)> = Vec::new();
64 for arg in args {
65 if let NestedMeta::Lit(syn::Lit::Str(lit)) = arg {
66 let val = lit.value();
67 let parts: Vec<_> = val.split('@').collect();
68 let func_name = parts[0].to_string();
69 let params = if parts.len() > 1 {
70 parts[1..]
71 .join("@")
72 .split(',')
73 .map(|s| s.trim().to_string())
74 .collect()
75 } else {
76 Vec::new()
77 };
78 modifiers.push((func_name, params));
79 }
80 }
81
82 // The function name and parameters
83 let fn_name = &func.sig.ident;
84 let fn_params = &func.sig.inputs;
85
86 // Extract statements from the original block
87 let fn_stmts = &func.block.stmts;
88
89 // Generate each modifier check and function call
90 let modifier_checks: Vec<proc_macro2::TokenStream> = modifiers
91 .iter()
92 .map(|(modi, params)| {
93 let modi_ident: syn::Ident = syn::Ident::new(modi, proc_macro2::Span::call_site());
94 let params_tokens: Vec<proc_macro2::TokenStream> = params
95 .iter()
96 .map(|p| {
97 let param_token: proc_macro2::TokenStream = p.parse().unwrap();
98 param_token
99 })
100 .collect();
101 if modi == "non_reentrant" {
102 quote! {
103 let __guard = ::rustic::reentrancy_guard::ReentrancyGuard::new();
104 }
105 } else {
106 quote! {
107 let r: Result<(), String> = #modi_ident(#(#params_tokens),*);
108 if let Err(e) = r {
109 ic_cdk::api::call::reject(&e);
110 panic!("{} failed: {}", stringify!(#modi_ident), e);
111 }
112 }
113 }
114 })
115 .collect();
116
117 let expanded = quote! {
118 #(#attrs)*
119 #vis #asyncness fn #fn_name #generics (#fn_params) #ret_type #where_clause {
120 #(#modifier_checks)*
121 #(#fn_stmts)*
122 }
123 };
124
125 TokenStream::from(expanded)
126}