threadgroup/lib.rs
1//! Manages a group of threads that all have the same return type and can be [join()]ed as a unit.
2//!
3//! The implementation uses a [mpsc channel] internaly so that children (spawned threads) can notify
4//! the parent (owner of a [`ThreadGroup`]) that they are finished without the parent having to use
5//! a blocking [`std::thread::JoinHandle.join()`] call.
6//!
7//! # Examples
8//! ```rust
9//! use std::thread::sleep;
10//! use std::time::Duration;
11//! use threadgroup::{JoinError, ThreadGroup};
12//!
13//! // Initialize a group of threads returning `u32`.
14//! let mut tg: ThreadGroup<u32> = ThreadGroup::new();
15//!
16//! // Start a bunch of threads that'll return or panic after a while
17//! tg.spawn::<_,u32>(|| {sleep(Duration::new(0,3000000));2});
18//! tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1500000));panic!()});
19//! tg.spawn::<_,u32>(|| {sleep(Duration::new(10,0));3});
20//! tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
21//!
22//! // Join them in the order they finished
23//! assert_eq!(1, tg.join().unwrap());
24//! assert_eq!(JoinError::Panicked, tg.join().unwrap_err());
25//! assert_eq!(2, tg.join().unwrap());
26//! assert_eq!(JoinError::Timeout, tg.join_timeout(Duration::new(0,10000)).unwrap_err());
27//! ```
28//! [join()]: struct.ThreadGroup.html#method.join
29//! [`ThreadGroup`]: struct.ThreadGroup.html
30//! [mpsc channel]: https://doc.rust-lang.org/stable/std/sync/mpsc/index.html
31//! [`std::thread::JoinHandle.join()`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html#method.join
32//#![doc(html_playground_url = "https://play.rust-lang.org/")]
33//#![doc(include = "README.md")]
34
35use std::sync::mpsc;
36use std::sync::mpsc::{Receiver, RecvError, RecvTimeoutError, Sender};
37use std::thread;
38use std::thread::{JoinHandle, ThreadId};
39use std::time::Duration;
40
41/// Possible error returns from [`join()`] and [`join_timeout()`].
42/// [`join()`]: struct.ThreadGroup.html#method.join
43/// [`join_timeout()`]: struct.ThreadGroup.html#method.join_timeout
44#[derive(Debug,PartialEq)]
45pub enum JoinError {
46 /// Thread list is empty, nothing to join.
47 AllDone,
48 /// Joined thread has panicked, no result available.
49 Panicked,//FIXME: include panic::PanicInfo
50 /// No thread has finished yet.
51 Timeout,
52 /// Internal channel got disconnected (should not happen, only included for completeness).
53 Disconnected,
54}
55
56/// Holds the collection of threads and the notification channel.
57/// All public functions operate on this struct.
58pub struct ThreadGroup<T> {
59 tx: Sender<ThreadId>,
60 rx: Receiver<ThreadId>,
61 handles: Vec<JoinHandle<T>>,
62}
63
64/// Sends current thread id on its channel when it gets out of scope.
65struct SendOnDrop {
66 tx: Sender<ThreadId>,
67}
68impl Drop for SendOnDrop {
69 fn drop(&mut self) {
70 self.tx.send(thread::current().id()).unwrap();
71 }
72}
73
74// TODO: Allow passing something during spawn() that'll be returned during join()
75// TODO: check threads.len() on drop()
76// TODO: join_all()
77// TODO: iter() or into_iter()
78impl<T> ThreadGroup<T> {
79 /// Initialize a group of threads returning `T`.
80 /// # Examples
81 /// ```rust
82 /// use threadgroup::ThreadGroup;
83 /// // spawning and joining require the struct to be mutable, and you'll need to provide type hints.
84 /// let mut tg: ThreadGroup<u32> = ThreadGroup::new();
85 /// ```
86 pub fn new() -> ThreadGroup<T> {
87 let (tx, rx): (Sender<ThreadId>, Receiver<ThreadId>) = mpsc::channel();
88 ThreadGroup::<T>{tx: tx, rx: rx, handles: vec![]}
89 }
90
91 /// Spawn a new thread, adding it to the ThreadGroup.
92 ///
93 /// Operates like [`std::thread::spawn()`], but modifies the ThreadGroup instead of returning a [`JoinHandle`].
94 /// # Examples
95 /// ```rust
96 /// use std::time::Duration;
97 /// use std::thread::sleep;
98 /// use threadgroup::{JoinError, ThreadGroup};
99 /// let mut tg: ThreadGroup<u32> = ThreadGroup::new();
100 /// tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
101 /// ```
102 /// [`std::thread::spawn()`]: https://doc.rust-lang.org/stable/std/thread/fn.spawn.html
103 /// [`JoinHandle`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html
104 // FIXME: Is there a way to remove the need to specify the type when calling spawn() ?
105 // The ThreadGroup has type T and we know that R is the same type, but the compiler doesn't see that.
106 pub fn spawn<F, R>(&mut self, f: F)
107 where
108 F: FnOnce() -> T,
109 F: Send + 'static,
110 R: Send + 'static,
111 T: Send + 'static,
112 {
113 let thread_tx = self.tx.clone();
114 let jh: JoinHandle<T> = thread::spawn(move || {
115 let _sender = SendOnDrop{tx: thread_tx.clone()};
116 f()
117 });
118 self.handles.push(jh);
119 }
120
121 /// Return count of threads that have been [`spawn()`]ed but not yet [`join()`]ed.
122 /// [`spawn()`]: struct.ThreadGroup.html#method.spawn
123 /// [`join()`]: struct.ThreadGroup.html#method.join
124 pub fn len(&self) -> usize {
125 self.handles.len()
126 }
127
128 /// Check if there is any thread not yet [`join()`]ed.
129 /// [`join()`]: struct.ThreadGroup.html#method.join
130 pub fn is_empty(&self) -> bool {
131 self.handles.is_empty()
132 }
133
134 /// Join one thread of the ThreadGroup.
135 ///
136 /// Operates like [`std::thread::JoinHandle.join()`], but picks the first thread that
137 /// terminates.
138 /// # Examples
139 /// ```rust
140 /// use threadgroup::ThreadGroup;
141 /// let mut tg: ThreadGroup<u32> = ThreadGroup::new();
142 /// while !tg.is_empty() {
143 /// match tg.join() {
144 /// Ok(ret) => println!("Thread returned {}", ret),
145 /// Err(e) => panic!("Oh noes !"),
146 /// }
147 /// }
148 /// ```
149 /// [`std::thread::JoinHandle.join()`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html#method.join
150 pub fn join(&mut self) -> Result<T, JoinError> {
151 match self.handles.is_empty() {
152 true => Err(JoinError::AllDone),
153 false => match self.rx.recv() {
154 Ok(id) => self.do_join(id),
155 Err(RecvError{}) => Err(JoinError::Disconnected)
156 }
157 }
158 }
159
160 /// Try to join one thread of the ThreadGroup.
161 ///
162 /// Operates like [`std::thread::JoinHandle.join()`], but picks the first thread that terminates
163 /// and gives up if the timeout is reached.
164 /// # Examples
165 /// ```rust
166 /// use std::time::Duration;
167 /// use threadgroup::{JoinError, ThreadGroup};
168 /// let mut tg: ThreadGroup<u32> = ThreadGroup::new();
169 /// for _ in 0..10 {
170 /// if let Err(JoinError::Timeout) = tg.join_timeout(Duration::new(0,10000)) {
171 /// println!("Still working...");
172 /// }
173 /// }
174 /// ```
175 /// [`std::thread::JoinHandle.join()`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html#method.join
176 pub fn join_timeout(&mut self, timeout: Duration) -> Result<T, JoinError> {
177 match self.handles.is_empty() {
178 true => Err(JoinError::AllDone),
179 false => match self.rx.recv_timeout(timeout) {
180 Ok(id) => self.do_join(id),
181 Err(RecvTimeoutError::Timeout) => Err(JoinError::Timeout),
182 Err(RecvTimeoutError::Disconnected) => Err(JoinError::Disconnected)
183 }
184 }
185 }
186
187 /// Find a thread by its id.
188 // TODO: replace with https://doc.rust-lang.org/nightly/std/vec/struct.Vec.html#method.remove_item
189 fn find(&self, id: ThreadId) -> Option<usize> {
190 for (i,jh) in self.handles.iter().enumerate() {
191 if jh.thread().id() == id {
192 return Some(i)
193 }
194 }
195 None
196 }
197
198 /// Actual [`JoinHandle.join()`] once we know that the thread has finished.
199 /// [`JoinHandle.join()`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html#method.join
200 fn do_join(&mut self, id: ThreadId) -> Result<T, JoinError> {
201 // We need to separately find and remove the JoinHandle from the vector in order to not upset the borrow checker
202 let i = self.find(id).unwrap();
203 match self.handles.remove(i).join() {
204 Ok(ret) => Ok(ret),
205 Err(_) => Err(JoinError::Panicked),
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use std::thread::sleep;
213 use std::time::Duration;
214 use ::{JoinError, ThreadGroup};
215
216 #[test]
217 fn empty_group() {
218 let mut tg: ThreadGroup<u32> = ThreadGroup::new();
219 assert!(tg.is_empty());
220 assert_eq!(tg.len(), 0);
221 assert_eq!(JoinError::AllDone, tg.join().unwrap_err());
222 }
223 #[test]
224 fn basic_join() {
225 let mut tg: ThreadGroup<u32> = ThreadGroup::new();
226 tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
227 tg.spawn::<_,u32>(|| {sleep(Duration::new(0,3000000));3});
228 tg.spawn::<_,u32>(|| {sleep(Duration::new(0,2000000));2});
229 assert_eq!(1, tg.join().unwrap());
230 assert_eq!(2, tg.join().unwrap());
231 assert_eq!(3, tg.join().unwrap());
232 assert_eq!(JoinError::AllDone, tg.join().unwrap_err());
233 }
234 #[test]
235 fn panic_join() {
236 let mut tg: ThreadGroup<u32> = ThreadGroup::new();
237 tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1500000));panic!()});
238 tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
239 assert_eq!(1, tg.join().unwrap());
240 assert_eq!(JoinError::Panicked, tg.join().unwrap_err());
241 assert_eq!(JoinError::AllDone, tg.join().unwrap_err());
242 }
243 #[test]
244 fn timeout_join() {
245 let mut tg: ThreadGroup<u32> = ThreadGroup::new();
246 tg.spawn::<_,u32>(|| {sleep(Duration::new(1000000,0));2});
247 tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
248 let t = Duration::new(1,0);
249 assert_eq!(1, tg.join_timeout(t).unwrap());
250 assert_eq!(JoinError::Timeout, tg.join_timeout(t).unwrap_err());
251 assert!(!tg.is_empty());
252 assert_eq!(tg.len(), 1);
253 }
254}