1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
//! Allows for the creation of sequential tests.
//! ```ignore
//! #[cfg(test)]
//! mod tests {
//!     #[test]
//!     #[sequential]
//!     fn test1() {
//!         // ...
//!     }
//!     #[test]
//!     #[sequential]
//!     fn test2() {
//!         // ...
//!     }
//!     #[test]
//!     #[parallel]
//!     fn test3() {
//!         // ...
//!     }
//! }
//! ```
//! - Tests with the [`macro@sequential`] attribute are guaranteed to be executed sequentially.
//! - Tests with the [`macro@parallel`] attribute may run in parallel of each other but will not run
//! at the same time as tests with the [`macro@sequential`] attribute.
//! - Tests with neither attributes may run in parallel with any tests.
//!
//! This library does not support `async` tests.
//!
//! Defining [`macro@sequential`] or [`macro@parallel`] attributes on non-tests or within scopes is
//! considered UB.
//!
//! This library is both faster[^speed] and smaller than
//! [`serial_test`](https://github.com/palfrey/serial_test) however offers less functionality.
//!
//! [^speed]: The current benchmark illustrate `sequential-test` covers the test set in an average
//! of ~350ms while [`serial_test`](https://github.com/palfrey/serial_test) covers the test set in
//! an average of ~550ms.
#![warn(clippy::pedantic)]
extern crate proc_macro;
use proc_macro::TokenStream;

use std::sync::atomic::{AtomicBool, Ordering};

/// Used for tje attribute macros such that we only define [`__TestState`] and [`__PAIR`] once.
static SET: AtomicBool = AtomicBool::new(false);

const DEFINE: &str = "
    /// The test state stored in the mutex in the pair used for ordering tests.
    enum __TestState {
        Sequential,
        Parallel(u64)
    }
    /// The mutex and condvar pair used for ordering tests.
    lazy_static::lazy_static! {
        static ref __PAIR: std::sync::Arc<(std::sync::Mutex<__TestState>, std::sync::Condvar)> = 
            std::sync::Arc::new(
                (std::sync::Mutex::new(__TestState::Parallel(0)), std::sync::Condvar::new())
            );
    }
";
const SEQ_PREFIX: &str = "
    let _ = __PAIR.1.wait_while(__PAIR.0.lock().unwrap(), |pending| 
        match pending {
            __TestState::Parallel(0) => {
                *pending = __TestState::Sequential;
                false
            },
            _ => true
        }
    ).unwrap();
";
const PAR_PREFIX: &str = "
    let _ = __PAIR.1.wait_while(__PAIR.0.lock().unwrap(), |pending|
        match pending {
            __TestState::Sequential => true,
            __TestState::Parallel(ref mut x) => {
                *x += 1;
                false
            }
        }
    ).unwrap();
";
const PAR_SUFFIX: &str = "
    match *__PAIR.0.lock().unwrap() {
        __TestState::Sequential => unreachable!(),
        __TestState::Parallel(ref mut x) => {
            *x -= 1;
        }
    }
";
const SEQ_SUFFIX: &str = "*__PAIR.0.lock().unwrap() = __TestState::Parallel(0);";
const SUFFIX: &str = "
    __PAIR.1.notify_all();
    if let Err(err) = res {
        std::panic::resume_unwind(err);
    }
";

/// Annotates tests which must run sequentially.
#[proc_macro_attribute]
pub fn sequential(_attr: TokenStream, item: TokenStream) -> TokenStream {
    inner(item, SEQ_PREFIX, SEQ_SUFFIX)
}
/// Annotates tests which may run in parallel.
#[proc_macro_attribute]
pub fn parallel(_attr: TokenStream, item: TokenStream) -> TokenStream {
    inner(item, PAR_PREFIX, PAR_SUFFIX)
}
fn inner(item: TokenStream, prefix: &str, suffix: &str) -> TokenStream {
    let ret = SET.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
    let mut iter = item.into_iter();
    let signature = (&mut iter).take(3).collect::<TokenStream>();
    let block = iter.collect::<TokenStream>();

    let item = format!(
        "
        {} {{
            {}
            let res = std::panic::catch_unwind(|| {} );
            {}
            {}
        }}",
        signature, prefix, block, suffix, SUFFIX
    );
    if ret.is_ok() {
        format!("{}\n{}", DEFINE, item)
    } else {
        item
    }
    .parse()
    .unwrap()
}