test_that_macro/lib.rs
1// Copyright 2022 Google LLC
2// Copyright 2026 Bradford Hovinen <bradford@hovinen.me>
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use quote::quote;
17use syn::{Attribute, ItemFn, ReturnType, parse_macro_input};
18
19/// Marks a test which may have non fatal assertions.
20///
21/// Annotate tests the same way ordinary Rust tests are annotated:
22///
23/// ```ignore
24/// #[test_that::test]
25/// fn should_work() {
26/// ...
27/// }
28/// ```
29///
30/// The test function is not required to have a return type. If it does have a
31/// return type, that type must be [`test_that::Result`]. One may do this if
32/// one wishes to use both fatal and non-fatal assertions in the same test. For
33/// example:
34///
35/// ```
36/// # use test_that::prelude::*;
37/// #[test_that::test]
38/// fn should_work() -> TestResult<()> {
39/// let value = 2;
40/// expect_that!(value, gt(0));
41/// verify_that!(value, eq(2))
42/// }
43/// ```
44///
45/// This macro can be used with `#[should_panic]` to indicate that the test is
46/// expected to panic. For example:
47///
48/// ```
49/// # use test_that::prelude::*;
50/// #[test_that::test]
51/// #[should_panic]
52/// fn passes_due_to_should_panic() {
53/// let value = 2;
54/// expect_that!(value, gt(0));
55/// panic!("This panics");
56/// }
57/// ```
58///
59/// Using `#[should_panic]` modifies the behaviour of `#[test_that::test]` so
60/// that the test panics (and passes) if any non-fatal assertion occurs.
61/// For example, the following test passes:
62///
63/// ```
64/// # use test_that::prelude::*;
65/// #[test_that::test]
66/// #[should_panic]
67/// fn passes_due_to_should_panic_and_failing_assertion() {
68/// let value = 2;
69/// expect_that!(value, eq(0));
70/// }
71/// ```
72///
73/// This integrates with other common test attribute macros such as
74/// [`tokio::test`] and [`rstest`]. Just apply both attribute macros to your
75/// test.
76///
77/// ```ignore
78/// #[test_that::test]
79/// #[rstest]
80/// #[case(1)]
81/// #[case(2)]
82/// #[case(3)]
83/// fn rstest_works_with_test_that(#[case] value: u32) -> Result<()> {
84/// verify_that!(value, gt(0))
85/// }
86///
87/// #[test_that::test]
88/// #[tokio::test]
89/// async fn tokio_works_with_test_that() -> Result<()> {
90/// verify_that!(get_some_value_async().await, gt(0))
91/// }
92/// ```
93///
94/// > **Note:**
95/// > In the case of rstest, make sure to put `#[test_that::test]` *before*
96/// > `#[rstest]`. Otherwise the annotated test will run twice, since both
97/// > macros will
98/// > attempt to register a test with the Rust test harness.
99///
100/// [`test_that::Result`]: type.Result.html
101/// [`tokio::test`]: https://docs.rs/tokio/latest/tokio/attr.test.html
102/// [`rstest`]: https://docs.rs/rstest/latest/rstest/attr.rstest.html
103#[proc_macro_attribute]
104pub fn test(
105 _args: proc_macro::TokenStream,
106 input: proc_macro::TokenStream,
107) -> proc_macro::TokenStream {
108 let mut parsed_fn = parse_macro_input!(input as ItemFn);
109 let attrs = parsed_fn.attrs.drain(..).collect::<Vec<_>>();
110 let (mut sig, block) = (parsed_fn.sig, parsed_fn.block);
111 let (outer_return_type, trailer) =
112 if attrs.iter().any(|attr| attr.path().is_ident("should_panic")) {
113 (quote! { () }, quote! { .unwrap(); })
114 } else {
115 (
116 quote! { std::result::Result<(), test_that::internal::test_outcome::TestFailure> },
117 quote! {},
118 )
119 };
120 let output_type = match sig.output.clone() {
121 ReturnType::Type(_, output_type) => Some(output_type),
122 ReturnType::Default => None,
123 };
124 sig.output = ReturnType::Default;
125 let (maybe_closure, invocation) = if sig.asyncness.is_some() {
126 (
127 // In the async case, the ? operator returns from the *block* rather than the
128 // surrounding function. So we just put the test content in an async block. Async
129 // closures are still unstable (see https://github.com/rust-lang/rust/issues/62290),
130 // so we can't use the same solution as the sync case below.
131 quote! {},
132 quote! {
133 async { #block }.await
134 },
135 )
136 } else {
137 (
138 // In the sync case, the ? operator returns from the surrounding function. So we must
139 // create a separate closure from which the ? operator can return in order to capture
140 // the output.
141 quote! {
142 let test = move || #block;
143 },
144 quote! {
145 test()
146 },
147 )
148 };
149 let function = if let Some(output_type) = output_type {
150 quote! {
151 #(#attrs)*
152 #sig -> #outer_return_type {
153 #maybe_closure
154 test_that::internal::test_outcome::TestOutcome::init_current_test_outcome();
155 let result: #output_type = #invocation;
156 test_that::internal::test_outcome::TestOutcome::close_current_test_outcome(result)
157 #trailer
158 }
159 }
160 } else {
161 quote! {
162 #(#attrs)*
163 #sig -> #outer_return_type {
164 #maybe_closure
165 test_that::internal::test_outcome::TestOutcome::init_current_test_outcome();
166 #invocation;
167 test_that::internal::test_outcome::TestOutcome::close_current_test_outcome(test_that::TestResult::Ok(()))
168 #trailer
169 }
170 }
171 };
172 let output = if attrs.iter().any(is_test_attribute) {
173 function
174 } else {
175 quote! {
176 #[::core::prelude::v1::test]
177 #function
178 }
179 };
180 output.into()
181}
182
183fn is_test_attribute(attr: &Attribute) -> bool {
184 let first_segment = match attr.path().segments.first() {
185 Some(first_segment) => first_segment,
186 None => return false,
187 };
188 let last_segment = match attr.path().segments.last() {
189 Some(last_segment) => last_segment,
190 None => return false,
191 };
192 last_segment.ident == "test"
193 || (first_segment.ident == "rstest"
194 && last_segment.ident == "rstest"
195 && attr.path().segments.len() <= 2)
196}