Skip to main content

test_better_macros/
lib.rs

1//! `test-better-macros`: procedural macros.
2//!
3//! Home of `matches_struct!`, `matches_tuple!`, `matches_variant!`, the
4//! `#[test_case]` attribute, and the `#[fixture]` / `#[test_with_fixtures]`
5//! attribute pair, with the inline snapshot macros still to come.
6//!
7//! The structural matchers parse a *pattern* of inner matcher expressions and
8//! emit a `Matcher` impl. The matcher holds a projection (a closure that pulls
9//! the fields out of the value) plus one inner matcher per field; the
10//! projection's type ties the matcher's type parameters to the real field
11//! types, so the field types never have to be named in the macro. The
12//! projection is threaded through a generated constructor function whose
13//! signature carries the `Fn` bound, which is what makes the closure infer as
14//! higher-ranked over the borrow.
15//!
16//! `#[test_case]` is an attribute macro: stacked `#[test_case(..)]` lines on one
17//! function each become a generated `#[test]`, all gathered into a module named
18//! for the function so the cases share a namespace.
19//!
20//! `#[fixture]` turns a `fn() -> TestResult<T>` into a fixture: its failures are
21//! re-categorized as `ErrorKind::Setup` so a setup problem never reads as an
22//! assertion miss. `#[test_with_fixtures]` is the seam that consumes them: each
23//! parameter `name: T` is filled by calling the same-named fixture `fn name()`.
24//!
25//! The generated code refers to the testing library through the `::test_better`
26//! facade crate, so these macros are meant to be used via `test-better`, not by
27//! depending on `test-better-macros` directly.
28
29use std::collections::HashSet;
30
31use proc_macro::TokenStream;
32use proc_macro2::TokenStream as TokenStream2;
33use quote::{format_ident, quote};
34use syn::parse::{Parse, ParseStream};
35use syn::punctuated::Punctuated;
36use syn::spanned::Spanned;
37use syn::{Expr, FnArg, Ident, Index, ItemFn, LitStr, Pat, Path, Token, braced, parenthesized};
38
39/// A named-field pattern: `Path { field: matcher, ..., .. }`.
40struct StructPattern {
41    path: Path,
42    fields: Vec<(Ident, Expr)>,
43    rest: bool,
44}
45
46impl Parse for StructPattern {
47    fn parse(input: ParseStream) -> syn::Result<Self> {
48        let path: Path = input.parse()?;
49        let content;
50        braced!(content in input);
51        let (fields, rest) = parse_named_fields(&content)?;
52        Ok(Self { path, fields, rest })
53    }
54}
55
56/// A positional pattern: `Path(matcher, ..., ..)`.
57struct TuplePattern {
58    path: Path,
59    elems: Vec<Expr>,
60    rest: bool,
61}
62
63impl Parse for TuplePattern {
64    fn parse(input: ParseStream) -> syn::Result<Self> {
65        let path: Path = input.parse()?;
66        let content;
67        parenthesized!(content in input);
68        let (elems, rest) = parse_positional_fields(&content)?;
69        Ok(Self { path, elems, rest })
70    }
71}
72
73/// The body of a variant pattern: struct-like, tuple-like, or unit.
74enum VariantBody {
75    Struct {
76        fields: Vec<(Ident, Expr)>,
77        rest: bool,
78    },
79    Tuple {
80        elems: Vec<Expr>,
81        rest: bool,
82    },
83    Unit,
84}
85
86/// A variant pattern: `Enum::Variant { .. }`, `Enum::Variant( .. )`, or
87/// `Enum::Variant`.
88struct VariantPattern {
89    path: Path,
90    body: VariantBody,
91}
92
93impl Parse for VariantPattern {
94    fn parse(input: ParseStream) -> syn::Result<Self> {
95        let path: Path = input.parse()?;
96        let body = if input.peek(syn::token::Brace) {
97            let content;
98            braced!(content in input);
99            let (fields, rest) = parse_named_fields(&content)?;
100            VariantBody::Struct { fields, rest }
101        } else if input.peek(syn::token::Paren) {
102            let content;
103            parenthesized!(content in input);
104            let (elems, rest) = parse_positional_fields(&content)?;
105            VariantBody::Tuple { elems, rest }
106        } else {
107            VariantBody::Unit
108        };
109        Ok(Self { path, body })
110    }
111}
112
113/// Parses `field: expr` entries, optionally ending with `..`. The `..`, when
114/// present, must be the final element.
115fn parse_named_fields(content: ParseStream) -> syn::Result<(Vec<(Ident, Expr)>, bool)> {
116    let mut fields = Vec::new();
117    let mut rest = false;
118    while !content.is_empty() {
119        if content.peek(Token![..]) {
120            content.parse::<Token![..]>()?;
121            rest = true;
122            break;
123        }
124        let name: Ident = content.parse()?;
125        content.parse::<Token![:]>()?;
126        let expr: Expr = content.parse()?;
127        fields.push((name, expr));
128        if content.is_empty() {
129            break;
130        }
131        content.parse::<Token![,]>()?;
132    }
133    if !content.is_empty() {
134        return Err(content.error("`..` must be the final element of the pattern"));
135    }
136    Ok((fields, rest))
137}
138
139/// Parses positional `expr` entries, optionally ending with `..`. The `..`,
140/// when present, must be the final element.
141fn parse_positional_fields(content: ParseStream) -> syn::Result<(Vec<Expr>, bool)> {
142    let mut elems = Vec::new();
143    let mut rest = false;
144    while !content.is_empty() {
145        if content.peek(Token![..]) {
146            content.parse::<Token![..]>()?;
147            rest = true;
148            break;
149        }
150        elems.push(content.parse()?);
151        if content.is_empty() {
152            break;
153        }
154        content.parse::<Token![,]>()?;
155    }
156    if !content.is_empty() {
157        return Err(content.error("`..` must be the final element of the pattern"));
158    }
159    Ok((elems, rest))
160}
161
162/// Splits `Enum::Variant` into the enum path (`Enum`) and the variant ident
163/// (`Variant`).
164fn split_variant_path(path: &Path) -> syn::Result<(Path, Ident)> {
165    if path.segments.len() < 2 {
166        return Err(syn::Error::new_spanned(
167            path,
168            "expected an enum variant path like `MyEnum::Variant`",
169        ));
170    }
171    let kept = path.segments.len() - 1;
172    let segments: Punctuated<syn::PathSegment, Token![::]> =
173        path.segments.iter().take(kept).cloned().collect();
174    let enum_path = Path {
175        leading_colon: path.leading_colon,
176        segments,
177    };
178    let variant_ident = match path.segments.last() {
179        Some(seg) => seg.ident.clone(),
180        None => return Err(syn::Error::new_spanned(path, "missing variant name")),
181    };
182    Ok((enum_path, variant_ident))
183}
184
185/// The per-field generated idents: the matcher type parameter, the field-type
186/// parameter, the struct field holding the matcher, and the binding the
187/// projection's output is destructured into.
188struct FieldIdents {
189    matcher_ty: Vec<Ident>,
190    field_ty: Vec<Ident>,
191    matcher_field: Vec<Ident>,
192    binding: Vec<Ident>,
193}
194
195fn field_idents(n: usize) -> FieldIdents {
196    FieldIdents {
197        matcher_ty: (0..n).map(|i| format_ident!("__TbM{}", i)).collect(),
198        field_ty: (0..n).map(|i| format_ident!("__TbF{}", i)).collect(),
199        matcher_field: (0..n).map(|i| format_ident!("__tb_m{}", i)).collect(),
200        binding: (0..n).map(|i| format_ident!("__tb_f{}", i)).collect(),
201    }
202}
203
204/// The body of each field's check: run the inner matcher on the projected
205/// field, and on failure return a mismatch whose expectation is labeled with
206/// the field name.
207fn field_check_blocks(
208    matcher_field: &[Ident],
209    binding: &[Ident],
210    labels: &[String],
211) -> Vec<TokenStream2> {
212    matcher_field
213        .iter()
214        .zip(binding)
215        .zip(labels)
216        .map(|((field, bind), label)| {
217            let label = label.as_str();
218            quote! {
219                {
220                    let __tb_result = ::test_better::Matcher::check(&self.#field, #bind);
221                    if !__tb_result.matched {
222                        let __tb_inner = match __tb_result.failure {
223                            ::core::option::Option::Some(__tb_mismatch) => __tb_mismatch,
224                            ::core::option::Option::None => ::test_better::Mismatch::new(
225                                ::test_better::Matcher::description(&self.#field),
226                                "the field matcher reported failure without detail",
227                            ),
228                        };
229                        return ::test_better::MatchResult::fail(::test_better::Mismatch {
230                            expected: ::test_better::Description::labeled(
231                                #label,
232                                __tb_inner.expected,
233                            ),
234                            actual: __tb_inner.actual,
235                            diff: __tb_inner.diff,
236                        });
237                    }
238                }
239            }
240        })
241        .collect()
242}
243
244/// Folds each field's labeled description together under conjunction.
245fn description_fold(matcher_field: &[Ident], labels: &[String]) -> TokenStream2 {
246    let mut parts = matcher_field.iter().zip(labels).map(|(field, label)| {
247        let label = label.as_str();
248        quote! {
249            ::test_better::Description::labeled(
250                #label,
251                ::test_better::Matcher::description(&self.#field),
252            )
253        }
254    });
255    match parts.next() {
256        Some(first) => {
257            let mut acc = first;
258            for part in parts {
259                acc = quote! { #acc.and(#part) };
260            }
261            acc
262        }
263        None => quote! { ::test_better::Description::text("a matching value") },
264    }
265}
266
267/// Wraps an exhaustiveness-checking statement in a never-called function, so a
268/// missing or unknown field is a hard error from rustc's own pattern checking.
269fn exhaustiveness_fn(target: &TokenStream2, stmt: Option<TokenStream2>) -> TokenStream2 {
270    match stmt {
271        Some(stmt) => quote! {
272            #[allow(dead_code, unused_variables, irrefutable_let_patterns, clippy::all)]
273            fn __tb_assert_exhaustive(__tb_value: &#target) {
274                #stmt
275            }
276        },
277        None => quote! {},
278    }
279}
280
281/// Assembles a plain (struct or tuple struct) structural matcher.
282///
283/// `projection` is a closure `Fn(&Self) -> (&F0, &F1, ...)`. It is passed to
284/// the generated `__tb_make`, whose signature carries the `Fn` bound; that is
285/// what lets the closure infer as higher-ranked over the borrow. `__tb_make`'s
286/// where-clause then pins the matcher's `Self` and field types.
287fn gen_plain(
288    target: &TokenStream2,
289    labels: &[String],
290    field_exprs: &[&Expr],
291    projection: TokenStream2,
292    exhaustiveness: Option<TokenStream2>,
293) -> TokenStream2 {
294    let idents = field_idents(labels.len());
295    let FieldIdents {
296        matcher_ty,
297        field_ty,
298        matcher_field,
299        binding,
300    } = &idents;
301    let n = labels.len();
302    let assertion = exhaustiveness_fn(target, exhaustiveness);
303
304    if n == 0 {
305        return quote! {
306            {
307                #[allow(non_camel_case_types, dead_code, clippy::all)]
308                struct __TbStructuralMatcher<__TbP> {
309                    __tb_project: __TbP,
310                }
311
312                #[allow(clippy::all)]
313                impl<__TbS, __TbP> ::test_better::Matcher<__TbS>
314                    for __TbStructuralMatcher<__TbP>
315                where
316                    __TbP: ::core::ops::Fn(&__TbS) -> (),
317                {
318                    fn check(&self, __tb_actual: &__TbS) -> ::test_better::MatchResult {
319                        let () = (self.__tb_project)(__tb_actual);
320                        ::test_better::MatchResult::pass()
321                    }
322
323                    fn description(&self) -> ::test_better::Description {
324                        ::test_better::Description::text("a matching value")
325                    }
326                }
327
328                #[allow(clippy::all)]
329                fn __tb_make<__TbS, __TbP>(
330                    __tb_project: __TbP,
331                ) -> impl ::test_better::Matcher<__TbS>
332                where
333                    __TbP: ::core::ops::Fn(&__TbS) -> (),
334                {
335                    __TbStructuralMatcher { __tb_project }
336                }
337
338                #assertion
339
340                __tb_make(#projection)
341            }
342        };
343    }
344
345    let checks = field_check_blocks(matcher_field, binding, labels);
346    let desc = description_fold(matcher_field, labels);
347
348    quote! {
349        {
350            #[allow(non_camel_case_types, dead_code, clippy::all)]
351            struct __TbStructuralMatcher<__TbP, #( #matcher_ty, )*> {
352                __tb_project: __TbP,
353                #( #matcher_field: #matcher_ty, )*
354            }
355
356            #[allow(clippy::all)]
357            impl<__TbS, #( #field_ty, )* __TbP, #( #matcher_ty, )*>
358                ::test_better::Matcher<__TbS>
359                for __TbStructuralMatcher<__TbP, #( #matcher_ty, )*>
360            where
361                __TbP: ::core::ops::Fn(&__TbS) -> ( #( &#field_ty, )* ),
362                #( #matcher_ty: ::test_better::Matcher<#field_ty>, )*
363            {
364                fn check(&self, __tb_actual: &__TbS) -> ::test_better::MatchResult {
365                    let ( #( #binding, )* ) = (self.__tb_project)(__tb_actual);
366                    #( #checks )*
367                    ::test_better::MatchResult::pass()
368                }
369
370                fn description(&self) -> ::test_better::Description {
371                    #desc
372                }
373            }
374
375            #[allow(clippy::all)]
376            fn __tb_make<__TbS, #( #field_ty, )* __TbP, #( #matcher_ty, )*>(
377                __tb_project: __TbP,
378                #( #matcher_field: #matcher_ty, )*
379            ) -> impl ::test_better::Matcher<__TbS>
380            where
381                __TbP: ::core::ops::Fn(&__TbS) -> ( #( &#field_ty, )* ),
382                #( #matcher_ty: ::test_better::Matcher<#field_ty>, )*
383            {
384                __TbStructuralMatcher {
385                    __tb_project,
386                    #( #matcher_field, )*
387                }
388            }
389
390            #assertion
391
392            __tb_make(#projection, #( #field_exprs, )*)
393        }
394    }
395}
396
397fn gen_struct(path: &Path, fields: &[(Ident, Expr)], rest: bool) -> TokenStream2 {
398    let target = quote! { #path };
399    let labels: Vec<String> = fields.iter().map(|(name, _)| name.to_string()).collect();
400    let field_exprs: Vec<&Expr> = fields.iter().map(|(_, expr)| expr).collect();
401    let field_names: Vec<&Ident> = fields.iter().map(|(name, _)| name).collect();
402
403    let projection = if fields.is_empty() {
404        quote! { |_: &#path| () }
405    } else {
406        quote! { |__tb_subject: &#path| ( #( &__tb_subject.#field_names, )* ) }
407    };
408
409    let exhaustiveness = if rest {
410        None
411    } else {
412        Some(quote! { let #path { #( #field_names: _, )* } = __tb_value; })
413    };
414
415    gen_plain(&target, &labels, &field_exprs, projection, exhaustiveness)
416}
417
418fn gen_tuple(path: &Path, elems: &[Expr], rest: bool) -> TokenStream2 {
419    let target = quote! { #path };
420    let labels: Vec<String> = (0..elems.len()).map(|i| i.to_string()).collect();
421    let field_exprs: Vec<&Expr> = elems.iter().collect();
422    let indices: Vec<Index> = (0..elems.len()).map(Index::from).collect();
423
424    let projection = if elems.is_empty() {
425        quote! { |_: &#path| () }
426    } else {
427        quote! { |__tb_subject: &#path| ( #( &__tb_subject.#indices, )* ) }
428    };
429
430    let exhaustiveness = if rest {
431        None
432    } else {
433        let holes = elems.iter().map(|_| quote!(_));
434        Some(quote! { let #path( #( #holes, )* ) = __tb_value; })
435    };
436
437    gen_plain(&target, &labels, &field_exprs, projection, exhaustiveness)
438}
439
440fn gen_variant(pattern: &VariantPattern) -> syn::Result<TokenStream2> {
441    let (enum_path, variant_ident) = split_variant_path(&pattern.path)?;
442    let path = &pattern.path;
443    let target = quote! { #enum_path };
444    let variant_name = variant_ident.to_string();
445    let variant_label = format!("the {variant_name} variant");
446
447    // The labels, the inner matcher expressions, the projection closure, and the
448    // exhaustiveness assertion all differ by variant shape.
449    let (labels, field_exprs, projection, exhaustiveness): (
450        Vec<String>,
451        Vec<&Expr>,
452        TokenStream2,
453        Option<TokenStream2>,
454    ) = match &pattern.body {
455        VariantBody::Struct { fields, rest } => {
456            let labels: Vec<String> = fields.iter().map(|(name, _)| name.to_string()).collect();
457            let field_exprs: Vec<&Expr> = fields.iter().map(|(_, expr)| expr).collect();
458            let field_names: Vec<&Ident> = fields.iter().map(|(name, _)| name).collect();
459            let bindings: Vec<Ident> = (0..fields.len())
460                .map(|i| format_ident!("__tb_p{}", i))
461                .collect();
462            let projection = quote! {
463                |__tb_subject: &#enum_path| match __tb_subject {
464                    #path { #( #field_names: #bindings, )* .. } =>
465                        ::core::option::Option::Some(( #( #bindings, )* )),
466                    _ => ::core::option::Option::None,
467                }
468            };
469            let exhaustiveness = if *rest {
470                None
471            } else {
472                Some(quote! { if let #path { #( #field_names: _, )* } = __tb_value {} })
473            };
474            (labels, field_exprs, projection, exhaustiveness)
475        }
476        VariantBody::Tuple { elems, rest } => {
477            let labels: Vec<String> = (0..elems.len()).map(|i| i.to_string()).collect();
478            let field_exprs: Vec<&Expr> = elems.iter().collect();
479            let bindings: Vec<Ident> = (0..elems.len())
480                .map(|i| format_ident!("__tb_p{}", i))
481                .collect();
482            let projection = quote! {
483                |__tb_subject: &#enum_path| match __tb_subject {
484                    #path( #( #bindings, )* .. ) =>
485                        ::core::option::Option::Some(( #( #bindings, )* )),
486                    _ => ::core::option::Option::None,
487                }
488            };
489            let exhaustiveness = if *rest {
490                None
491            } else {
492                let holes = elems.iter().map(|_| quote!(_));
493                Some(quote! { if let #path( #( #holes, )* ) = __tb_value {} })
494            };
495            (labels, field_exprs, projection, exhaustiveness)
496        }
497        VariantBody::Unit => {
498            let projection = quote! {
499                |__tb_subject: &#enum_path| match __tb_subject {
500                    #path => ::core::option::Option::Some(()),
501                    _ => ::core::option::Option::None,
502                }
503            };
504            (Vec::new(), Vec::new(), projection, None)
505        }
506    };
507
508    let idents = field_idents(labels.len());
509    let FieldIdents {
510        matcher_ty,
511        field_ty,
512        matcher_field,
513        binding,
514    } = &idents;
515    let n = labels.len();
516    let assertion = exhaustiveness_fn(&target, exhaustiveness);
517
518    let wrong_variant = quote! {
519        ::test_better::MatchResult::fail(::test_better::Mismatch::new(
520            ::test_better::Description::text(#variant_label),
521            ::std::format!("{:?}", __tb_actual),
522        ))
523    };
524
525    if n == 0 {
526        return Ok(quote! {
527            {
528                #[allow(non_camel_case_types, dead_code, clippy::all)]
529                struct __TbVariantMatcher<__TbP> {
530                    __tb_project: __TbP,
531                }
532
533                #[allow(clippy::all)]
534                impl<__TbS, __TbP> ::test_better::Matcher<__TbS>
535                    for __TbVariantMatcher<__TbP>
536                where
537                    __TbP: ::core::ops::Fn(&__TbS) -> ::core::option::Option<()>,
538                    __TbS: ::core::fmt::Debug,
539                {
540                    fn check(&self, __tb_actual: &__TbS) -> ::test_better::MatchResult {
541                        match (self.__tb_project)(__tb_actual) {
542                            ::core::option::Option::Some(()) => {
543                                ::test_better::MatchResult::pass()
544                            }
545                            ::core::option::Option::None => #wrong_variant,
546                        }
547                    }
548
549                    fn description(&self) -> ::test_better::Description {
550                        ::test_better::Description::text(#variant_label)
551                    }
552                }
553
554                #[allow(clippy::all)]
555                fn __tb_make<__TbS, __TbP>(
556                    __tb_project: __TbP,
557                ) -> impl ::test_better::Matcher<__TbS>
558                where
559                    __TbP: ::core::ops::Fn(&__TbS) -> ::core::option::Option<()>,
560                    __TbS: ::core::fmt::Debug,
561                {
562                    __TbVariantMatcher { __tb_project }
563                }
564
565                #assertion
566
567                __tb_make(#projection)
568            }
569        });
570    }
571
572    let checks = field_check_blocks(matcher_field, binding, &labels);
573    let desc_inner = description_fold(matcher_field, &labels);
574    let desc = quote! { ::test_better::Description::labeled(#variant_name, #desc_inner) };
575
576    Ok(quote! {
577        {
578            #[allow(non_camel_case_types, dead_code, clippy::all)]
579            struct __TbVariantMatcher<__TbP, #( #matcher_ty, )*> {
580                __tb_project: __TbP,
581                #( #matcher_field: #matcher_ty, )*
582            }
583
584            #[allow(clippy::all)]
585            impl<__TbS, #( #field_ty, )* __TbP, #( #matcher_ty, )*>
586                ::test_better::Matcher<__TbS>
587                for __TbVariantMatcher<__TbP, #( #matcher_ty, )*>
588            where
589                __TbP: ::core::ops::Fn(&__TbS)
590                    -> ::core::option::Option<( #( &#field_ty, )* )>,
591                #( #matcher_ty: ::test_better::Matcher<#field_ty>, )*
592                __TbS: ::core::fmt::Debug,
593            {
594                fn check(&self, __tb_actual: &__TbS) -> ::test_better::MatchResult {
595                    match (self.__tb_project)(__tb_actual) {
596                        ::core::option::Option::Some(( #( #binding, )* )) => {
597                            #( #checks )*
598                            ::test_better::MatchResult::pass()
599                        }
600                        ::core::option::Option::None => #wrong_variant,
601                    }
602                }
603
604                fn description(&self) -> ::test_better::Description {
605                    #desc
606                }
607            }
608
609            #[allow(clippy::all)]
610            fn __tb_make<__TbS, #( #field_ty, )* __TbP, #( #matcher_ty, )*>(
611                __tb_project: __TbP,
612                #( #matcher_field: #matcher_ty, )*
613            ) -> impl ::test_better::Matcher<__TbS>
614            where
615                __TbP: ::core::ops::Fn(&__TbS)
616                    -> ::core::option::Option<( #( &#field_ty, )* )>,
617                #( #matcher_ty: ::test_better::Matcher<#field_ty>, )*
618                __TbS: ::core::fmt::Debug,
619            {
620                __TbVariantMatcher {
621                    __tb_project,
622                    #( #matcher_field, )*
623                }
624            }
625
626            #assertion
627
628            __tb_make(#projection, #( #field_exprs, )*)
629        }
630    })
631}
632
633/// Matches a struct by applying an inner matcher to each named field.
634///
635/// Without a trailing `..` every field must be listed, exactly as in a struct
636/// pattern; with `..` the unlisted fields are ignored.
637///
638/// ```ignore
639/// use test_better::prelude::*;
640/// use test_better::matches_struct;
641///
642/// #[derive(Debug)]
643/// struct User {
644///     name: String,
645///     age: u32,
646///     email: String,
647/// }
648///
649/// fn check(user: User) -> TestResult {
650///     expect!(user).to(matches_struct!(User {
651///         name: eq(String::from("alice")),
652///         age: gt(0u32),
653///         email: contains_str("@"),
654///         .. // remaining fields ignored
655///     }))?;
656///     Ok(())
657/// }
658/// ```
659#[proc_macro]
660pub fn matches_struct(input: TokenStream) -> TokenStream {
661    match syn::parse::<StructPattern>(input) {
662        Ok(pattern) => gen_struct(&pattern.path, &pattern.fields, pattern.rest).into(),
663        Err(error) => error.to_compile_error().into(),
664    }
665}
666
667/// Matches a tuple struct by applying an inner matcher to each positional
668/// field.
669///
670/// Without a trailing `..` every element must be listed; with `..` the unlisted
671/// trailing elements are ignored.
672///
673/// ```ignore
674/// use test_better::prelude::*;
675/// use test_better::matches_tuple;
676///
677/// #[derive(Debug)]
678/// struct Point(i32, i32);
679///
680/// fn check(point: Point) -> TestResult {
681///     expect!(point).to(matches_tuple!(Point(gt(0), lt(100))))?;
682///     Ok(())
683/// }
684/// ```
685#[proc_macro]
686pub fn matches_tuple(input: TokenStream) -> TokenStream {
687    match syn::parse::<TuplePattern>(input) {
688        Ok(pattern) => gen_tuple(&pattern.path, &pattern.elems, pattern.rest).into(),
689        Err(error) => error.to_compile_error().into(),
690    }
691}
692
693/// Matches an enum value against a specific variant, applying an inner matcher
694/// to each of that variant's fields.
695///
696/// A value of a different variant is a match failure, not a compile error. The
697/// variant may be struct-like (`Enum::Variant { field: m, .. }`), tuple-like
698/// (`Enum::Variant(m, ..)`), or unit (`Enum::Variant`). The enum type must be
699/// `Debug` so a wrong-variant failure can render the value.
700///
701/// ```ignore
702/// use test_better::prelude::*;
703/// use test_better::matches_variant;
704///
705/// #[derive(Debug)]
706/// enum Shape {
707///     Circle { radius: f64 },
708///     Rectangle(f64, f64),
709/// }
710///
711/// fn check(shape: Shape) -> TestResult {
712///     expect!(shape).to(matches_variant!(Shape::Circle { radius: gt(0.0) }))?;
713///     Ok(())
714/// }
715/// ```
716#[proc_macro]
717pub fn matches_variant(input: TokenStream) -> TokenStream {
718    let result = syn::parse::<VariantPattern>(input).and_then(|pattern| gen_variant(&pattern));
719    match result {
720        Ok(tokens) => tokens.into(),
721        Err(error) => error.to_compile_error().into(),
722    }
723}
724
725/// One `#[test_case(..)]` invocation: the argument expressions and an optional
726/// `; "label"`.
727struct TestCase {
728    /// A span pointing at the invocation, used to place an arg-count error on
729    /// the offending attribute rather than on the function.
730    span: proc_macro2::Span,
731    /// The expressions passed positionally to the annotated function.
732    args: Vec<Expr>,
733    /// The case label after `;`, if one was given.
734    label: Option<LitStr>,
735}
736
737impl Parse for TestCase {
738    fn parse(input: ParseStream) -> syn::Result<Self> {
739        let span = input.span();
740        let mut args = Vec::new();
741        let mut label = None;
742        while !input.is_empty() {
743            if input.peek(Token![;]) {
744                input.parse::<Token![;]>()?;
745                label = Some(input.parse::<LitStr>()?);
746                if !input.is_empty() {
747                    return Err(input.error("unexpected tokens after the test-case label"));
748                }
749                break;
750            }
751            args.push(input.parse()?);
752            if input.is_empty() || input.peek(Token![;]) {
753                continue;
754            }
755            input.parse::<Token![,]>()?;
756        }
757        Ok(Self { span, args, label })
758    }
759}
760
761/// Reads one `#[test_case]` attribute. A bare `#[test_case]` (no parentheses) is
762/// a zero-argument, unlabeled case; anything else is parsed as a [`TestCase`].
763fn parse_test_case_attr(attribute: &syn::Attribute) -> syn::Result<TestCase> {
764    match &attribute.meta {
765        syn::Meta::Path(_) => Ok(TestCase {
766            span: attribute.span(),
767            args: Vec::new(),
768            label: None,
769        }),
770        _ => attribute.parse_args::<TestCase>(),
771    }
772}
773
774/// Whether an attribute is `#[test_case]` (matched on the final path segment so
775/// a fully qualified `#[test_better::test_case]` is recognized too).
776fn is_test_case_attr(attribute: &syn::Attribute) -> bool {
777    attribute
778        .path()
779        .segments
780        .last()
781        .is_some_and(|segment| segment.ident == "test_case")
782}
783
784/// Turns a case label into a valid, lowercased identifier fragment: every
785/// character that is not ASCII alphanumeric becomes `_`, and a leading digit
786/// gets an `_` prefix. An empty result falls back to `case`.
787fn sanitize_ident(label: &str) -> String {
788    let mut out: String = label
789        .chars()
790        .map(|ch| {
791            if ch.is_ascii_alphanumeric() {
792                ch.to_ascii_lowercase()
793            } else {
794                '_'
795            }
796        })
797        .collect();
798    if out.is_empty() {
799        out.push_str("case");
800    }
801    if out.starts_with(|ch: char| ch.is_ascii_digit()) {
802        out.insert(0, '_');
803    }
804    out
805}
806
807/// Expands the stacked `#[test_case]` attributes on `func` into one `#[test]`
808/// per case, all wrapped in a module named for the function.
809fn test_case_impl(first: TestCase, mut func: ItemFn) -> syn::Result<TokenStream2> {
810    // The topmost `#[test_case]` is handed to us as `attr`; the rest are still
811    // attached to the function. Split the remaining attributes into further
812    // cases and everything else (`#[ignore]`, doc comments, ...), which is
813    // forwarded onto every generated test.
814    let mut cases = vec![first];
815    let mut forwarded = Vec::new();
816    for attribute in std::mem::take(&mut func.attrs) {
817        if is_test_case_attr(&attribute) {
818            cases.push(parse_test_case_attr(&attribute)?);
819        } else {
820            forwarded.push(attribute);
821        }
822    }
823
824    let fn_name = func.sig.ident.clone();
825    let fn_name_str = fn_name.to_string();
826    let ret = func.sig.output.clone();
827    let expected_arity = func.sig.inputs.len();
828    // A `-> ()` (explicit or implicit) test cannot carry failure context; only
829    // a value-returning test (the `-> TestResult` shape) is wrapped.
830    let returns_value = match &func.sig.output {
831        syn::ReturnType::Default => false,
832        syn::ReturnType::Type(_, ty) => {
833            !matches!(&**ty, syn::Type::Tuple(tuple) if tuple.elems.is_empty())
834        }
835    };
836
837    let mut used_names: HashSet<String> = HashSet::new();
838    let mut tests = Vec::with_capacity(cases.len());
839    for (index, case) in cases.iter().enumerate() {
840        if case.args.len() != expected_arity {
841            return Err(syn::Error::new(
842                case.span,
843                format!(
844                    "this `#[test_case]` passes {} argument(s) but `{fn_name_str}` takes {}",
845                    case.args.len(),
846                    expected_arity,
847                ),
848            ));
849        }
850
851        let base = match &case.label {
852            Some(label) => sanitize_ident(&label.value()),
853            None => format!("case_{index}"),
854        };
855        // Two labels that sanitize to the same identifier (or a label that
856        // collides with a `case_N` default) are disambiguated by index.
857        let name = if used_names.contains(&base) {
858            format!("{base}_{index}")
859        } else {
860            base
861        };
862        used_names.insert(name.clone());
863        let test_ident = format_ident!("{name}");
864
865        let args = &case.args;
866        let args_rendered = args
867            .iter()
868            .map(|arg| quote!(#arg).to_string())
869            .collect::<Vec<_>>()
870            .join(", ");
871        let label_part = match &case.label {
872            Some(label) => format!("{:?}", label.value()),
873            None => format!("#{index}"),
874        };
875        let context_msg = format!("test case {label_part}: {fn_name_str}({args_rendered})");
876
877        let body = if returns_value {
878            quote! {
879                ::test_better::ContextExt::context(#fn_name(#(#args),*), #context_msg)
880            }
881        } else {
882            quote! { #fn_name(#(#args),*); }
883        };
884
885        // `pub(super)` so the generated test stays addressable from the file
886        // that wrote the `#[test_case]` (e.g. to drive an `#[ignore]`d failing
887        // case directly), without widening visibility any further.
888        tests.push(quote! {
889            #(#forwarded)*
890            #[test]
891            pub(super) fn #test_ident() #ret {
892                #body
893            }
894        });
895    }
896
897    Ok(quote! {
898        mod #fn_name {
899            #[allow(unused_imports)]
900            use super::*;
901
902            #func
903
904            #(#tests)*
905        }
906    })
907}
908
909/// Generates one `#[test]` per `#[test_case(..)]` line on a function.
910///
911/// Each attribute lists the positional arguments for one run, optionally
912/// followed by `; "label"`. The cases are gathered into a module named for the
913/// annotated function, so a labeled case is addressable as
914/// `the_fn::the_label`; an unlabeled case becomes `the_fn::case_N`. The
915/// original function stays callable inside that module as a helper.
916///
917/// When the function returns a value (the usual `-> TestResult` shape), each
918/// generated test wraps the call in failure context carrying the case label
919/// and the rendered arguments, so a failure names the case that produced it.
920///
921/// ```ignore
922/// use test_better::prelude::*;
923///
924/// #[test_case("alice", 30 ; "common case")]
925/// #[test_case("",      0  ; "empty name")]
926/// fn validates_user(name: &str, age: u32) -> TestResult {
927///     expect!(name.len()).to(ge(age as usize))
928/// }
929/// ```
930///
931/// A `#[test_case]` whose argument count does not match the function's
932/// parameter count is a compile error, as is trailing junk after the label.
933/// Other attributes on the function (`#[ignore]`, doc comments) are forwarded
934/// onto every generated test.
935#[proc_macro_attribute]
936pub fn test_case(attr: TokenStream, item: TokenStream) -> TokenStream {
937    let first = match syn::parse::<TestCase>(attr) {
938        Ok(case) => case,
939        Err(error) => return error.to_compile_error().into(),
940    };
941    let func = match syn::parse::<ItemFn>(item) {
942        Ok(func) => func,
943        Err(error) => return error.to_compile_error().into(),
944    };
945    match test_case_impl(first, func) {
946        Ok(tokens) => tokens.into(),
947        Err(error) => error.to_compile_error().into(),
948    }
949}
950
951/// How long a fixture's value is kept: rebuilt for every test, or built once
952/// per module and shared.
953enum FixtureScope {
954    /// The default: the fixture body runs afresh for each test that uses it.
955    Test,
956    /// The fixture body runs once; every test gets a clone of the cached value.
957    Module,
958}
959
960/// The parsed `#[fixture(..)]` arguments. The only knob is `scope`.
961struct FixtureArgs {
962    scope: FixtureScope,
963}
964
965impl Parse for FixtureArgs {
966    fn parse(input: ParseStream) -> syn::Result<Self> {
967        // A bare `#[fixture]` is per-test scope.
968        if input.is_empty() {
969            return Ok(Self {
970                scope: FixtureScope::Test,
971            });
972        }
973        let key: Ident = input.parse()?;
974        if key != "scope" {
975            return Err(syn::Error::new_spanned(
976                key,
977                "the only `#[fixture]` argument is `scope`",
978            ));
979        }
980        input.parse::<Token![=]>()?;
981        let value: LitStr = input.parse()?;
982        let scope = match value.value().as_str() {
983            "test" => FixtureScope::Test,
984            "module" => FixtureScope::Module,
985            other => {
986                return Err(syn::Error::new_spanned(
987                    value,
988                    format!("unknown fixture scope {other:?}, expected \"test\" or \"module\""),
989                ));
990            }
991        };
992        if !input.is_empty() {
993            return Err(input.error("unexpected tokens after the fixture scope"));
994        }
995        Ok(Self { scope })
996    }
997}
998
999/// Rewrites a `#[fixture]` function into a setup-aware provider.
1000///
1001/// The original body is kept verbatim in a nested `__tb_fixture_impl`; the
1002/// generated outer function (same name, same signature) calls it and, on the
1003/// error path, stamps the failure as `ErrorKind::Setup` so it can never read
1004/// as an assertion miss.
1005fn fixture_impl(args: FixtureArgs, mut func: ItemFn) -> syn::Result<TokenStream2> {
1006    if let Some(param) = func.sig.inputs.first() {
1007        return Err(syn::Error::new_spanned(
1008            param,
1009            "a `#[fixture]` function takes no parameters",
1010        ));
1011    }
1012    let return_ty = match &func.sig.output {
1013        syn::ReturnType::Type(_, ty) => (**ty).clone(),
1014        syn::ReturnType::Default => {
1015            return Err(syn::Error::new_spanned(
1016                &func.sig,
1017                "a `#[fixture]` function must return a `TestResult<T>`",
1018            ));
1019        }
1020    };
1021
1022    // Everything attached to the function (`#[ignore]` makes no sense here, but
1023    // `#[allow(..)]`, `#[cfg(..)]`, doc comments do) is forwarded onto the
1024    // generated provider; it also covers the nested impl, which sits inside it.
1025    let forwarded: Vec<syn::Attribute> = std::mem::take(&mut func.attrs);
1026
1027    let fn_name = func.sig.ident.clone();
1028    let fn_name_str = fn_name.to_string();
1029    let vis = func.vis.clone();
1030    let ret = func.sig.output.clone();
1031    let body = &func.block;
1032    let context_msg = format!("setting up fixture `{fn_name_str}`");
1033
1034    let impl_fn = quote! {
1035        fn __tb_fixture_impl() #ret #body
1036    };
1037
1038    let outer_body = match args.scope {
1039        // Per-test: run the body, re-categorize any failure as `Setup` and add
1040        // a frame naming the fixture. The successful value is moved straight
1041        // out, so the fixture type need not be `Clone`.
1042        FixtureScope::Test => quote! {
1043            #impl_fn
1044            ::core::result::Result::map_err(__tb_fixture_impl(), |__tb_error| {
1045                ::test_better::TestError::with_context_frame(
1046                    ::test_better::TestError::with_kind(
1047                        __tb_error,
1048                        ::test_better::ErrorKind::Setup,
1049                    ),
1050                    ::test_better::ContextFrame::new(#context_msg),
1051                )
1052            })
1053        },
1054        // Module: build once into a `LazyLock`, then hand every caller a clone.
1055        // The cached `Err` cannot be moved out, so the error path synthesizes a
1056        // fresh `Setup` failure carrying the original's rendered text.
1057        FixtureScope::Module => quote! {
1058            #impl_fn
1059            static __TB_FIXTURE_CELL: ::std::sync::LazyLock<#return_ty> =
1060                ::std::sync::LazyLock::new(__tb_fixture_impl);
1061            match &*__TB_FIXTURE_CELL {
1062                ::core::result::Result::Ok(__tb_value) => {
1063                    ::core::result::Result::Ok(::core::clone::Clone::clone(__tb_value))
1064                }
1065                ::core::result::Result::Err(__tb_error) => {
1066                    ::core::result::Result::Err(
1067                        ::test_better::TestError::with_message(
1068                            ::test_better::TestError::new(
1069                                ::test_better::ErrorKind::Setup,
1070                            ),
1071                            ::std::format!(
1072                                "module-scoped fixture `{}` failed during setup: {}",
1073                                #fn_name_str,
1074                                __tb_error,
1075                            ),
1076                        ),
1077                    )
1078                }
1079            }
1080        },
1081    };
1082
1083    Ok(quote! {
1084        #(#forwarded)*
1085        #vis fn #fn_name() #ret {
1086            #outer_body
1087        }
1088    })
1089}
1090
1091/// Marks a `fn() -> TestResult<T>` as a fixture: a reusable piece of test setup
1092/// whose failures surface as `ErrorKind::Setup`, never as assertion misses.
1093///
1094/// A fixture is consumed by [`macro@test_with_fixtures`]: a test parameter
1095/// `name: T` is filled by the same-named fixture `fn name()`.
1096///
1097/// By default a fixture is per-test (`#[fixture]` or `#[fixture(scope =
1098/// "test")]`): the body runs afresh for every test, and the value is moved
1099/// straight out, so `T` need not be `Clone`. With `#[fixture(scope = "module")]`
1100/// the body runs once and every test gets a clone of the cached value, so `T`
1101/// must be `Clone + Send + Sync + 'static`.
1102///
1103/// ```ignore
1104/// use test_better::prelude::*;
1105///
1106/// #[fixture]
1107/// fn answer() -> TestResult<i32> {
1108///     Ok(42)
1109/// }
1110///
1111/// #[test_with_fixtures]
1112/// fn uses_the_answer(answer: i32) -> TestResult {
1113///     expect!(answer).to(eq(42))
1114/// }
1115/// ```
1116#[proc_macro_attribute]
1117pub fn fixture(attr: TokenStream, item: TokenStream) -> TokenStream {
1118    let args = match syn::parse::<FixtureArgs>(attr) {
1119        Ok(args) => args,
1120        Err(error) => return error.to_compile_error().into(),
1121    };
1122    let func = match syn::parse::<ItemFn>(item) {
1123        Ok(func) => func,
1124        Err(error) => return error.to_compile_error().into(),
1125    };
1126    match fixture_impl(args, func) {
1127        Ok(tokens) => tokens.into(),
1128        Err(error) => error.to_compile_error().into(),
1129    }
1130}
1131
1132/// Rewrites a parameterized test into a zero-argument `#[test]` that resolves
1133/// each parameter from its fixture before calling the original body.
1134fn test_with_fixtures_impl(mut func: ItemFn) -> syn::Result<TokenStream2> {
1135    let mut params = Vec::with_capacity(func.sig.inputs.len());
1136    for input in &func.sig.inputs {
1137        match input {
1138            FnArg::Receiver(receiver) => {
1139                return Err(syn::Error::new_spanned(
1140                    receiver,
1141                    "a `#[test_with_fixtures]` function cannot take `self`",
1142                ));
1143            }
1144            FnArg::Typed(pat_type) => match &*pat_type.pat {
1145                Pat::Ident(pat_ident) => params.push(pat_ident.ident.clone()),
1146                other => {
1147                    return Err(syn::Error::new_spanned(
1148                        other,
1149                        "each `#[test_with_fixtures]` parameter must be a plain \
1150                         `name: Type`, where `name` is the fixture function",
1151                    ));
1152                }
1153            },
1154        }
1155    }
1156
1157    // Everything on the function (`#[ignore]`, doc comments, ...) is forwarded
1158    // onto the generated `#[test]`.
1159    let forwarded: Vec<syn::Attribute> = std::mem::take(&mut func.attrs);
1160    let fn_name = func.sig.ident.clone();
1161    let vis = func.vis.clone();
1162    let ret = func.sig.output.clone();
1163
1164    // The original function, renamed, becomes a nested helper the generated
1165    // test calls once every fixture has been resolved.
1166    let mut inner = func;
1167    inner.sig.ident = format_ident!("__tb_inner");
1168    inner.vis = syn::Visibility::Inherited;
1169
1170    Ok(quote! {
1171        #(#forwarded)*
1172        #[test]
1173        #vis fn #fn_name() #ret {
1174            #inner
1175            #( let #params = #params()?; )*
1176            __tb_inner(#(#params),*)
1177        }
1178    })
1179}
1180
1181/// Turns a test whose parameters are fixtures into a runnable `#[test]`.
1182///
1183/// Each parameter `name: T` is resolved by calling the same-named [`macro@fixture`]
1184/// function `fn name() -> TestResult<T>` and `?`-propagating its result, so a
1185/// fixture failure aborts the test as an `ErrorKind::Setup` error before the
1186/// body runs. The parameters are resolved left to right.
1187///
1188/// Because the resolved fixtures are `?`-propagated, the test must return a
1189/// type that `?` accepts, the usual `-> TestResult` shape.
1190///
1191/// ```ignore
1192/// use test_better::prelude::*;
1193///
1194/// #[fixture]
1195/// fn name() -> TestResult<String> {
1196///     Ok(String::from("alice"))
1197/// }
1198///
1199/// #[test_with_fixtures]
1200/// fn greets_by_name(name: String) -> TestResult {
1201///     expect!(name.as_str()).to(eq("alice"))
1202/// }
1203/// ```
1204#[proc_macro_attribute]
1205pub fn test_with_fixtures(_attr: TokenStream, item: TokenStream) -> TokenStream {
1206    let func = match syn::parse::<ItemFn>(item) {
1207        Ok(func) => func,
1208        Err(error) => return error.to_compile_error().into(),
1209    };
1210    match test_with_fixtures_impl(func) {
1211        Ok(tokens) => tokens.into(),
1212        Err(error) => error.to_compile_error().into(),
1213    }
1214}