test_helpers/
lib.rs

1//
2//! Some setup and teardown macro helpers to mimic [Jest's setup and teardown](https://jestjs.io/docs/setup-teardown)
3//! functionality. Also includes a `skip` macro that mimics the [skip](https://jestjs.io/docs/api#testskipname-fn)
4//! functionality in Jest.
5//!
6//! There are currently five macros provided: `after_all`,
7//! `after_each`, `before_all`, `before_each`, and `skip`. I would like to implement `only` to
8//! match [Jest's only](https://jestjs.io/docs/api#testonlyname-fn-timeout) functionality. I'm
9//! unsure of a great way to do that currently, however.
10//!
11//! ## Getting Started
12//! Using these macros is fairly simple. The four after/before functions all require a function
13//! with the same name as the attribute and are only valid when applied to a mod. They are all used
14//! like in the below example. Replace `before_each` with whichever method you want to use. The
15//! code in the matching function will be inserted into every fn in the containing mod that has an
16//! attribute with the word "test" in it. This is to allow for use with not just normal `#[test]`
17//! attributes, but also other flavors like `#[tokio::test]` and `#[test_case(0)]`.
18//! ```
19//! #[cfg(test)]
20//! use test_env_helpers::*;
21//!
22//! #[before_each]
23//! #[cfg(test)]
24//! mod my_tests{
25//!     fn before_each(){println!("I'm in every test!")}
26//!     #[test]
27//!     fn test_1(){}
28//!     #[test]
29//!     fn test_2(){}
30//!     #[test]
31//!     fn test_3(){}
32//! }
33//! ```
34//!
35//! The `skip` macro is valid on either a mod or an individual test and will remove the mod or test
36//! it is applied to. You can use it to skip tests that aren't working correctly or that you don't
37//! want to run for some reason.
38//!
39//! ```
40//! #[cfg(test)]
41//! use test_env_helpers::*;
42//!
43//! #[cfg(test)]
44//! mod my_tests{
45//!     #[skip]
46//!     #[test]
47//!     fn broken_test(){panic!("I'm hella broke")}
48//!     #[skip]
49//!     mod broken_mod{
50//!         #[test]
51//!         fn i_will_not_be_run(){panic!("I get skipped too")}
52//!     }
53//!     #[test]
54//!     fn test_2(){}
55//!     #[test]
56//!     fn test_3(){}
57//! }
58//! ```
59
60extern crate proc_macro;
61mod utils;
62
63use crate::utils::traverse_use_item;
64
65use proc_macro::TokenStream;
66use quote::quote;
67use syn::Stmt;
68use syn::Item;
69use syn::parse_quote;
70use syn::parse_macro_input;
71
72/// Will run the code in the matching `after_all` function exactly once when all of the tests have
73/// run. This works by counting the number of `#[test]` attributes and decrementing a counter at
74/// the beginning of every test. Once the counter reaches 0, it will run the code in `after_all`.
75/// It uses [std::sync::Once](https://doc.rust-lang.org/std/sync/struct.Once.html) internally
76/// to ensure that the code is run at maximum one time.
77///
78/// ```
79/// #[cfg(test)]
80/// use test_env_helpers::*;
81///
82/// #[after_all]
83/// #[cfg(test)]
84/// mod my_tests{
85///     fn after_all(){println!("I only get run once at the very end")}
86///     #[test]
87///     fn test_1(){}
88///     #[test]
89///     fn test_2(){}
90///     #[test]
91///     fn test_3(){}
92/// }
93/// ```
94#[proc_macro_attribute]
95pub fn after_all(_metadata: TokenStream, input: TokenStream) -> TokenStream {
96    let input: Item = match parse_macro_input!(input as Item) {
97        Item::Mod(mut m) => {
98            let (brace, items) = m.content.unwrap();
99            let (after_all_fn, everything_else): (Vec<Item>, Vec<Item>) =
100                items.into_iter().partition(|t| match t {
101                    Item::Fn(f) => f.sig.ident == "after_all",
102                    _ => false,
103                });
104            let after_all_fn_block = if after_all_fn.len() != 1 {
105                panic!("The `after_all` macro attribute requires a single function named `after_all` in the body of the module it is called on.")
106            } else {
107                match after_all_fn.into_iter().next().unwrap() {
108                    Item::Fn(f) => f.block,
109                    _ => unreachable!(),
110                }
111            };
112            let after_all_if: Stmt = parse_quote! {
113                if REMAINING_TESTS.fetch_sub(1, Ordering::SeqCst) == 1 {
114                    AFTER_ALL.call_once(|| {
115                        #after_all_fn_block
116                    });
117                }
118            };
119            let resume: Stmt = parse_quote! {
120                if let Err(err) = result {
121                    panic::resume_unwind(err);
122                }
123            };
124            let mut count: usize = 0;
125            let mut has_once: bool = false;
126            let mut has_atomic_usize: bool = false;
127            let mut has_ordering: bool = false;
128            let mut has_panic: bool = false;
129
130            let mut e: Vec<Item> = everything_else
131                .into_iter()
132                .map(|t| match t {
133                    Item::Fn(mut f) => {
134                        let test_count = f
135                            .attrs
136                            .iter()
137                            .filter(|attr| {
138                                attr.path
139                                    .segments
140                                    .iter()
141                                    .any(|segment| segment.ident.to_string().contains("test"))
142                            })
143                            .count();
144                        if test_count > 0 {
145                            count += test_count;
146                            let block = f.block.clone();
147                            let catch_unwind: Stmt = parse_quote! {
148                                let result = panic::catch_unwind(|| {
149                                    #block
150                                });
151                            };
152                            f.block.stmts = vec![catch_unwind, after_all_if.clone(), resume.clone()];
153                            Item::Fn(f)
154                        } else {
155                            Item::Fn(f)
156                        }
157                    }
158                    Item::Use(use_stmt) => {
159                        if traverse_use_item(&use_stmt.tree, vec!["std", "sync", "Once"]).is_some()
160                        {
161                            has_once = true;
162                        }
163                        if traverse_use_item(
164                            &use_stmt.tree,
165                            vec!["std", "sync", "atomic", "AtomicUsize"],
166                        )
167                        .is_some()
168                        {
169                            has_atomic_usize = true;
170                        }
171                        if traverse_use_item(
172                            &use_stmt.tree,
173                            vec!["std", "sync", "atomic", "Ordering"],
174                        )
175                        .is_some()
176                        {
177                            has_ordering = true;
178                        }
179                        if traverse_use_item(
180                            &use_stmt.tree,
181                            vec!["std", "panic"],
182                        )
183                        .is_some() {
184                            has_panic = true;
185                        }
186                        Item::Use(use_stmt)
187                    }
188                    el => el,
189                })
190                .collect();
191
192            let use_once: Item = parse_quote!(
193                use std::sync::Once;
194            );
195            let use_atomic_usize: Item = parse_quote!(
196                use std::sync::atomic::AtomicUsize;
197            );
198            let use_ordering: Item = parse_quote!(
199                use std::sync::atomic::Ordering;
200            );
201            let use_panic: Item = parse_quote!(
202                use std::panic;
203            );
204            let static_once: Item = parse_quote!(
205                static AFTER_ALL: Once = Once::new();
206            );
207            let static_count: Item = parse_quote!(
208                static REMAINING_TESTS: AtomicUsize = AtomicUsize::new(#count);
209            );
210
211            let mut once_content = vec![];
212
213            if !has_once {
214                once_content.push(use_once);
215            }
216            if !has_atomic_usize {
217                once_content.push(use_atomic_usize);
218            }
219            if !has_ordering {
220                once_content.push(use_ordering);
221            }
222            if !has_panic {
223                once_content.push(use_panic);
224            }
225            once_content.append(&mut vec![static_once, static_count]);
226            once_content.append(&mut e);
227
228            m.content = Some((brace, once_content));
229            Item::Mod(m)
230        }
231        _ => {
232            panic!("The `after_all` macro attribute is only valid when called on a module.")
233        }
234    };
235    TokenStream::from(quote! (#input))
236}
237
238/// Will run the code in the matching `after_each` function at the end of every `#[test]` function.
239/// Useful if you want to cleanup after a test or reset some external state. If the test panics,
240/// this code will not be run. If you need something that is infallible, you should use
241/// `before_each` instead.
242/// ```
243/// #[cfg(test)]
244/// use test_env_helpers::*;
245///
246/// #[after_each]
247/// #[cfg(test)]
248/// mod my_tests{
249///     fn after_each(){println!("I get run at the very end of each function")}
250///     #[test]
251///     fn test_1(){}
252///     #[test]
253///     fn test_2(){}
254///     #[test]
255///     fn test_3(){}
256/// }
257/// ```
258#[proc_macro_attribute]
259pub fn after_each(_metadata: TokenStream, input: TokenStream) -> TokenStream {
260    let input: Item = match parse_macro_input!(input as Item) {
261        Item::Mod(mut m) => {
262            let (brace, items) = m.content.unwrap();
263            let (after_each_fn, everything_else): (Vec<Item>, Vec<Item>) =
264                items.into_iter().partition(|t| match t {
265                    Item::Fn(f) => f.sig.ident == "after_each",
266                    _ => false,
267                });
268            let after_each_fn_block = if after_each_fn.len() != 1 {
269                panic!("The `after_each` macro attribute requires a single function named `after_each` in the body of the module it is called on.")
270            } else {
271                match after_each_fn.into_iter().next().unwrap() {
272                    Item::Fn(f) => f.block,
273                    _ => unreachable!(),
274                }
275            };
276
277            let e: Vec<Item> = everything_else
278                .into_iter()
279                .map(|t| match t {
280                    Item::Fn(mut f) => {
281                        if f.attrs.iter().any(|attr| {
282                            attr.path
283                                .segments
284                                .iter()
285                                .any(|segment| segment.ident.to_string().contains("test"))
286                        }) {
287                            f.block.stmts.append(&mut after_each_fn_block.stmts.clone());
288                            Item::Fn(f)
289                        } else {
290                            Item::Fn(f)
291                        }
292                    }
293                    e => e,
294                })
295                .collect();
296            m.content = Some((brace, e));
297            Item::Mod(m)
298        }
299
300        _ => {
301            panic!("The `after_each` macro attribute is only valid when called on a module.")
302        }
303    };
304    TokenStream::from(quote! {#input})
305}
306
307/// Will run the code in the matching `before_all` function exactly once at the very beginning of a
308/// test run. It uses [std::sync::Once](https://doc.rust-lang.org/std/sync/struct.Once.html) internally
309/// to ensure that the code is run at maximum one time. Useful for setting up some external state
310/// that will be reused in multiple tests.
311/// ```
312/// #[cfg(test)]
313/// use test_env_helpers::*;
314///
315/// #[before_all]
316/// #[cfg(test)]
317/// mod my_tests{
318///     fn before_all(){println!("I get run at the very beginning of the test suite")}
319///     #[test]
320///     fn test_1(){}
321///     #[test]
322///     fn test_2(){}
323///     #[test]
324///     fn test_3(){}
325/// }
326/// ```
327#[proc_macro_attribute]
328pub fn before_all(_metadata: TokenStream, input: TokenStream) -> TokenStream {
329    let input: Item = match parse_macro_input!(input as Item) {
330        Item::Mod(mut m) => {
331            let (brace, items) = m.content.unwrap();
332            let (before_all_fn, everything_else): (Vec<Item>, Vec<Item>) =
333                items.into_iter().partition(|t| match t {
334                    Item::Fn(f) => f.sig.ident == "before_all",
335                    _ => false,
336                });
337            let before_all_fn_block = if before_all_fn.len() != 1 {
338                panic!("The `before_all` macro attribute requires a single function named `before_all` in the body of the module it is called on.")
339            } else {
340                match before_all_fn.into_iter().next().unwrap() {
341                    Item::Fn(f) => f.block,
342                    _ => unreachable!(),
343                }
344            };
345            let q: Stmt = parse_quote! {
346                BEFORE_ALL.call_once(|| {
347                    #before_all_fn_block
348                });
349            };
350
351            let mut has_once: bool = false;
352            let mut e: Vec<Item> = everything_else
353                .into_iter()
354                .map(|t| match t {
355                    Item::Fn(mut f) => {
356                        if f.attrs.iter().any(|attr| {
357                            attr.path
358                                .segments
359                                .iter()
360                                .any(|segment| segment.ident.to_string().contains("test"))
361                        }) {
362                            let mut stmts = vec![q.clone()];
363                            stmts.append(&mut f.block.stmts);
364                            f.block.stmts = stmts;
365                            Item::Fn(f)
366                        } else {
367                            Item::Fn(f)
368                        }
369                    }
370                    Item::Use(use_stmt) => {
371                        if traverse_use_item(&use_stmt.tree, vec!["std", "sync", "Once"]).is_some()
372                        {
373                            has_once = true;
374                        }
375                        Item::Use(use_stmt)
376                    }
377                    e => e,
378                })
379                .collect();
380            let use_once: Item = parse_quote!(
381                use std::sync::Once;
382            );
383            let static_once: Item = parse_quote!(
384                static BEFORE_ALL: Once = Once::new();
385            );
386
387            let mut once_content = vec![];
388            if !has_once {
389                once_content.push(use_once);
390            }
391            once_content.push(static_once);
392            once_content.append(&mut e);
393
394            m.content = Some((brace, once_content));
395            Item::Mod(m)
396        }
397
398        _ => {
399            panic!("The `before_all` macro attribute is only valid when called on a module.")
400        }
401    };
402    TokenStream::from(quote! (#input))
403}
404
405/// Will run the code in the matching `before_each` function at the beginning of every test. Useful
406/// to reset state to ensure that a test has a clean slate.
407/// ```
408/// #[cfg(test)]
409/// use test_env_helpers::*;
410///
411/// #[before_each]
412/// #[cfg(test)]
413/// mod my_tests{
414///     fn before_each(){println!("I get run at the very beginning of every test")}
415///     #[test]
416///     fn test_1(){}
417///     #[test]
418///     fn test_2(){}
419///     #[test]
420///     fn test_3(){}
421/// }
422/// ```
423///
424/// Can be used to reduce the amount of boilerplate setup code that needs to be copied into each test.
425/// For example, if you need to ensure that tests in a single test suite are not run in parallel, this can
426/// easily be done with a [Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html).
427/// However, remembering to copy and paste the code to acquire a lock on the `Mutex` in every test
428/// is tedious and error prone.
429/// ```
430/// #[cfg(test)]
431/// mod without_before_each{
432///     lazy_static! {
433///         static ref MTX: Mutex<()> = Mutex::new(());
434///     }
435///     #[test]
436///     fn test_1(){let _m = MTX.lock();}
437///     #[test]
438///     fn test_2(){let _m = MTX.lock();}
439///     #[test]
440///     fn test_3(){let _m = MTX.lock();}
441/// }
442/// ```
443/// Using `before_each` removes the need to copy and paste so much and makes making changes easier
444/// because they only need to be made in a single location instead of once for every test.
445/// ```
446/// #[cfg(test)]
447/// use test_env_helpers::*;
448///
449/// #[before_each]
450/// #[cfg(test)]
451/// mod with_before_each{
452///     lazy_static! {
453///         static ref MTX: Mutex<()> = Mutex::new(());
454///     }
455///     fn before_each(){let _m = MTX.lock();}
456///     #[test]
457///     fn test_1(){}
458///     #[test]
459///     fn test_2(){}
460///     #[test]
461///     fn test_3(){}
462/// }
463/// ```
464#[proc_macro_attribute]
465pub fn before_each(_metadata: TokenStream, input: TokenStream) -> TokenStream {
466    let input: Item = match parse_macro_input!(input as Item) {
467        Item::Mod(mut m) => {
468            let (brace, items) = m.content.unwrap();
469            let (before_each_fn, everything_else): (Vec<Item>, Vec<Item>) =
470                items.into_iter().partition(|t| match t {
471                    Item::Fn(f) => f.sig.ident == "before_each",
472                    _ => false,
473                });
474            let before_each_fn_block = if before_each_fn.len() != 1 {
475                panic!("The `before_each` macro attribute requires a single function named `before_each` in the body of the module it is called on.")
476            } else {
477                match before_each_fn.into_iter().next().unwrap() {
478                    Item::Fn(f) => f.block,
479                    _ => unreachable!(),
480                }
481            };
482
483            let e: Vec<Item> = everything_else
484                .into_iter()
485                .map(|t| match t {
486                    Item::Fn(mut f) => {
487                        if f.attrs.iter().any(|attr| {
488                            attr.path
489                                .segments
490                                .iter()
491                                .any(|segment| segment.ident.to_string().contains("test"))
492                        }) {
493                            let mut b = before_each_fn_block.stmts.clone();
494                            b.append(&mut f.block.stmts);
495                            f.block.stmts = b;
496                            Item::Fn(f)
497                        } else {
498                            Item::Fn(f)
499                        }
500                    }
501                    e => e,
502                })
503                .collect();
504            m.content = Some((brace, e));
505            Item::Mod(m)
506        }
507
508        _ => {
509            panic!("The `before_each` macro attribute is only valid when called on a module.")
510        }
511    };
512    TokenStream::from(quote! {#input})
513}
514
515/// Will skip running the code it is applied on. You can use it to skip tests that aren't working
516/// correctly or that you don't want to run for some reason. There are no checks to make sure it's
517/// applied to a `#[test]` or mod. It will remove whatever it is applied to from the final AST.
518///
519/// ```
520/// #[cfg(test)]
521/// use test_env_helpers::*;
522///
523/// #[cfg(test)]
524/// mod my_tests{
525///     #[skip]
526///     #[test]
527///     fn broken_test(){panic!("I'm hella broke")}
528///     #[skip]
529///     mod broken_mod{
530///         #[test]
531///         fn i_will_not_be_run(){panic!("I get skipped too")}
532///     }
533///     #[test]
534///     fn test_2(){}
535///     #[test]
536///     fn test_3(){}
537/// }
538/// ```
539#[proc_macro_attribute]
540pub fn skip(_metadata: TokenStream, _input: TokenStream) -> TokenStream {
541    TokenStream::from(quote! {})
542}