test_env_helpers/
lib.rs

1//
2//! Some setup and teardown macro helpers to mimic [Jest's setup and teardown](https://jestjs.io/docs/setup-teardown)
3//! functionality. Also includes a `skip` macro that mimics the [skip](https://jestjs.io/docs/api#testskipname-fn)
4//! functionality in Jest.
5//!
6//! There are currently five macros provided: `after_all`,
7//! `after_each`, `before_all`, `before_each`, and `skip`. I would like to implement `only` to
8//! match [Jest's only](https://jestjs.io/docs/api#testonlyname-fn-timeout) functionality. I'm
9//! unsure of a great way to do that currently, however.
10//!
11//! ## Getting Started
12//! Using these macros is fairly simple. The four after/before functions all require a function
13//! with the same name as the attribute and are only valid when applied to a mod. They are all used
14//! like in the below example. Replace `before_each` with whichever method you want to use. The
15//! code in the matching function will be inserted into every fn in the containing mod that has an
16//! attribute with the word "test" in it. This is to allow for use with not just normal `#[test]`
17//! attributes, but also other flavors like `#[tokio::test]` and `#[test_case(0)]`.
18//! ```
19//! #[cfg(test)]
20//! use test_env_helpers::*;
21//!
22//! #[before_each]
23//! #[cfg(test)]
24//! mod my_tests{
25//!     fn before_each(){println!("I'm in every test!")}
26//!     #[test]
27//!     fn test_1(){}
28//!     #[test]
29//!     fn test_2(){}
30//!     #[test]
31//!     fn test_3(){}
32//! }
33//! ```
34//!
35//! The `skip` macro is valid on either a mod or an individual test and will remove the mod or test
36//! it is applied to. You can use it to skip tests that aren't working correctly or that you don't
37//! want to run for some reason.
38//!
39//! ```
40//! #[cfg(test)]
41//! use test_env_helpers::*;
42//!
43//! #[cfg(test)]
44//! mod my_tests{
45//!     #[skip]
46//!     #[test]
47//!     fn broken_test(){panic!("I'm hella broke")}
48//!     #[skip]
49//!     mod broken_mod{
50//!         #[test]
51//!         fn i_will_not_be_run(){panic!("I get skipped too")}
52//!     }
53//!     #[test]
54//!     fn test_2(){}
55//!     #[test]
56//!     fn test_3(){}
57//! }
58//! ```
59
60extern crate proc_macro;
61mod utils;
62
63use crate::utils::traverse_use_item;
64
65use proc_macro::TokenStream;
66use quote::quote;
67use syn::{parse_macro_input, parse_quote, Item, Stmt};
68
69/// Will run the code in the matching `after_all` function exactly once when all of the tests have
70/// run. This works by counting the number of `#[test]` attributes and decrementing a counter at
71/// the beginning of every test. Once the counter reaches 0, it will run the code in `after_all`.
72/// It uses [std::sync::Once](https://doc.rust-lang.org/std/sync/struct.Once.html) internally
73/// to ensure that the code is run at maximum one time.
74///
75/// ```
76/// #[cfg(test)]
77/// use test_env_helpers::*;
78///
79/// #[after_all]
80/// #[cfg(test)]
81/// mod my_tests{
82///     fn after_all(){println!("I only get run once at the very end")}
83///     #[test]
84///     fn test_1(){}
85///     #[test]
86///     fn test_2(){}
87///     #[test]
88///     fn test_3(){}
89/// }
90/// ```
91#[proc_macro_attribute]
92pub fn after_all(_metadata: TokenStream, input: TokenStream) -> TokenStream {
93    let input: Item = match parse_macro_input!(input as Item) {
94        Item::Mod(mut m) => {
95            let (brace, items) = m.content.unwrap();
96            let (after_all_fn, everything_else): (Vec<Item>, Vec<Item>) =
97                items.into_iter().partition(|t| match t {
98                    Item::Fn(f) => f.sig.ident == "after_all",
99                    _ => false,
100                });
101            let after_all_fn_block = if after_all_fn.len() != 1 {
102                panic!("The `after_all` macro attribute requires a single function named `after_all` in the body of the module it is called on.")
103            } else {
104                match after_all_fn.into_iter().next().unwrap() {
105                    Item::Fn(f) => f.block,
106                    _ => unreachable!(),
107                }
108            };
109            let inc_count: Stmt = parse_quote!(
110                REMAINING_TESTS.fetch_sub(1, Ordering::SeqCst);
111            );
112            let after_all_if: Stmt = parse_quote! {
113                if REMAINING_TESTS.load(Ordering::SeqCst) <= 0{
114                    AFTER_ALL.call_once(|| {
115                        #after_all_fn_block
116                    });
117                }
118            };
119
120            let mut count: usize = 0;
121            let mut has_once: bool = false;
122            let mut has_atomic_usize: bool = false;
123            let mut has_ordering: bool = false;
124
125            let mut e: Vec<Item> = everything_else
126                .into_iter()
127                .map(|t| match t {
128                    Item::Fn(mut f) => {
129                        let test_count = f
130                            .attrs
131                            .iter()
132                            .filter(|attr| {
133                                attr.path
134                                    .segments
135                                    .iter()
136                                    .any(|segment| segment.ident.to_string().contains("test"))
137                            })
138                            .count();
139                        if test_count > 0 {
140                            count += test_count;
141                            let mut stmts = vec![inc_count.clone()];
142                            stmts.append(&mut f.block.stmts);
143                            stmts.push(after_all_if.clone());
144                            f.block.stmts = stmts;
145                            Item::Fn(f)
146                        } else {
147                            Item::Fn(f)
148                        }
149                    }
150                    Item::Use(use_stmt) => {
151                        if traverse_use_item(&use_stmt.tree, vec!["std", "sync", "Once"]).is_some()
152                        {
153                            has_once = true;
154                        }
155                        if traverse_use_item(
156                            &use_stmt.tree,
157                            vec!["std", "sync", "atomic", "AtomicUsize"],
158                        )
159                        .is_some()
160                        {
161                            has_atomic_usize = true;
162                        }
163                        if traverse_use_item(
164                            &use_stmt.tree,
165                            vec!["std", "sync", "atomic", "Ordering"],
166                        )
167                        .is_some()
168                        {
169                            has_ordering = true;
170                        }
171                        Item::Use(use_stmt)
172                    }
173                    el => el,
174                })
175                .collect();
176
177            let use_once: Item = parse_quote!(
178                use std::sync::Once;
179            );
180            let use_atomic_usize: Item = parse_quote!(
181                use std::sync::atomic::AtomicUsize;
182            );
183            let use_ordering: Item = parse_quote!(
184                use std::sync::atomic::Ordering;
185            );
186            let static_once: Item = parse_quote!(
187                static AFTER_ALL: Once = Once::new();
188            );
189            let static_count: Item = parse_quote!(
190                static REMAINING_TESTS: AtomicUsize = AtomicUsize::new(#count);
191            );
192
193            let mut once_content = vec![];
194
195            if !has_once {
196                once_content.push(use_once);
197            }
198            if !has_atomic_usize {
199                once_content.push(use_atomic_usize);
200            }
201            if !has_ordering {
202                once_content.push(use_ordering);
203            }
204            once_content.append(&mut vec![static_once, static_count]);
205            once_content.append(&mut e);
206
207            m.content = Some((brace, once_content));
208            Item::Mod(m)
209        }
210        _ => {
211            panic!("The `after_all` macro attribute is only valid when called on a module.")
212        }
213    };
214    TokenStream::from(quote! (#input))
215}
216
217/// Will run the code in the matching `after_each` function at the end of every `#[test]` function.
218/// Useful if you want to cleanup after a test or reset some external state. If the test panics,
219/// this code will not be run. If you need something that is infallible, you should use
220/// `before_each` instead.
221/// ```
222/// #[cfg(test)]
223/// use test_env_helpers::*;
224///
225/// #[after_each]
226/// #[cfg(test)]
227/// mod my_tests{
228///     fn after_each(){println!("I get run at the very end of each function")}
229///     #[test]
230///     fn test_1(){}
231///     #[test]
232///     fn test_2(){}
233///     #[test]
234///     fn test_3(){}
235/// }
236/// ```
237#[proc_macro_attribute]
238pub fn after_each(_metadata: TokenStream, input: TokenStream) -> TokenStream {
239    let input: Item = match parse_macro_input!(input as Item) {
240        Item::Mod(mut m) => {
241            let (brace, items) = m.content.unwrap();
242            let (after_each_fn, everything_else): (Vec<Item>, Vec<Item>) =
243                items.into_iter().partition(|t| match t {
244                    Item::Fn(f) => f.sig.ident == "after_each",
245                    _ => false,
246                });
247            let after_each_fn_block = if after_each_fn.len() != 1 {
248                panic!("The `after_each` macro attribute requires a single function named `after_each` in the body of the module it is called on.")
249            } else {
250                match after_each_fn.into_iter().next().unwrap() {
251                    Item::Fn(f) => f.block,
252                    _ => unreachable!(),
253                }
254            };
255
256            let e: Vec<Item> = everything_else
257                .into_iter()
258                .map(|t| match t {
259                    Item::Fn(mut f) => {
260                        if f.attrs.iter().any(|attr| {
261                            attr.path
262                                .segments
263                                .iter()
264                                .any(|segment| segment.ident.to_string().contains("test"))
265                        }) {
266                            f.block.stmts.append(&mut after_each_fn_block.stmts.clone());
267                            Item::Fn(f)
268                        } else {
269                            Item::Fn(f)
270                        }
271                    }
272                    e => e,
273                })
274                .collect();
275            m.content = Some((brace, e));
276            Item::Mod(m)
277        }
278
279        _ => {
280            panic!("The `after_each` macro attribute is only valid when called on a module.")
281        }
282    };
283    TokenStream::from(quote! {#input})
284}
285
286/// Will run the code in the matching `before_all` function exactly once at the very beginning of a
287/// test run. It uses [std::sync::Once](https://doc.rust-lang.org/std/sync/struct.Once.html) internally
288/// to ensure that the code is run at maximum one time. Useful for setting up some external state
289/// that will be reused in multiple tests.
290/// ```
291/// #[cfg(test)]
292/// use test_env_helpers::*;
293///
294/// #[before_all]
295/// #[cfg(test)]
296/// mod my_tests{
297///     fn before_all(){println!("I get run at the very beginning of the test suite")}
298///     #[test]
299///     fn test_1(){}
300///     #[test]
301///     fn test_2(){}
302///     #[test]
303///     fn test_3(){}
304/// }
305/// ```
306#[proc_macro_attribute]
307pub fn before_all(_metadata: TokenStream, input: TokenStream) -> TokenStream {
308    let input: Item = match parse_macro_input!(input as Item) {
309        Item::Mod(mut m) => {
310            let (brace, items) = m.content.unwrap();
311            let (before_all_fn, everything_else): (Vec<Item>, Vec<Item>) =
312                items.into_iter().partition(|t| match t {
313                    Item::Fn(f) => f.sig.ident == "before_all",
314                    _ => false,
315                });
316            let before_all_fn_block = if before_all_fn.len() != 1 {
317                panic!("The `before_all` macro attribute requires a single function named `before_all` in the body of the module it is called on.")
318            } else {
319                match before_all_fn.into_iter().next().unwrap() {
320                    Item::Fn(f) => f.block,
321                    _ => unreachable!(),
322                }
323            };
324            let q: Stmt = parse_quote! {
325                BEFORE_ALL.call_once(|| {
326                    #before_all_fn_block
327                });
328            };
329
330            let mut has_once: bool = false;
331            let mut e: Vec<Item> = everything_else
332                .into_iter()
333                .map(|t| match t {
334                    Item::Fn(mut f) => {
335                        if f.attrs.iter().any(|attr| {
336                            attr.path
337                                .segments
338                                .iter()
339                                .any(|segment| segment.ident.to_string().contains("test"))
340                        }) {
341                            let mut stmts = vec![q.clone()];
342                            stmts.append(&mut f.block.stmts);
343                            f.block.stmts = stmts;
344                            Item::Fn(f)
345                        } else {
346                            Item::Fn(f)
347                        }
348                    }
349                    Item::Use(use_stmt) => {
350                        if traverse_use_item(&use_stmt.tree, vec!["std", "sync", "Once"]).is_some()
351                        {
352                            has_once = true;
353                        }
354                        Item::Use(use_stmt)
355                    }
356                    e => e,
357                })
358                .collect();
359            let use_once: Item = parse_quote!(
360                use std::sync::Once;
361            );
362            let static_once: Item = parse_quote!(
363                static BEFORE_ALL: Once = Once::new();
364            );
365
366            let mut once_content = vec![];
367            if !has_once {
368                once_content.push(use_once);
369            }
370            once_content.push(static_once);
371            once_content.append(&mut e);
372
373            m.content = Some((brace, once_content));
374            Item::Mod(m)
375        }
376
377        _ => {
378            panic!("The `before_all` macro attribute is only valid when called on a module.")
379        }
380    };
381    TokenStream::from(quote! (#input))
382}
383
384/// Will run the code in the matching `before_each` function at the beginning of every test. Useful
385/// to reset state to ensure that a test has a clean slate.
386/// ```
387/// #[cfg(test)]
388/// use test_env_helpers::*;
389///
390/// #[before_each]
391/// #[cfg(test)]
392/// mod my_tests{
393///     fn before_each(){println!("I get run at the very beginning of every test")}
394///     #[test]
395///     fn test_1(){}
396///     #[test]
397///     fn test_2(){}
398///     #[test]
399///     fn test_3(){}
400/// }
401/// ```
402///
403/// Can be used to reduce the amount of boilerplate setup code that needs to be copied into each test.
404/// For example, if you need to ensure that tests in a single test suite are not run in parallel, this can
405/// easily be done with a [Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html).
406/// However, remembering to copy and paste the code to acquire a lock on the `Mutex` in every test
407/// is tedious and error prone.
408/// ```
409/// #[cfg(test)]
410/// mod without_before_each{
411///     lazy_static! {
412///         static ref MTX: Mutex<()> = Mutex::new(());
413///     }
414///     #[test]
415///     fn test_1(){let _m = MTX.lock();}
416///     #[test]
417///     fn test_2(){let _m = MTX.lock();}
418///     #[test]
419///     fn test_3(){let _m = MTX.lock();}
420/// }
421/// ```
422/// Using `before_each` removes the need to copy and paste so much and makes making changes easier
423/// because they only need to be made in a single location instead of once for every test.
424/// ```
425/// #[cfg(test)]
426/// use test_env_helpers::*;
427///
428/// #[before_each]
429/// #[cfg(test)]
430/// mod with_before_each{
431///     lazy_static! {
432///         static ref MTX: Mutex<()> = Mutex::new(());
433///     }
434///     fn before_each(){let _m = MTX.lock();}
435///     #[test]
436///     fn test_1(){}
437///     #[test]
438///     fn test_2(){}
439///     #[test]
440///     fn test_3(){}
441/// }
442/// ```
443#[proc_macro_attribute]
444pub fn before_each(_metadata: TokenStream, input: TokenStream) -> TokenStream {
445    let input: Item = match parse_macro_input!(input as Item) {
446        Item::Mod(mut m) => {
447            let (brace, items) = m.content.unwrap();
448            let (before_each_fn, everything_else): (Vec<Item>, Vec<Item>) =
449                items.into_iter().partition(|t| match t {
450                    Item::Fn(f) => f.sig.ident == "before_each",
451                    _ => false,
452                });
453            let before_each_fn_block = if before_each_fn.len() != 1 {
454                panic!("The `before_each` macro attribute requires a single function named `before_each` in the body of the module it is called on.")
455            } else {
456                match before_each_fn.into_iter().next().unwrap() {
457                    Item::Fn(f) => f.block,
458                    _ => unreachable!(),
459                }
460            };
461
462            let e: Vec<Item> = everything_else
463                .into_iter()
464                .map(|t| match t {
465                    Item::Fn(mut f) => {
466                        if f.attrs.iter().any(|attr| {
467                            attr.path
468                                .segments
469                                .iter()
470                                .any(|segment| segment.ident.to_string().contains("test"))
471                        }) {
472                            let mut b = before_each_fn_block.stmts.clone();
473                            b.append(&mut f.block.stmts);
474                            f.block.stmts = b;
475                            Item::Fn(f)
476                        } else {
477                            Item::Fn(f)
478                        }
479                    }
480                    e => e,
481                })
482                .collect();
483            m.content = Some((brace, e));
484            Item::Mod(m)
485        }
486
487        _ => {
488            panic!("The `before_each` macro attribute is only valid when called on a module.")
489        }
490    };
491    TokenStream::from(quote! {#input})
492}
493
494/// Will skip running the code it is applied on. You can use it to skip tests that aren't working
495/// correctly or that you don't want to run for some reason. There are no checks to make sure it's
496/// applied to a `#[test]` or mod. It will remove whatever it is applied to from the final AST.
497///
498/// ```
499/// #[cfg(test)]
500/// use test_env_helpers::*;
501///
502/// #[cfg(test)]
503/// mod my_tests{
504///     #[skip]
505///     #[test]
506///     fn broken_test(){panic!("I'm hella broke")}
507///     #[skip]
508///     mod broken_mod{
509///         #[test]
510///         fn i_will_not_be_run(){panic!("I get skipped too")}
511///     }
512///     #[test]
513///     fn test_2(){}
514///     #[test]
515///     fn test_3(){}
516/// }
517/// ```
518#[proc_macro_attribute]
519pub fn skip(_metadata: TokenStream, _input: TokenStream) -> TokenStream {
520    TokenStream::from(quote! {})
521}