1use proc_macro::TokenStream;
2use quote::quote;
3use syn::punctuated::Punctuated;
4use syn::{parse::Parse, parse_macro_input, Expr, ItemFn, Lit};
5use syn::{ExprLit, Token};
6
7#[derive(Default, Debug)]
8struct MacroParamInput {
9 pub times: Option<i64>,
10}
11
12impl Parse for MacroParamInput {
13 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
14 let mut output_value = MacroParamInput::default();
15 let values: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated(input)?;
16
17 let mut iter = values.iter();
18
19 if let Some(expr) = iter.next() {
20 if let Expr::Lit(ExprLit {
21 lit: Lit::Int(lit_int),
22 ..
23 }) = expr
24 {
25 output_value.times = Some(lit_int.base10_parse::<i64>().unwrap());
26 }
27 }
28
29 Ok(output_value)
30 }
31}
32
33#[proc_macro_attribute]
34pub fn benchmark(attr: TokenStream, item: TokenStream) -> TokenStream {
35 let input_fn = parse_macro_input!(item as ItemFn);
36 let fn_name = &input_fn.sig.ident;
37 let fn_body = &input_fn.block;
38
39 let parsed_attr = parse_macro_input!(attr as MacroParamInput);
40 let times = parsed_attr.times.unwrap_or(1);
41
42 let expanded = if times > 1 {
43 quote! {
44 fn #fn_name() {
45 use std::time::Instant;
46 let mut total_time = std::time::Duration::new(0, 0);
47
48 for _ in 0..#times {
49 let start = Instant::now();
50 #fn_body
51 let duration = start.elapsed();
52 total_time += duration;
53 println!("Iteration took: {:?}", duration);
54 }
55
56 let avg_time = total_time.as_nanos() / (#times as u128);
57 println!("Function '{}' executed {} times. Avg time: {:?} ns",stringify!(#fn_name), #times, avg_time);
58 }
59 }
60 } else {
61 quote! {
62 fn #fn_name() {
63 use std::time::Instant;
64 let start = Instant::now();
65 #fn_body
66 let duration = start.elapsed().as_nanos();
67 println!("Function '{}' executed in {:?} ns", stringify!(#fn_name), duration);
68 }
69 }
70 };
71
72 expanded.into()
73}