Skip to main content

vrs_core_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse_macro_input, parse_quote, visit_mut::VisitMut, Attribute, FnArg, Ident, ItemFn, ItemMod,
5    PatType, ReturnType, Type,
6};
7
8#[derive(Clone)]
9struct ApiEntry {
10    name: String,
11    method: String,
12    param_types: Vec<Box<Type>>,
13    return_type: Box<Type>,
14}
15
16struct ApiVisitor {
17    entries: Vec<ApiEntry>,
18}
19
20fn find_entry(attrs: &[Attribute]) -> Option<&str> {
21    for attr in attrs {
22        if attr.path().is_ident("get") {
23            return Some("get");
24        } else if attr.path().is_ident("post") {
25            return Some("post");
26        }
27    }
28    None
29}
30
31impl VisitMut for ApiVisitor {
32    fn visit_item_fn_mut(&mut self, item: &mut syn::ItemFn) {
33        if let Some(method) = find_entry(&item.attrs) {
34            let name = item.sig.ident.to_string();
35            let param_types: Vec<_> = item
36                .sig
37                .inputs
38                .iter()
39                .filter_map(|arg| match arg {
40                    FnArg::Typed(PatType { ty, .. }) => Some((*ty).clone()),
41                    _ => None,
42                })
43                .collect();
44            let return_type = match &item.sig.output {
45                ReturnType::Default => parse_quote!(()),
46                ReturnType::Type(_, ty) => (*ty).clone(),
47            };
48            self.entries.push(ApiEntry {
49                name,
50                method: method.to_string(),
51                param_types,
52                return_type,
53            });
54        }
55        syn::visit_mut::visit_item_fn_mut(self, item);
56    }
57}
58
59#[proc_macro_attribute]
60pub fn nucleus(_attr: TokenStream, item: TokenStream) -> TokenStream {
61    let mut input_mod = parse_macro_input!(item as ItemMod);
62    let mut visitor = ApiVisitor {
63        entries: Vec::new(),
64    };
65    visitor.visit_item_mod_mut(&mut input_mod);
66    let entries: Vec<_> = visitor
67        .entries
68        .iter()
69        .map(|entry| {
70            let name = &entry.name;
71            let method = &entry.method;
72            let param_types = entry
73                .param_types
74                .iter()
75                .map(|ty| ty.clone())
76                .collect::<Vec<_>>();
77            let return_type = &entry.return_type;
78            quote! {
79                registry.register_api(
80                    #name.to_string(),
81                    #method.to_string(),
82                    vec![#(::vrs_core_sdk::scale_info::meta_type::<#param_types>(),)*],
83                    ::vrs_core_sdk::scale_info::meta_type::<#return_type>(),
84                );
85            }
86        })
87        .collect::<Vec<_>>();
88    if let Some((_, ref mut items)) = input_mod.content {
89        items.push(parse_quote! {
90            vrs_core_sdk::lazy_static::lazy_static! {
91                static ref TYPES: ::vrs_core_sdk::abi::ApiRegistry = {
92                    let mut registry = ::vrs_core_sdk::abi::ApiRegistry::new();
93                    #(#entries)*
94                    registry
95                };
96            }
97        });
98        items.push(parse_quote! {
99            #[no_mangle]
100            pub fn __nucleus_abi() -> *const u8 {
101                let abi = TYPES.dump_abi();
102                let encoded = <::vrs_core_sdk::abi::JsonAbi as ::vrs_core_sdk::codec::Encode>::encode(&abi);
103                let dummy_encoded = Some(encoded);
104                let encoded = <Option<Vec<u8>> as ::vrs_core_sdk::codec::Encode>::encode(&dummy_encoded);
105                let len = encoded.len() as u32;
106                let mut output = Vec::with_capacity(4 + len as usize);
107                output.extend_from_slice(&len.to_ne_bytes());
108                output.extend_from_slice(&encoded);
109                let ptr = output.as_ptr();
110                std::mem::forget(output);
111                ptr
112            }
113        });
114    }
115    quote! {
116        #input_mod
117    }
118    .into()
119}
120
121#[proc_macro_attribute]
122pub fn post(_attr: TokenStream, item: TokenStream) -> TokenStream {
123    let func = parse_macro_input!(item as ItemFn);
124    let func_name = format_ident!("__nucleus_{}_{}", "post", &func.sig.ident);
125    expand(func, func_name)
126}
127
128#[proc_macro_attribute]
129pub fn get(_attr: TokenStream, item: TokenStream) -> TokenStream {
130    let func = parse_macro_input!(item as ItemFn);
131    let func_name = format_ident!("__nucleus_{}_{}", "get", &func.sig.ident);
132    expand(func, func_name)
133}
134
135#[proc_macro_attribute]
136pub fn init(_attr: TokenStream, item: TokenStream) -> TokenStream {
137    let func = parse_macro_input!(item as ItemFn);
138    let func_name = format_ident!("__nucleus_init");
139    expand(func, func_name)
140}
141
142#[proc_macro_attribute]
143pub fn timer(_attr: TokenStream, item: TokenStream) -> TokenStream {
144    let func = parse_macro_input!(item as ItemFn);
145    let func_name = format_ident!("__nucleus_{}_{}", "timer", &func.sig.ident);
146    expand(func, func_name)
147}
148
149#[proc_macro_attribute]
150pub fn callback(_attr: TokenStream, item: TokenStream) -> TokenStream {
151    let func = parse_macro_input!(item as ItemFn);
152    let func_name = format_ident!("__nucleus_http_callback");
153    expand(func, func_name)
154}
155
156fn expand(func: ItemFn, entry_name: Ident) -> TokenStream {
157    let func_block = &func.block;
158    let func_decl = &func.sig;
159    let origin_name = &func_decl.ident;
160    let func_generics = &func_decl.generics;
161    let func_inputs = &func_decl.inputs;
162    let func_output = &func_decl.output;
163    if !func_generics.params.is_empty() {
164        panic!("function should not have generics");
165    }
166    let tys: Vec<_> = func_inputs
167        .iter()
168        .map(|i| match i {
169            FnArg::Typed(ref val) => val.ty.clone(),
170            _ => unreachable!(),
171        })
172        .collect();
173    let arg_names: Vec<_> = func_inputs
174        .iter()
175        .map(|i| match i {
176            FnArg::Typed(ref val) => val.pat.clone(),
177            _ => unreachable!(),
178        })
179        .collect();
180    let out_ty = match func_output {
181        ReturnType::Default => quote! { () },
182        ReturnType::Type(_, ty) => quote! { #ty },
183    };
184    let expanded = quote! {
185        // declare the wrapper function: `fn __nucleus_XX(__ptr: *const u8, __len: usize)`
186        #[no_mangle]
187        pub fn #entry_name(__ptr: *const u8, __len: usize) -> *const u8 {
188            // rewrite the original function `fn(x: X, y: Y)` to `fn((x, y, z): (X, Y, Z))`
189            fn #origin_name((#(#arg_names,)*): (#(#tys,)*)) #func_output #func_block
190            // the VM has passed the raw parameters, now decode them within VM
191            let mut v = unsafe { std::slice::from_raw_parts(__ptr, __len) };
192            let decoding_result = <(#(#tys,)*) as ::vrs_core_sdk::codec::Decode>::decode(&mut v);
193            let result: Option<Vec<u8>> = match decoding_result {
194                Ok(decoded) => {
195                    let ret = #origin_name(decoded);
196                    let encoded = <#out_ty as ::vrs_core_sdk::codec::Encode>::encode(&ret);
197                    Option::<Vec<u8>>::Some(encoded)
198                }
199                Err(_) => None::<Vec<u8>>,
200            };
201            let encoded = <Option<Vec<u8>> as ::vrs_core_sdk::codec::Encode>::encode(&result);
202            let len = encoded.len() as u32;
203            let mut output = Vec::with_capacity(4 + len as usize);
204            output.extend_from_slice(&len.to_ne_bytes());
205            output.extend_from_slice(&encoded);
206            let ptr = output.as_ptr();
207            std::mem::forget(output);
208            ptr
209        }
210    };
211    expanded.into()
212}