shutter_gate/
lib.rs

1//! Simple test utility to panic when spawned threads are blocked too long.
2
3#![deny(missing_docs)]
4
5use std::thread::{self, ThreadId, JoinHandle};
6use std::sync::mpsc::{self, Sender, Receiver};
7use std::time::Duration;
8use std::cell::RefCell;
9use std::collections::HashSet;
10use std::ops::Drop;
11
12/// Examples
13///
14/// ```
15/// # extern crate shutter_gate;
16/// # use shutter_gate::Shutter;
17/// use std::thread;
18/// use std::time::Duration;
19///
20/// #[test]
21/// fn test_func() {
22///   let shutter = Shutter::new();
23///
24///   shutter.spawn(|| println!("testing"));
25///   shutter.spawn(|| thread::sleep(Duration::from_millis(300)));
26///
27///   shutter.timeout(Duration::from_millis(500));
28/// }
29/// ```
30#[derive(Debug)]
31pub struct Shutter {
32    sender: Sender<ThreadId>,
33    receiver: Receiver<ThreadId>,
34    ids: RefCell<HashSet<ThreadId>>,
35}
36
37#[derive(Debug)]
38struct Guard(Sender<ThreadId>);
39
40impl Drop for Guard {
41    fn drop(&mut self) {
42        let id = thread::current().id();
43        self.0.send(id).ok();
44    }
45}
46
47impl Shutter {
48    /// Create a shutter.
49    pub fn new() -> Self {
50        let (sender, receiver) = mpsc::channel();
51
52        Shutter {
53            sender,
54            receiver,
55            ids: RefCell::default(),
56        }
57    }
58
59    /// Spawn a thread and track it.
60    pub fn spawn<F, T>(&self, f: F) -> JoinHandle<T> where
61        F: FnOnce() -> T,
62        F: Send + 'static,
63        T: Send + 'static,
64    {
65        let sender = self.sender.clone();
66
67        let handle = thread::spawn(move|| {
68            let _guard = Guard(sender);
69            f()
70        });
71
72        assert!(self.ids.borrow_mut().insert(handle.thread().id()));
73
74        handle
75    }
76
77    /// Ensure every spawned threads are not blocked after given duration.
78    ///
79    /// Returns early if every spawned threads are terminated.
80    ///
81    /// # Panics
82    ///
83    /// It panics after given duration if at least one thread spawned by this shutter is blocked.
84    pub fn timeout(&self, dur: Duration) {
85        let sender = self.sender.clone();
86
87        let timer = thread::spawn(move|| {
88            let _guard = Guard(sender);
89            thread::park_timeout(dur);
90        });
91        let timer = timer.thread();
92
93        for msg in self.receiver.iter() {
94            let mut ids = self.ids.borrow_mut();
95
96            if msg == timer.id() {
97                for msg in self.receiver.try_iter() {
98                    assert!(ids.remove(&msg));
99                }
100
101                if ids.is_empty() {
102                    return
103                } else {
104                    panic!("Timeout")
105                }
106            }
107
108            assert!(ids.remove(&msg));
109
110            if ids.is_empty() {
111                timer.unpark();
112                return
113            }
114        }
115
116        unreachable!()
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    fn ms(num: u64) -> Duration {
125        Duration::from_millis(num)
126    }
127
128    #[test]
129    fn test_success() {
130        let shutter = Shutter::new();
131
132        shutter.spawn(|| thread::sleep(ms(200)));
133        shutter.spawn(|| thread::sleep(ms(300)));
134
135        shutter.timeout(ms(500));
136    }
137
138    #[test]
139    #[should_panic(expected = "Timeout")]
140    fn test_timeout() {
141        let shutter = Shutter::new();
142
143        shutter.spawn(|| { thread::sleep(ms(500)); });
144        shutter.timeout(ms(100));
145    }
146
147    #[test]
148    fn test_child_panic() {
149        let shutter = Shutter::new();
150
151        shutter.spawn(|| { panic!("unbeleavable!"); });
152        shutter.timeout(ms(100));
153    }
154}