threadgroup/
lib.rs

1//! Manages a group of threads that all have the same return type and can be [join()]ed as a unit.
2//!
3//! The implementation uses a [mpsc channel] internaly so that children (spawned threads) can notify
4//! the parent (owner of a [`ThreadGroup`]) that they are finished without the parent having to use
5//! a blocking [`std::thread::JoinHandle.join()`] call.
6//!
7//! # Examples
8//! ```rust
9//! use std::thread::sleep;
10//! use std::time::Duration;
11//! use threadgroup::{JoinError, ThreadGroup};
12//!
13//! // Initialize a group of threads returning `u32`.
14//! let mut tg: ThreadGroup<u32> = ThreadGroup::new();
15//!
16//! // Start a bunch of threads that'll return or panic after a while
17//! tg.spawn::<_,u32>(|| {sleep(Duration::new(0,3000000));2});
18//! tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1500000));panic!()});
19//! tg.spawn::<_,u32>(|| {sleep(Duration::new(10,0));3});
20//! tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
21//!
22//! // Join them in the order they finished
23//! assert_eq!(1,                   tg.join().unwrap());
24//! assert_eq!(JoinError::Panicked, tg.join().unwrap_err());
25//! assert_eq!(2,                   tg.join().unwrap());
26//! assert_eq!(JoinError::Timeout,  tg.join_timeout(Duration::new(0,10000)).unwrap_err());
27//! ```
28//! [join()]:                           struct.ThreadGroup.html#method.join
29//! [`ThreadGroup`]:                    struct.ThreadGroup.html
30//! [mpsc channel]:                     https://doc.rust-lang.org/stable/std/sync/mpsc/index.html
31//! [`std::thread::JoinHandle.join()`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html#method.join
32//#![doc(html_playground_url = "https://play.rust-lang.org/")]
33//#![doc(include = "README.md")]
34
35use std::sync::mpsc;
36use std::sync::mpsc::{Receiver, RecvError, RecvTimeoutError, Sender};
37use std::thread;
38use std::thread::{JoinHandle, ThreadId};
39use std::time::Duration;
40
41/// Possible error returns from [`join()`] and [`join_timeout()`].
42/// [`join()`]:         struct.ThreadGroup.html#method.join
43/// [`join_timeout()`]: struct.ThreadGroup.html#method.join_timeout
44#[derive(Debug,PartialEq)]
45pub enum JoinError {
46    /// Thread list is empty, nothing to join.
47    AllDone,
48    /// Joined thread has panicked, no result available.
49    Panicked,//FIXME: include panic::PanicInfo
50    /// No thread has finished yet.
51    Timeout,
52    /// Internal channel got disconnected (should not happen, only included for completeness).
53    Disconnected,
54}
55
56/// Holds the collection of threads and the notification channel.
57/// All public functions operate on this struct.
58pub struct ThreadGroup<T> {
59    tx: Sender<ThreadId>,
60    rx: Receiver<ThreadId>,
61    handles: Vec<JoinHandle<T>>,
62}
63
64/// Sends current thread id on its channel when it gets out of scope.
65struct SendOnDrop {
66    tx: Sender<ThreadId>,
67}
68impl Drop for SendOnDrop {
69    fn drop(&mut self) {
70        self.tx.send(thread::current().id()).unwrap();
71    }
72}
73
74// TODO: Allow passing something during spawn() that'll be returned during join()
75// TODO: check threads.len() on drop()
76// TODO: join_all()
77// TODO: iter() or into_iter()
78impl<T> ThreadGroup<T> {
79    /// Initialize a group of threads returning `T`.
80    /// # Examples
81    /// ```rust
82    /// use threadgroup::ThreadGroup;
83    /// // spawning and joining require the struct to be mutable, and you'll need to provide type hints.
84    /// let mut tg: ThreadGroup<u32> = ThreadGroup::new();
85    /// ```
86    pub fn new() -> ThreadGroup<T> {
87        let (tx, rx): (Sender<ThreadId>, Receiver<ThreadId>) = mpsc::channel();
88        ThreadGroup::<T>{tx: tx, rx: rx, handles: vec![]}
89    }
90
91    /// Spawn a new thread, adding it to the ThreadGroup.
92    ///
93    /// Operates like [`std::thread::spawn()`], but modifies the ThreadGroup instead of returning a [`JoinHandle`].
94    /// # Examples
95    /// ```rust
96    /// use std::time::Duration;
97    /// use std::thread::sleep;
98    /// use threadgroup::{JoinError, ThreadGroup};
99    /// let mut tg: ThreadGroup<u32> = ThreadGroup::new();
100    /// tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
101    /// ```
102    /// [`std::thread::spawn()`]: https://doc.rust-lang.org/stable/std/thread/fn.spawn.html
103    /// [`JoinHandle`]:           https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html
104    // FIXME: Is there a way to remove the need to specify the type when calling spawn() ?
105    //        The ThreadGroup has type T and we know that R is the same type, but the compiler doesn't see that.
106    pub fn spawn<F, R>(&mut self, f: F)
107        where
108        F: FnOnce() -> T,
109        F: Send + 'static,
110        R: Send + 'static,
111        T: Send + 'static,
112    {
113        let thread_tx = self.tx.clone();
114        let jh: JoinHandle<T> = thread::spawn(move || {
115            let _sender = SendOnDrop{tx: thread_tx.clone()};
116            f()
117        });
118        self.handles.push(jh);
119    }
120
121    /// Return count of threads that have been [`spawn()`]ed but not yet [`join()`]ed.
122    /// [`spawn()`]: struct.ThreadGroup.html#method.spawn
123    /// [`join()`]: struct.ThreadGroup.html#method.join
124    pub fn len(&self) -> usize {
125        self.handles.len()
126    }
127
128    /// Check if there is any thread not yet [`join()`]ed.
129    /// [`join()`]: struct.ThreadGroup.html#method.join
130    pub fn is_empty(&self) -> bool {
131        self.handles.is_empty()
132    }
133
134    /// Join one thread of the ThreadGroup.
135    ///
136    /// Operates like [`std::thread::JoinHandle.join()`], but picks the first thread that
137    /// terminates.
138    /// # Examples
139    /// ```rust
140    /// use threadgroup::ThreadGroup;
141    /// let mut tg: ThreadGroup<u32> = ThreadGroup::new();
142    /// while !tg.is_empty() {
143    ///     match tg.join() {
144    ///         Ok(ret) => println!("Thread returned {}", ret),
145    ///         Err(e) => panic!("Oh noes !"),
146    ///     }
147    /// }
148    /// ```
149    /// [`std::thread::JoinHandle.join()`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html#method.join
150    pub fn join(&mut self) -> Result<T, JoinError> {
151        match self.handles.is_empty() {
152            true => Err(JoinError::AllDone),
153            false => match self.rx.recv() {
154                Ok(id) => self.do_join(id),
155                Err(RecvError{}) => Err(JoinError::Disconnected)
156            }
157        }
158    }
159
160    /// Try to join one thread of the ThreadGroup.
161    ///
162    /// Operates like [`std::thread::JoinHandle.join()`], but picks the first thread that terminates
163    /// and gives up if the timeout is reached.
164    /// # Examples
165    /// ```rust
166    /// use std::time::Duration;
167    /// use threadgroup::{JoinError, ThreadGroup};
168    /// let mut tg: ThreadGroup<u32> = ThreadGroup::new();
169    /// for _ in 0..10 {
170    ///    if let Err(JoinError::Timeout) = tg.join_timeout(Duration::new(0,10000)) {
171    ///        println!("Still working...");
172    ///    }
173    /// }
174    /// ```
175    /// [`std::thread::JoinHandle.join()`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html#method.join
176    pub fn join_timeout(&mut self, timeout: Duration) -> Result<T, JoinError> {
177        match self.handles.is_empty() {
178            true => Err(JoinError::AllDone),
179            false => match self.rx.recv_timeout(timeout) {
180                Ok(id) => self.do_join(id),
181                Err(RecvTimeoutError::Timeout) => Err(JoinError::Timeout),
182                Err(RecvTimeoutError::Disconnected) => Err(JoinError::Disconnected)
183            }
184        }
185    }
186
187    /// Find a thread by its id.
188    // TODO: replace with https://doc.rust-lang.org/nightly/std/vec/struct.Vec.html#method.remove_item
189    fn find(&self, id: ThreadId) -> Option<usize> {
190        for (i,jh) in self.handles.iter().enumerate() {
191            if jh.thread().id() == id {
192                return Some(i)
193            }
194        }
195        None
196    }
197
198    /// Actual [`JoinHandle.join()`] once we know that the thread has finished.
199    /// [`JoinHandle.join()`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html#method.join
200    fn do_join(&mut self, id: ThreadId) -> Result<T, JoinError> {
201        // We need to separately find and remove the JoinHandle from the vector in order to not upset the borrow checker
202        let i = self.find(id).unwrap();
203        match self.handles.remove(i).join() {
204            Ok(ret) => Ok(ret),
205            Err(_) => Err(JoinError::Panicked),
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use std::thread::sleep;
213    use std::time::Duration;
214    use ::{JoinError, ThreadGroup};
215
216    #[test]
217    fn empty_group() {
218        let mut tg: ThreadGroup<u32> = ThreadGroup::new();
219        assert!(tg.is_empty());
220        assert_eq!(tg.len(), 0);
221        assert_eq!(JoinError::AllDone, tg.join().unwrap_err());
222    }
223    #[test]
224    fn basic_join() {
225        let mut tg: ThreadGroup<u32> = ThreadGroup::new();
226        tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
227        tg.spawn::<_,u32>(|| {sleep(Duration::new(0,3000000));3});
228        tg.spawn::<_,u32>(|| {sleep(Duration::new(0,2000000));2});
229        assert_eq!(1, tg.join().unwrap());
230        assert_eq!(2, tg.join().unwrap());
231        assert_eq!(3, tg.join().unwrap());
232        assert_eq!(JoinError::AllDone, tg.join().unwrap_err());
233    }
234    #[test]
235    fn panic_join() {
236        let mut tg: ThreadGroup<u32> = ThreadGroup::new();
237        tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1500000));panic!()});
238        tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
239        assert_eq!(1,                   tg.join().unwrap());
240        assert_eq!(JoinError::Panicked, tg.join().unwrap_err());
241        assert_eq!(JoinError::AllDone,  tg.join().unwrap_err());
242    }
243    #[test]
244    fn timeout_join() {
245        let mut tg: ThreadGroup<u32> = ThreadGroup::new();
246        tg.spawn::<_,u32>(|| {sleep(Duration::new(1000000,0));2});
247        tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
248        let t = Duration::new(1,0);
249        assert_eq!(1,                  tg.join_timeout(t).unwrap());
250        assert_eq!(JoinError::Timeout, tg.join_timeout(t).unwrap_err());
251        assert!(!tg.is_empty());
252        assert_eq!(tg.len(), 1);
253    }
254}