1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use sha2::{Digest, Sha256};
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, ReturnType, Type};
6
7fn hash_token_stream(tokens: &proc_macro2::TokenStream) -> [u8; 32] {
65 let token_string = tokens.to_string();
67
68 let mut hasher = Sha256::new();
70
71 hasher.update(token_string.as_bytes());
73
74 hasher.finalize().into()
76}
77
78fn check_for_mutable_refs(
79 fn_inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
80) -> Result<(), syn::Error> {
81 for arg in fn_inputs {
82 let FnArg::Typed(pat_type) = arg else {
83 continue;
84 };
85
86 let Type::Reference(type_ref) = &*pat_type.ty else {
87 continue;
88 };
89
90 let Some(mutability) = &type_ref.mutability else {
91 continue;
92 };
93
94 return Err(syn::Error::new_spanned(
95 mutability,
96 "cached functions must be pure - mutable references are not allowed",
97 ));
98 }
99 Ok(())
100}
101
102fn get_param_type(ty: &Type) -> &Type {
103 if let Type::Reference(type_ref) = ty {
104 &type_ref.elem
105 } else {
106 ty
107 }
108}
109
110#[proc_macro_attribute]
111pub fn cached(_attr: TokenStream, item: TokenStream) -> TokenStream {
112 let input_fn = parse_macro_input!(item as ItemFn);
113
114 if let Err(err) = check_for_mutable_refs(&input_fn.sig.inputs) {
116 let compiler_err = err.to_compile_error();
117
118 return quote! {
119 #input_fn
120
121 #compiler_err
122 }
123 .into();
124 }
125
126 let mut input_fn = input_fn;
127
128 let mut fn_with_name_inner = input_fn.clone();
129 fn_with_name_inner.sig.ident = Ident::new("inner", Span::call_site());
130
131 let fn_with_name_inner_tokens = quote! {
132 #fn_with_name_inner
133 };
134
135 let inner_fn_hash = hash_token_stream(&fn_with_name_inner_tokens);
136
137 let inner_fn_hash_literal = quote! {
139 [
140 #(#inner_fn_hash,)*
141 ]
142 };
143
144 let fn_inputs = &input_fn.sig.inputs;
145 let fn_output = match &input_fn.sig.output {
146 ReturnType::Default => quote!(()),
147 ReturnType::Type(_, ty) => quote!(#ty),
148 };
149
150 let param_names: Vec<_> = fn_inputs
151 .iter()
152 .filter_map(|arg| match arg {
153 FnArg::Typed(pat_type) => {
154 if let Pat::Ident(pat_ident) = &*pat_type.pat {
155 Some(&pat_ident.ident)
156 } else {
157 None
158 }
159 }
160 _ => None,
161 })
162 .collect();
163
164 let param_types: Vec<_> = fn_inputs
165 .iter()
166 .filter_map(|arg| match arg {
167 FnArg::Typed(pat_type) => Some(get_param_type(&pat_type.ty)),
168 _ => None,
169 })
170 .collect();
171
172 let new_block = quote! {{
173 #fn_with_name_inner
174
175 use rkyv::{with::InlineAsBox, Archive, Deserialize, Serialize};
176
177 #[derive(Archive, Serialize, Deserialize, Debug)]
178 struct CacheKey<'a> {
179 #(
180 #[rkyv(with = InlineAsBox)]
181 #param_names: &'a #param_types,
182 )*
183 _function_hash: [u8; 32],
184 }
185
186 let key = CacheKey {
187 #(#param_names: &#param_names,)*
188 _function_hash: #inner_fn_hash_literal,
189 };
190 println!("{key:?}");
191 let key_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&key).unwrap();
192
193 if let Some(cached_result) = smart_cache::get_cached(&*key_bytes) {
194 let cached_result = &*cached_result;
195 let cached_result: &rkyv::Archived<#fn_output> = rkyv::access::<_, rkyv::rancor::Error>(cached_result).unwrap();
196 let cached_result: #fn_output = rkyv::deserialize::<#fn_output, rkyv::rancor::Error>(cached_result).unwrap();
197 return cached_result;
198 }
199
200 let result = inner(#(#param_names,)*);
201
202 let value_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&result).unwrap();
203 let _ = smart_cache::set_cached(&key_bytes, &value_bytes);
204
205 result
206 }};
207
208 input_fn.block = syn::parse2(new_block).unwrap();
209
210 TokenStream::from(quote! {
211 #input_fn
212 })
213}