skip_if_macros/
lib.rs

1extern crate proc_macro;
2use std::collections::HashSet;
3use std::hash::{Hash, Hasher};
4
5use darling::FromMeta;
6use proc_macro::*;
7use quote::{quote, ToTokens};
8
9#[derive(Debug, FromMeta)]
10struct Args {
11    strategy: syn::Expr,
12    output: syn::Expr,
13    #[darling(default)]
14    args_skip: String,
15}
16
17#[proc_macro_attribute]
18pub fn skip_if(args: TokenStream, input: TokenStream) -> TokenStream {
19    let args = darling::ast::NestedMeta::parse_meta_list(args.into()).unwrap();
20    let args = Args::from_list(&args).unwrap();
21    let strategy = &args.strategy;
22    let output = &args.output;
23    let mut input: syn::ItemFn = syn::parse(input).unwrap();
24
25    // Output type
26    let syn::ReturnType::Type(_, output_type) = &input.sig.output else {
27        panic!()
28    };
29    let syn::Type::Path(output_type) = output_type.as_ref() else {
30        panic!();
31    };
32
33    let stms = &input.block.stmts;
34
35    let mut hasher = std::collections::hash_map::DefaultHasher::new();
36    (*stms).hash(&mut hasher);
37    let code_hash = hasher.finish();
38
39    let mut args_hash = vec![quote! {
40        let mut hasher = std::collections::hash_map::DefaultHasher::new();
41    }];
42    let args_skip: HashSet<_> = args.args_skip.split(',').collect();
43    for input in &input.sig.inputs {
44        args_hash.push(match input {
45            syn::FnArg::Receiver(_) => {
46                quote! {self.hash(&mut hasher); }
47            }
48            syn::FnArg::Typed(tp) => {
49                let syn::Pat::Ident(id) = &*tp.pat else {
50                    panic!();
51                };
52                let pat = &tp.pat;
53                if args_skip.contains(id.ident.to_string().as_str()) {
54                    continue;
55                }
56                quote! { #pat.hash(&mut hasher); }
57            }
58        });
59    }
60
61    let res = if input.sig.asyncness.is_some() {
62        quote! { (|| async {#(#stms)*})().await }
63    } else {
64        quote! { (|| {#(#stms)*})() }
65    };
66
67    input.block = Box::new(
68        syn::parse2(quote! {{
69            use skip_if::Strategy;
70            use std::hash::{Hasher, Hash};
71
72            let _args_hash = {
73                #(#args_hash)*
74                hasher.finish()
75            };
76
77            let skip_if_output = &#output;
78            // Hack for type inference until `impl Trait` works for local variables
79            // (see https://github.com/rust-lang/rust/issues/63066)
80            fn get_strategy() -> impl Strategy<#output_type> {
81               #strategy
82            }
83            let _strategy = get_strategy();
84            // Call the strategy
85            if _strategy.skip(skip_if_output, _args_hash, #code_hash) {
86                tracing::warn!(?skip_if_output, "Skipping due to strategy");
87                return Ok(());
88            }
89            // Use a closure to avoid early returns
90            let res: #output_type = #res;
91            // Callback
92            if let Err(e) = _strategy.callback(&res, skip_if_output, _args_hash ,#code_hash) {
93                tracing::warn!(?e, "Strategy callback failed");
94            }
95            res
96        }})
97        .unwrap(),
98    );
99
100    input.into_token_stream().into()
101}