sigma_types_macros/
lib.rs

1//! Macros that expand the power and usefulness of the `sigma-types` crate.
2
3#![expect(
4    clippy::panic,
5    reason = "TODO: how to return an error from a `TokenStream` function"
6)]
7
8use {
9    core::iter::{empty, once},
10    proc_macro::TokenStream,
11    quote::{format_ident, quote},
12    syn::{
13        AttrStyle, Attribute, Block, Expr, ExprAssign, ExprCall, ExprPath, FnArg, GenericParam,
14        Generics, Item, ItemFn, Pat, PatIdent, PatType, Path, PathArguments, PathSegment,
15        ReturnType, Signature, Stmt, Type, TypePath, TypeReference, Visibility, parse_macro_input,
16        punctuated::Punctuated, token::Comma,
17    },
18};
19
20#[cfg(test)]
21use {quickcheck as _, quickcheck_macros as _, sigma_types as _};
22
23/// Add a test that ensures that
24/// this type signature holds for all possible inputs.
25/// # Panics
26/// If not applied to a standalone (i.e. no `self`),
27/// standard (i.e. non-`async`, non-`unsafe`, etc.) function.
28#[proc_macro_attribute]
29#[expect(
30    clippy::too_many_lines,
31    reason = "fuck off (or take issue with `syn`)--the logic is actually fairly simple"
32)]
33pub fn forall(_args: TokenStream, input: TokenStream) -> TokenStream {
34    let input2: proc_macro2::TokenStream = input.clone().into(); // TODO: must be a better way
35    let item = parse_macro_input!(input as Item);
36    let Item::Fn(fn_item) = item else {
37        panic!("`{input2}` is not a function");
38    };
39
40    let signature = fn_item.sig;
41    let quickcheck_inputs: Punctuated<FnArg, Comma> = signature
42        .inputs
43        .into_iter()
44        .map(|arg_or_self| {
45            let FnArg::Typed(arg_with_type) = arg_or_self else {
46                panic!("Functions with `self` won't work");
47            };
48            let PatType {
49                attrs,
50                pat,
51                colon_token,
52                ty,
53            } = arg_with_type;
54            let (quickcheck_pattern, quickcheck_type) =
55                if let Type::Reference(TypeReference { elem, .. }) = *ty {
56                    (
57                        Box::new(Pat::Ident(PatIdent {
58                            attrs: attrs.clone(),
59                            by_ref: Some(Default::default()),
60                            mutability: None,
61                            ident: {
62                                #[expect(
63                                    clippy::wildcard_enum_match_arm,
64                                    reason = "huge number of cases"
65                                )]
66                                match *pat {
67                                    Pat::Ident(PatIdent { ident, .. }) => ident,
68                                    #[expect(clippy::todo, reason = "edge case, not world-ending")]
69                                    _ => todo!("make up an ident"),
70                                }
71                            },
72                            subpat: None,
73                        })),
74                        elem,
75                    )
76                } else {
77                    (pat, ty)
78                };
79            FnArg::Typed(PatType {
80                attrs,
81                pat: quickcheck_pattern,
82                colon_token,
83                ty: quickcheck_type,
84            })
85        })
86        .collect();
87    let quickcheck_args: Punctuated<Expr, Comma> = quickcheck_inputs
88        .iter()
89        .map(|fn_arg| {
90            let FnArg::Typed(PatType {
91                ref attrs,
92                ref pat,
93                colon_token: _,
94                ty: _,
95            }) = *fn_arg
96            else {
97                panic!("INTERNAL ERROR")
98            };
99            Expr::Path(ExprPath {
100                attrs: attrs.clone(),
101                qself: None,
102                path: pat_to_path(pat),
103            })
104        })
105        .collect();
106    let quickcheck_signature = Signature {
107        constness: None,
108        abi: None,
109        asyncness: None,
110        fn_token: signature.fn_token,
111        generics: Generics {
112            lt_token: None,
113            params: empty::<GenericParam>().collect(),
114            gt_token: None,
115            where_clause: None,
116        },
117        ident: format_ident!("forall_{}", signature.ident),
118        inputs: quickcheck_inputs,
119        output: ReturnType::Type(
120            Default::default(),
121            Box::new(Type::Path(TypePath {
122                qself: None,
123                path: Path {
124                    leading_colon: Some(Default::default()),
125                    segments: [
126                        PathSegment {
127                            arguments: PathArguments::None,
128                            ident: format_ident!("quickcheck"),
129                        },
130                        PathSegment {
131                            arguments: PathArguments::None,
132                            ident: format_ident!("TestResult"),
133                        },
134                    ]
135                    .into_iter()
136                    .collect(),
137                },
138            })),
139        ),
140        paren_token: Default::default(),
141        unsafety: None,
142        variadic: None,
143    };
144
145    let quickcheck_fn = ItemFn {
146        attrs: once(Attribute {
147            pound_token: Default::default(),
148            style: AttrStyle::Outer,
149            bracket_token: Default::default(),
150            path: Path {
151                leading_colon: None,
152                segments: once(PathSegment {
153                    arguments: PathArguments::None,
154                    ident: format_ident!("cfg"),
155                })
156                .collect(),
157            },
158            tokens: quote! { (test) },
159        })
160        .chain(
161            fn_item
162                .attrs
163                .into_iter()
164                .filter(|attr| matches!(attr.style, AttrStyle::Outer)),
165        )
166        .chain(once(Attribute {
167            pound_token: Default::default(),
168            style: AttrStyle::Outer,
169            bracket_token: Default::default(),
170            path: Path {
171                leading_colon: Some(Default::default()),
172                segments: [
173                    PathSegment {
174                        arguments: PathArguments::None,
175                        ident: format_ident!("quickcheck_macros"),
176                    },
177                    PathSegment {
178                        arguments: PathArguments::None,
179                        ident: format_ident!("quickcheck"),
180                    },
181                ]
182                .into_iter()
183                .collect(),
184            },
185            tokens: quote! {},
186        }))
187        .collect(),
188        vis: Visibility::Inherited,
189        sig: quickcheck_signature,
190        block: Box::new(Block {
191            brace_token: Default::default(),
192            stmts: vec![
193                Stmt::Semi(
194                    Expr::Assign(ExprAssign {
195                        attrs: vec![],
196                        left: Box::new(Expr::Path(ExprPath {
197                            attrs: vec![],
198                            qself: None,
199                            path: Path {
200                                leading_colon: None,
201                                segments: once(PathSegment {
202                                    arguments: PathArguments::None,
203                                    ident: format_ident!("_"),
204                                })
205                                .collect(),
206                            },
207                        })),
208                        eq_token: Default::default(),
209                        right: Box::new(Expr::Call(ExprCall {
210                            attrs: vec![],
211                            func: Box::new(Expr::Path(ExprPath {
212                                attrs: vec![],
213                                qself: None,
214                                path: Path {
215                                    leading_colon: None,
216                                    segments: once(PathSegment {
217                                        arguments: PathArguments::None,
218                                        ident: signature.ident,
219                                    })
220                                    .collect(),
221                                },
222                            })),
223                            paren_token: Default::default(),
224                            args: quickcheck_args,
225                        })),
226                    }),
227                    Default::default(),
228                ),
229                Stmt::Expr(
230                    Expr::Call(ExprCall {
231                        attrs: vec![],
232                        func: Box::new(Expr::Path(ExprPath {
233                            attrs: vec![],
234                            qself: None,
235                            path: Path {
236                                leading_colon: Some(Default::default()),
237                                segments: [
238                                    PathSegment {
239                                        arguments: PathArguments::None,
240                                        ident: format_ident!("quickcheck"),
241                                    },
242                                    PathSegment {
243                                        arguments: PathArguments::None,
244                                        ident: format_ident!("TestResult"),
245                                    },
246                                    PathSegment {
247                                        arguments: PathArguments::None,
248                                        ident: format_ident!("passed"),
249                                    },
250                                ]
251                                .into_iter()
252                                .collect(),
253                            },
254                        })),
255                        paren_token: Default::default(),
256                        args: empty::<Expr>().collect(),
257                    }),
258                    // None,
259                ),
260            ],
261        }),
262    };
263
264    quote! {
265        #input2 // retain the original input
266
267        #quickcheck_fn
268    }
269    .into()
270}
271
272/// Extract an ident from a compatible pattern.
273#[inline]
274#[expect(clippy::single_call_fn, reason = "may be recursive")]
275fn pat_to_path(pat: &Pat) -> Path {
276    #[expect(clippy::wildcard_enum_match_arm, reason = "huge number of cases")]
277    match *pat {
278        Pat::Ident(PatIdent { ref ident, .. }) => Path {
279            leading_colon: None,
280            segments: once(PathSegment {
281                ident: ident.clone(),
282                arguments: PathArguments::None,
283            })
284            .collect(),
285        },
286        _ => panic!("INTERNAL ERROR"),
287    }
288}