rustbench/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::punctuated::Punctuated;
4use syn::{parse::Parse, parse_macro_input, Expr, ItemFn, Lit};
5use syn::{ExprLit, Token};
6
7#[derive(Default, Debug)]
8struct MacroParamInput {
9    pub times: Option<i64>,
10}
11
12impl Parse for MacroParamInput {
13    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
14        let mut output_value = MacroParamInput::default();
15        let values: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated(input)?;
16
17        let mut iter = values.iter();
18
19        if let Some(expr) = iter.next() {
20            if let Expr::Lit(ExprLit {
21                lit: Lit::Int(lit_int),
22                ..
23            }) = expr
24            {
25                output_value.times = Some(lit_int.base10_parse::<i64>().unwrap());
26            }
27        }
28
29        Ok(output_value)
30    }
31}
32
33#[proc_macro_attribute]
34pub fn benchmark(attr: TokenStream, item: TokenStream) -> TokenStream {
35    let input_fn = parse_macro_input!(item as ItemFn);
36    let fn_name = &input_fn.sig.ident;
37    let fn_body = &input_fn.block;
38
39    let parsed_attr = parse_macro_input!(attr as MacroParamInput);
40    let times = parsed_attr.times.unwrap_or(1);
41
42    let expanded = if times > 1 {
43        quote! {
44            fn #fn_name() {
45                use std::time::Instant;
46                let mut total_time = std::time::Duration::new(0, 0);
47
48                for _ in 0..#times {
49                    let start = Instant::now();
50                    #fn_body
51                    let duration = start.elapsed();
52                    total_time += duration;
53                    println!("Iteration took: {:?}", duration);
54                }
55
56                let avg_time = total_time.as_nanos() / (#times as u128);
57                println!("Function '{}' executed {} times. Avg time: {:?} ns",stringify!(#fn_name), #times, avg_time);
58            }
59        }
60    } else {
61        quote! {
62            fn #fn_name() {
63                use std::time::Instant;
64                let start = Instant::now();
65                #fn_body
66                let duration = start.elapsed().as_nanos();
67                println!("Function '{}' executed in {:?} ns", stringify!(#fn_name), duration);
68            }
69        }
70    };
71
72    expanded.into()
73}