Skip to main content

test_that_macro/
lib.rs

1// Copyright 2022 Google LLC
2// Copyright 2026 Bradford Hovinen <bradford@hovinen.me>
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use quote::quote;
17use syn::{Attribute, ItemFn, ReturnType, parse_macro_input};
18
19/// Marks a test which may have non fatal assertions.
20///
21/// Annotate tests the same way ordinary Rust tests are annotated:
22///
23/// ```ignore
24/// #[test_that::test]
25/// fn should_work() {
26///    ...
27/// }
28/// ```
29///
30/// The test function is not required to have a return type. If it does have a
31/// return type, that type must be [`test_that::Result`]. One may do this if
32/// one wishes to use both fatal and non-fatal assertions in the same test. For
33/// example:
34///
35/// ```
36/// # use test_that::prelude::*;
37/// #[test_that::test]
38/// fn should_work() -> TestResult<()> {
39///     let value = 2;
40///     expect_that!(value, gt(0));
41///     verify_that!(value, eq(2))
42/// }
43/// ```
44///
45/// This macro can be used with `#[should_panic]` to indicate that the test is
46/// expected to panic. For example:
47///
48/// ```
49/// # use test_that::prelude::*;
50/// #[test_that::test]
51/// #[should_panic]
52/// fn passes_due_to_should_panic() {
53///     let value = 2;
54///     expect_that!(value, gt(0));
55///     panic!("This panics");
56/// }
57/// ```
58///
59/// Using `#[should_panic]` modifies the behaviour of `#[test_that::test]` so
60/// that the test panics (and passes) if any non-fatal assertion occurs.
61/// For example, the following test passes:
62///
63/// ```
64/// # use test_that::prelude::*;
65/// #[test_that::test]
66/// #[should_panic]
67/// fn passes_due_to_should_panic_and_failing_assertion() {
68///     let value = 2;
69///     expect_that!(value, eq(0));
70/// }
71/// ```
72///
73/// This integrates with other common test attribute macros such as
74/// [`tokio::test`] and [`rstest`]. Just apply both attribute macros to your
75/// test.
76///
77/// ```ignore
78/// #[test_that::test]
79/// #[rstest]
80/// #[case(1)]
81/// #[case(2)]
82/// #[case(3)]
83/// fn rstest_works_with_test_that(#[case] value: u32) -> Result<()> {
84///     verify_that!(value, gt(0))
85/// }
86///
87/// #[test_that::test]
88/// #[tokio::test]
89/// async fn tokio_works_with_test_that() -> Result<()> {
90///     verify_that!(get_some_value_async().await, gt(0))
91/// }
92/// ```
93///
94/// > **Note:**
95/// > In the case of rstest, make sure to put `#[test_that::test]` *before*
96/// > `#[rstest]`. Otherwise the annotated test will run twice, since both
97/// > macros will
98/// > attempt to register a test with the Rust test harness.
99///
100/// [`test_that::Result`]: type.Result.html
101/// [`tokio::test`]: https://docs.rs/tokio/latest/tokio/attr.test.html
102/// [`rstest`]: https://docs.rs/rstest/latest/rstest/attr.rstest.html
103#[proc_macro_attribute]
104pub fn test(
105    _args: proc_macro::TokenStream,
106    input: proc_macro::TokenStream,
107) -> proc_macro::TokenStream {
108    let mut parsed_fn = parse_macro_input!(input as ItemFn);
109    let attrs = parsed_fn.attrs.drain(..).collect::<Vec<_>>();
110    let (mut sig, block) = (parsed_fn.sig, parsed_fn.block);
111    let (outer_return_type, trailer) =
112        if attrs.iter().any(|attr| attr.path().is_ident("should_panic")) {
113            (quote! { () }, quote! { .unwrap(); })
114        } else {
115            (
116                quote! { std::result::Result<(), test_that::internal::test_outcome::TestFailure> },
117                quote! {},
118            )
119        };
120    let output_type = match sig.output.clone() {
121        ReturnType::Type(_, output_type) => Some(output_type),
122        ReturnType::Default => None,
123    };
124    sig.output = ReturnType::Default;
125    let (maybe_closure, invocation) = if sig.asyncness.is_some() {
126        (
127            // In the async case, the ? operator returns from the *block* rather than the
128            // surrounding function. So we just put the test content in an async block. Async
129            // closures are still unstable (see https://github.com/rust-lang/rust/issues/62290),
130            // so we can't use the same solution as the sync case below.
131            quote! {},
132            quote! {
133                async { #block }.await
134            },
135        )
136    } else {
137        (
138            // In the sync case, the ? operator returns from the surrounding function. So we must
139            // create a separate closure from which the ? operator can return in order to capture
140            // the output.
141            quote! {
142                let test = move || #block;
143            },
144            quote! {
145                test()
146            },
147        )
148    };
149    let function = if let Some(output_type) = output_type {
150        quote! {
151            #(#attrs)*
152            #sig -> #outer_return_type {
153                #maybe_closure
154                test_that::internal::test_outcome::TestOutcome::init_current_test_outcome();
155                let result: #output_type = #invocation;
156                test_that::internal::test_outcome::TestOutcome::close_current_test_outcome(result)
157                #trailer
158            }
159        }
160    } else {
161        quote! {
162            #(#attrs)*
163            #sig -> #outer_return_type {
164                #maybe_closure
165                test_that::internal::test_outcome::TestOutcome::init_current_test_outcome();
166                #invocation;
167                test_that::internal::test_outcome::TestOutcome::close_current_test_outcome(test_that::TestResult::Ok(()))
168                #trailer
169            }
170        }
171    };
172    let output = if attrs.iter().any(is_test_attribute) {
173        function
174    } else {
175        quote! {
176            #[::core::prelude::v1::test]
177            #function
178        }
179    };
180    output.into()
181}
182
183fn is_test_attribute(attr: &Attribute) -> bool {
184    let first_segment = match attr.path().segments.first() {
185        Some(first_segment) => first_segment,
186        None => return false,
187    };
188    let last_segment = match attr.path().segments.last() {
189        Some(last_segment) => last_segment,
190        None => return false,
191    };
192    last_segment.ident == "test"
193        || (first_segment.ident == "rstest"
194            && last_segment.ident == "rstest"
195            && attr.path().segments.len() <= 2)
196}