zrl_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemFn, Lit, Meta};
4
5#[proc_macro_attribute]
6pub fn lru(attr: TokenStream, item: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(item as ItemFn);
8
9    let capacity = if attr.is_empty() {
10        100
11    } else {
12        let meta = parse_macro_input!(attr as Meta);
13        match meta {
14            Meta::NameValue(nv) if nv.path.is_ident("capacity") => {
15                if let syn::Expr::Lit(expr_lit) = &nv.value {
16                    if let Lit::Int(lit_int) = &expr_lit.lit {
17                        lit_int.base10_parse::<usize>().unwrap_or(100)
18                    } else {
19                        100
20                    }
21                } else {
22                    100
23                }
24            }
25            _ => 100,
26        }
27    };
28
29    let fn_name = &input.sig.ident;
30    let fn_vis = &input.vis;
31    let fn_attrs = &input.attrs;
32    let fn_inputs = &input.sig.inputs;
33    let fn_output = &input.sig.output;
34    let fn_block = &input.block;
35
36    let mut param_names = vec![];
37    let mut param_types = vec![];
38
39    for input in fn_inputs {
40        if let syn::FnArg::Typed(pat_type) = input {
41            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
42                param_names.push(&pat_ident.ident);
43                param_types.push(&pat_type.ty);
44            }
45        }
46    }
47
48    let cache_key_type = if param_types.len() == 1 {
49        quote! { #(#param_types)* }
50    } else {
51        quote! { (#(#param_types),*) }
52    };
53
54    let cache_key_expr = if param_names.len() == 1 {
55        quote! { #(#param_names)*.clone() }
56    } else {
57        quote! { (#(#param_names.clone()),*) }
58    };
59
60    let return_type = match fn_output {
61        syn::ReturnType::Default => quote! { () },
62        syn::ReturnType::Type(_, ty) => quote! { #ty },
63    };
64
65    let cache_name = quote::format_ident!("{}_CACHE", fn_name.to_string().to_uppercase());
66
67    let expanded = quote! {
68        thread_local! {
69            static #cache_name: std::cell::RefCell<LRUCache<#cache_key_type, #return_type>> =
70                std::cell::RefCell::new(LRUCache::new(#capacity));
71        }
72
73        #(#fn_attrs)*
74        #fn_vis fn #fn_name(#fn_inputs) #fn_output {
75            let key = #cache_key_expr;
76
77            let cached = #cache_name.with(|cache| {
78                cache.borrow_mut().get(&key)
79            });
80
81            if let Some(result) = cached {
82                return result;
83            }
84
85            let result: #return_type = {
86                #fn_block
87            };
88
89            #cache_name.with(|cache| {
90                cache.borrow_mut().put(key, result.clone());
91            });
92
93            result
94        }
95    };
96
97    TokenStream::from(expanded)
98}