spec_fn/
lib.rs

1mod case_block;
2use std::fmt::Display;
3
4use case_block::CaseBlock;
5
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use quote::ToTokens;
9use syn::{
10    parse_macro_input, parse_quote_spanned, punctuated::Punctuated, spanned::Spanned, token::Comma,
11    Arm, Block, Expr, ExprMatch, ExprPath, ExprTuple, FnArg, ItemFn, Pat, PatIdent, PatTuple, Path,
12    Signature,
13};
14
15/// Create spanned compile error.
16fn error<T: Display>(span: &Span, e: T) -> TokenStream {
17    syn::Error::new(*span, e).to_compile_error().into()
18}
19
20/// Parses an argument to an Expresion.
21fn parse_arg(pat: &Pat) -> Option<Expr> {
22    match pat {
23        // Parse ident as expr.
24        Pat::Ident(PatIdent { ident, .. }) => Some(Expr::from(ExprPath {
25            attrs: Vec::new(),
26            qself: None,
27            path: Path::from(ident.clone()),
28        })),
29        // Skip wildcards.
30        Pat::Wild(_) => None,
31        // Recursively parse tuples to tuple expressions.
32        Pat::Tuple(PatTuple { elems, .. }) => Some(Expr::from(ExprTuple {
33            attrs: Vec::new(),
34            paren_token: Default::default(),
35            elems: elems.iter().filter_map(parse_arg).collect(),
36        })),
37        _ => unreachable!("Invalid function parameter..."),
38    }
39}
40
41/// Parse the argument names of the specialized function and make them into a comma separated list.
42fn param_to_tuple(sig: &Signature) -> Punctuated<Expr, Comma> {
43    sig.inputs
44        .iter()
45        .filter_map(|arg| match arg {
46            FnArg::Typed(arg) => parse_arg(&*arg.pat),
47            _ => None,
48        })
49        .collect()
50}
51
52#[proc_macro]
53pub fn spec(item: TokenStream) -> TokenStream {
54    // Brace input so it can be parsed as block.
55    let item = TokenStream2::from(item);
56    let input: Block = parse_quote_spanned!( item.span() => { #item } );
57    let mut stmts = input.stmts.iter();
58
59    // Parse function specification.
60    let mut spec_func = if let Some(func) = stmts.next() {
61        let func: TokenStream = func.into_token_stream().into();
62        parse_macro_input!(func as ItemFn)
63    } else {
64        return error(&input.span(), "spec!{{}} block must not be empty");
65    };
66
67    // Parse parameters to expression for the match statement.
68    let spec_func_parameters = param_to_tuple(&spec_func.sig);
69
70    // Create the match statement.
71    let mut expr_match: ExprMatch = parse_quote_spanned!( spec_func.span() =>
72        #[allow(unused_parens)]
73        match (#spec_func_parameters) {}
74    );
75
76    // Parse the pattern and function body of each case and assemble the match statement.
77    for stmt in stmts {
78        let stmt: TokenStream = stmt.to_token_stream().into();
79        let CaseBlock { case, when, block } = parse_macro_input!(stmt as CaseBlock);
80
81        let arm: Arm = match when {
82            Some(when) => parse_quote_spanned!( spec_func.span() => (#case) if #when => #block ),
83            None => parse_quote_spanned!( spec_func.span() => (#case) => #block ),
84        };
85
86        expr_match.arms.push(arm);
87    }
88
89    // Set the specialized function body to the match statement.
90    let body: Block = parse_quote_spanned!( spec_func.span() => { #expr_match } );
91    spec_func.block = Box::new(body);
92
93    spec_func.into_token_stream().into()
94}