1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
//! **Do not use this**
#![warn(clippy::pedantic)]
extern crate proc_macro;
use proc_macro::TokenStream;

use std::sync::atomic::{AtomicBool, Ordering};

/// Used for the attribute macros such that we only define [`__TestState`] and [`__PAIR`] once.
static SET: AtomicBool = AtomicBool::new(false);

const DEFINE: &str = "
    /// The test state stored in the mutex in the pair used for ordering tests.
    enum __TestState {
        Sequential,
        Parallel(u64)
    }
    /// The mutex and condvar pair used for ordering tests.
    static __PAIR: (std::sync::Mutex<__TestState>, std::sync::Condvar) = (
        std::sync::Mutex::new(__TestState::Parallel(0)), std::sync::Condvar::new()
    );
";
const SEQ_PREFIX: &str = "
    let _ = __PAIR.1.wait_while(__PAIR.0.lock().expect(\"sequential-test error\"), |pending| 
        match pending {
            __TestState::Parallel(0) => {
                *pending = __TestState::Sequential;
                false
            },
            _ => true
        }
    ).expect(\"sequential-test error\");
";
const PAR_PREFIX: &str = "
    let _ = __PAIR.1.wait_while(__PAIR.0.lock().expect(\"sequential-test error\"), |pending|
        match pending {
            __TestState::Sequential => true,
            __TestState::Parallel(ref mut x) => {
                *x += 1;
                false
            }
        }
    ).expect(\"sequential-test error\");
";
const PAR_SUFFIX: &str = "
    match *__PAIR.0.lock().expect(\"sequential-test error\") {
        __TestState::Sequential => unreachable!(\"sequential-test error\"),
        __TestState::Parallel(ref mut x) => {
            *x -= 1;
        }
    }
";
const SEQ_SUFFIX: &str = "*__PAIR.0.lock().expect(\"sequential-test error\") = __TestState::Parallel(0);";
const SUFFIX: &str = "
    __PAIR.1.notify_all();
    if let Err(err) = res {
        std::panic::resume_unwind(err);
    }
";

/// Annotates tests which must run sequentially.
#[proc_macro_attribute]
pub fn sequential(_attr: TokenStream, item: TokenStream) -> TokenStream {
    inner(item, SEQ_PREFIX, SEQ_SUFFIX)
}
/// Annotates tests which may run in parallel.
#[proc_macro_attribute]
pub fn parallel(_attr: TokenStream, item: TokenStream) -> TokenStream {
    inner(item, PAR_PREFIX, PAR_SUFFIX)
}
fn inner(item: TokenStream, prefix: &str, suffix: &str) -> TokenStream {
    // We get whether this was the first macro invocation and set first macro invocation to false.
    let ret = SET.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
    let mut iter = item.into_iter().peekable();
    let signature = std::iter::from_fn(|| {
        iter.next_if(|x| match x {
            proc_macro::TokenTree::Group(group) => {
                !matches!(group.delimiter(), proc_macro::Delimiter::Brace)
            }
            _ => true,
        })
    })
    .collect::<TokenStream>();
    let block = iter.collect::<TokenStream>();
    let item = format!(
        "
        {signature} {{
            {prefix}
            let res = std::panic::catch_unwind(|| {block} );
            {suffix}
            {SUFFIX}
        }}"
    );
    // If this was the first macro invocation define the mutex and condvar used for locking.
    if ret.is_ok() {
        format!("{DEFINE}\n{item}")
    } else {
        item
    }
    .parse()
    .unwrap()
}