1#![deny(missing_docs)]
4
5use std::thread::{self, ThreadId, JoinHandle};
6use std::sync::mpsc::{self, Sender, Receiver};
7use std::time::Duration;
8use std::cell::RefCell;
9use std::collections::HashSet;
10use std::ops::Drop;
11
12#[derive(Debug)]
31pub struct Shutter {
32 sender: Sender<ThreadId>,
33 receiver: Receiver<ThreadId>,
34 ids: RefCell<HashSet<ThreadId>>,
35}
36
37#[derive(Debug)]
38struct Guard(Sender<ThreadId>);
39
40impl Drop for Guard {
41 fn drop(&mut self) {
42 let id = thread::current().id();
43 self.0.send(id).ok();
44 }
45}
46
47impl Shutter {
48 pub fn new() -> Self {
50 let (sender, receiver) = mpsc::channel();
51
52 Shutter {
53 sender,
54 receiver,
55 ids: RefCell::default(),
56 }
57 }
58
59 pub fn spawn<F, T>(&self, f: F) -> JoinHandle<T> where
61 F: FnOnce() -> T,
62 F: Send + 'static,
63 T: Send + 'static,
64 {
65 let sender = self.sender.clone();
66
67 let handle = thread::spawn(move|| {
68 let _guard = Guard(sender);
69 f()
70 });
71
72 assert!(self.ids.borrow_mut().insert(handle.thread().id()));
73
74 handle
75 }
76
77 pub fn timeout(&self, dur: Duration) {
85 let sender = self.sender.clone();
86
87 let timer = thread::spawn(move|| {
88 let _guard = Guard(sender);
89 thread::park_timeout(dur);
90 });
91 let timer = timer.thread();
92
93 for msg in self.receiver.iter() {
94 let mut ids = self.ids.borrow_mut();
95
96 if msg == timer.id() {
97 for msg in self.receiver.try_iter() {
98 assert!(ids.remove(&msg));
99 }
100
101 if ids.is_empty() {
102 return
103 } else {
104 panic!("Timeout")
105 }
106 }
107
108 assert!(ids.remove(&msg));
109
110 if ids.is_empty() {
111 timer.unpark();
112 return
113 }
114 }
115
116 unreachable!()
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 fn ms(num: u64) -> Duration {
125 Duration::from_millis(num)
126 }
127
128 #[test]
129 fn test_success() {
130 let shutter = Shutter::new();
131
132 shutter.spawn(|| thread::sleep(ms(200)));
133 shutter.spawn(|| thread::sleep(ms(300)));
134
135 shutter.timeout(ms(500));
136 }
137
138 #[test]
139 #[should_panic(expected = "Timeout")]
140 fn test_timeout() {
141 let shutter = Shutter::new();
142
143 shutter.spawn(|| { thread::sleep(ms(500)); });
144 shutter.timeout(ms(100));
145 }
146
147 #[test]
148 fn test_child_panic() {
149 let shutter = Shutter::new();
150
151 shutter.spawn(|| { panic!("unbeleavable!"); });
152 shutter.timeout(ms(100));
153 }
154}