threadpool_crossbeam_channel/
lib.rs

1// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2// file at the top-level directory of this distribution and at
3// http://rust-lang.org/COPYRIGHT.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11//! A thread pool used to execute functions in parallel.
12//!
13//! Spawns a specified number of worker threads and replenishes the pool if any worker threads
14//! panic.
15//!
16//! # Examples
17//!
18//! ## Synchronized with a channel
19//!
20//! Every thread sends one message over the channel, which then is collected with the `take()`.
21//!
22//! ```
23//! use threadpool::ThreadPool;
24//! use std::sync::mpsc::channel;
25//!
26//! let n_workers = 4;
27//! let n_jobs = 8;
28//! let pool = ThreadPool::new(n_workers);
29//!
30//! let (tx, rx) = channel();
31//! for _ in 0..n_jobs {
32//!     let tx = tx.clone();
33//!     pool.execute(move|| {
34//!         tx.send(1).expect("channel will be there waiting for the pool");
35//!     });
36//! }
37//!
38//! assert_eq!(rx.iter().take(n_jobs).fold(0, |a, b| a + b), 8);
39//! ```
40//!
41//! ## Synchronized with a barrier
42//!
43//! Keep in mind, if a barrier synchronizes more jobs than you have workers in the pool,
44//! you will end up with a [deadlock](https://en.wikipedia.org/wiki/Deadlock)
45//! at the barrier which is [not considered unsafe]
46//! (https://doc.rust-lang.org/reference/behavior-not-considered-unsafe.html).
47//!
48//! ```
49//! use threadpool::ThreadPool;
50//! use std::sync::{Arc, Barrier};
51//! use std::sync::atomic::{AtomicUsize, Ordering};
52//!
53//! // create at least as many workers as jobs or you will deadlock yourself
54//! let n_workers = 42;
55//! let n_jobs = 23;
56//! let pool = ThreadPool::new(n_workers);
57//! let an_atomic = Arc::new(AtomicUsize::new(0));
58//!
59//! assert!(n_jobs <= n_workers, "too many jobs, will deadlock");
60//!
61//! // create a barrier that waits for all jobs plus the starter thread
62//! let barrier = Arc::new(Barrier::new(n_jobs + 1));
63//! for _ in 0..n_jobs {
64//!     let barrier = barrier.clone();
65//!     let an_atomic = an_atomic.clone();
66//!
67//!     pool.execute(move|| {
68//!         // do the heavy work
69//!         an_atomic.fetch_add(1, Ordering::Relaxed);
70//!
71//!         // then wait for the other threads
72//!         barrier.wait();
73//!     });
74//! }
75//!
76//! // wait for the threads to finish the work
77//! barrier.wait();
78//! assert_eq!(an_atomic.load(Ordering::SeqCst), /* n_jobs = */ 23);
79//! ```
80
81extern crate num_cpus;
82extern crate crossbeam_channel;
83
84use crossbeam_channel::{unbounded, Receiver, Sender};
85
86use std::fmt;
87use std::sync::atomic::{AtomicUsize, Ordering};
88use std::sync::{Arc, Condvar, Mutex};
89use std::thread;
90
91trait FnBox {
92    fn call_box(self: Box<Self>);
93}
94
95impl<F: FnOnce()> FnBox for F {
96    fn call_box(self: Box<F>) {
97        (*self)()
98    }
99}
100
101type Thunk<'a> = Box<FnBox + Send + 'a>;
102
103struct Sentinel<'a> {
104    shared_data: &'a Arc<ThreadPoolSharedData>,
105    active: bool,
106}
107
108impl<'a> Sentinel<'a> {
109    fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> {
110        Sentinel {
111            shared_data: shared_data,
112            active: true,
113        }
114    }
115
116    /// Cancel and destroy this sentinel.
117    fn cancel(mut self) {
118        self.active = false;
119    }
120}
121
122impl<'a> Drop for Sentinel<'a> {
123    fn drop(&mut self) {
124        if self.active {
125            self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
126            if thread::panicking() {
127                self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
128            }
129            self.shared_data.no_work_notify_all();
130            spawn_in_pool(self.shared_data.clone())
131        }
132    }
133}
134
135/// [`ThreadPool`] factory, which can be used in order to configure the properties of the
136/// [`ThreadPool`].
137///
138/// The three configuration options available:
139///
140/// * `num_threads`: maximum number of threads that will be alive at any given moment by the built
141///   [`ThreadPool`]
142/// * `thread_name`: thread name for each of the threads spawned by the built [`ThreadPool`]
143/// * `thread_stack_size`: stack size (in bytes) for each of the threads spawned by the built
144///   [`ThreadPool`]
145///
146/// [`ThreadPool`]: struct.ThreadPool.html
147///
148/// # Examples
149///
150/// Build a [`ThreadPool`] that uses a maximum of eight threads simultaneously and each thread has
151/// a 8 MB stack size:
152///
153/// ```
154/// let pool = threadpool::Builder::new()
155///     .num_threads(8)
156///     .thread_stack_size(8_000_000)
157///     .build();
158/// ```
159#[derive(Clone, Default)]
160pub struct Builder {
161    num_threads: Option<usize>,
162    thread_name: Option<String>,
163    thread_stack_size: Option<usize>,
164}
165
166impl Builder {
167    /// Initiate a new [`Builder`].
168    ///
169    /// [`Builder`]: struct.Builder.html
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// let builder = threadpool::Builder::new();
175    /// ```
176    pub fn new() -> Builder {
177        Builder {
178            num_threads: None,
179            thread_name: None,
180            thread_stack_size: None,
181        }
182    }
183
184    /// Set the maximum number of worker-threads that will be alive at any given moment by the built
185    /// [`ThreadPool`]. If not specified, defaults the number of threads to the number of CPUs.
186    ///
187    /// [`ThreadPool`]: struct.ThreadPool.html
188    ///
189    /// # Panics
190    ///
191    /// This method will panic if `num_threads` is 0.
192    ///
193    /// # Examples
194    ///
195    /// No more than eight threads will be alive simultaneously for this pool:
196    ///
197    /// ```
198    /// use std::thread;
199    ///
200    /// let pool = threadpool::Builder::new()
201    ///     .num_threads(8)
202    ///     .build();
203    ///
204    /// for _ in 0..100 {
205    ///     pool.execute(|| {
206    ///         println!("Hello from a worker thread!")
207    ///     })
208    /// }
209    /// ```
210    pub fn num_threads(mut self, num_threads: usize) -> Builder {
211        assert!(num_threads > 0);
212        self.num_threads = Some(num_threads);
213        self
214    }
215
216    /// Set the thread name for each of the threads spawned by the built [`ThreadPool`]. If not
217    /// specified, threads spawned by the thread pool will be unnamed.
218    ///
219    /// [`ThreadPool`]: struct.ThreadPool.html
220    ///
221    /// # Examples
222    ///
223    /// Each thread spawned by this pool will have the name "foo":
224    ///
225    /// ```
226    /// use std::thread;
227    ///
228    /// let pool = threadpool::Builder::new()
229    ///     .thread_name("foo".into())
230    ///     .build();
231    ///
232    /// for _ in 0..100 {
233    ///     pool.execute(|| {
234    ///         assert_eq!(thread::current().name(), Some("foo"));
235    ///     })
236    /// }
237    /// ```
238    pub fn thread_name(mut self, name: String) -> Builder {
239        self.thread_name = Some(name);
240        self
241    }
242
243    /// Set the stack size (in bytes) for each of the threads spawned by the built [`ThreadPool`].
244    /// If not specified, threads spawned by the threadpool will have a stack size [as specified in
245    /// the `std::thread` documentation][thread].
246    ///
247    /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size
248    /// [`ThreadPool`]: struct.ThreadPool.html
249    ///
250    /// # Examples
251    ///
252    /// Each thread spawned by this pool will have a 4 MB stack:
253    ///
254    /// ```
255    /// let pool = threadpool::Builder::new()
256    ///     .thread_stack_size(4_000_000)
257    ///     .build();
258    ///
259    /// for _ in 0..100 {
260    ///     pool.execute(|| {
261    ///         println!("This thread has a 4 MB stack size!");
262    ///     })
263    /// }
264    /// ```
265    pub fn thread_stack_size(mut self, size: usize) -> Builder {
266        self.thread_stack_size = Some(size);
267        self
268    }
269
270    /// Finalize the [`Builder`] and build the [`ThreadPool`].
271    ///
272    /// [`Builder`]: struct.Builder.html
273    /// [`ThreadPool`]: struct.ThreadPool.html
274    ///
275    /// # Examples
276    ///
277    /// ```
278    /// let pool = threadpool::Builder::new()
279    ///     .num_threads(8)
280    ///     .thread_stack_size(4_000_000)
281    ///     .build();
282    /// ```
283    pub fn build(self) -> ThreadPool {
284        let (tx, rx) = unbounded::<Thunk<'static>>();
285
286        let num_threads = self.num_threads.unwrap_or_else(num_cpus::get);
287
288        let shared_data = Arc::new(ThreadPoolSharedData {
289            name: self.thread_name,
290            job_receiver: rx,
291            empty_condvar: Condvar::new(),
292            empty_trigger: Mutex::new(()),
293            join_generation: AtomicUsize::new(0),
294            queued_count: AtomicUsize::new(0),
295            active_count: AtomicUsize::new(0),
296            max_thread_count: AtomicUsize::new(num_threads),
297            panic_count: AtomicUsize::new(0),
298            stack_size: self.thread_stack_size,
299        });
300
301        // Threadpool threads
302        for _ in 0..num_threads {
303            spawn_in_pool(shared_data.clone());
304        }
305
306        ThreadPool {
307            jobs: tx,
308            shared_data: shared_data,
309        }
310    }
311}
312
313struct ThreadPoolSharedData {
314    name: Option<String>,
315    job_receiver: Receiver<Thunk<'static>>,
316    empty_trigger: Mutex<()>,
317    empty_condvar: Condvar,
318    join_generation: AtomicUsize,
319    queued_count: AtomicUsize,
320    active_count: AtomicUsize,
321    max_thread_count: AtomicUsize,
322    panic_count: AtomicUsize,
323    stack_size: Option<usize>,
324}
325
326impl ThreadPoolSharedData {
327    fn has_work(&self) -> bool {
328        self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
329    }
330
331    /// Notify all observers joining this pool if there is no more work to do.
332    fn no_work_notify_all(&self) {
333        if !self.has_work() {
334            *self.empty_trigger
335                .lock()
336                .expect("Unable to notify all joining threads");
337            self.empty_condvar.notify_all();
338        }
339    }
340}
341
342/// Abstraction of a thread pool for basic parallelism.
343pub struct ThreadPool {
344    // How the threadpool communicates with subthreads.
345    //
346    // This is the only such Sender, so when it is dropped all subthreads will
347    // quit.
348    jobs: Sender<Thunk<'static>>,
349    shared_data: Arc<ThreadPoolSharedData>,
350}
351
352impl ThreadPool {
353    /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
354    ///
355    /// # Panics
356    ///
357    /// This function will panic if `num_threads` is 0.
358    ///
359    /// # Examples
360    ///
361    /// Create a new thread pool capable of executing four jobs concurrently:
362    ///
363    /// ```
364    /// use threadpool::ThreadPool;
365    ///
366    /// let pool = ThreadPool::new(4);
367    /// ```
368    pub fn new(num_threads: usize) -> ThreadPool {
369        Builder::new().num_threads(num_threads).build()
370    }
371
372    /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
373    /// Each thread will have the [name][thread name] `name`.
374    ///
375    /// # Panics
376    ///
377    /// This function will panic if `num_threads` is 0.
378    ///
379    /// # Examples
380    ///
381    /// ```rust
382    /// use std::thread;
383    /// use threadpool::ThreadPool;
384    ///
385    /// let pool = ThreadPool::with_name("worker".into(), 2);
386    /// for _ in 0..2 {
387    ///     pool.execute(|| {
388    ///         assert_eq!(
389    ///             thread::current().name(),
390    ///             Some("worker")
391    ///         );
392    ///     });
393    /// }
394    /// pool.join();
395    /// ```
396    ///
397    /// [thread name]: https://doc.rust-lang.org/std/thread/struct.Thread.html#method.name
398    pub fn with_name(name: String, num_threads: usize) -> ThreadPool {
399        Builder::new()
400            .num_threads(num_threads)
401            .thread_name(name)
402            .build()
403    }
404
405    /// **Deprecated: Use [`ThreadPool::with_name`](#method.with_name)**
406    #[inline(always)]
407    #[deprecated(since = "1.4.0", note = "use ThreadPool::with_name")]
408    pub fn new_with_name(name: String, num_threads: usize) -> ThreadPool {
409        Self::with_name(name, num_threads)
410    }
411
412    /// Executes the function `job` on a thread in the pool.
413    ///
414    /// # Examples
415    ///
416    /// Execute four jobs on a thread pool that can run two jobs concurrently:
417    ///
418    /// ```
419    /// use threadpool::ThreadPool;
420    ///
421    /// let pool = ThreadPool::new(2);
422    /// pool.execute(|| println!("hello"));
423    /// pool.execute(|| println!("world"));
424    /// pool.execute(|| println!("foo"));
425    /// pool.execute(|| println!("bar"));
426    /// pool.join();
427    /// ```
428    pub fn execute<F>(&self, job: F)
429    where
430        F: FnOnce() + Send + 'static,
431    {
432        self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
433        self.jobs
434            .send(Box::new(job))
435            .expect("ThreadPool::execute unable to send job into queue.");
436    }
437
438    /// Returns the number of jobs waiting to executed in the pool.
439    ///
440    /// # Examples
441    ///
442    /// ```
443    /// use threadpool::ThreadPool;
444    /// use std::time::Duration;
445    /// use std::thread::sleep;
446    ///
447    /// let pool = ThreadPool::new(2);
448    /// for _ in 0..10 {
449    ///     pool.execute(|| {
450    ///         sleep(Duration::from_secs(100));
451    ///     });
452    /// }
453    ///
454    /// sleep(Duration::from_secs(1)); // wait for threads to start
455    /// assert_eq!(8, pool.queued_count());
456    /// ```
457    pub fn queued_count(&self) -> usize {
458        self.shared_data.queued_count.load(Ordering::Relaxed)
459    }
460
461    /// Returns the number of currently active threads.
462    ///
463    /// # Examples
464    ///
465    /// ```
466    /// use threadpool::ThreadPool;
467    /// use std::time::Duration;
468    /// use std::thread::sleep;
469    ///
470    /// let pool = ThreadPool::new(4);
471    /// for _ in 0..10 {
472    ///     pool.execute(move || {
473    ///         sleep(Duration::from_secs(100));
474    ///     });
475    /// }
476    ///
477    /// sleep(Duration::from_secs(1)); // wait for threads to start
478    /// assert_eq!(4, pool.active_count());
479    /// ```
480    pub fn active_count(&self) -> usize {
481        self.shared_data.active_count.load(Ordering::SeqCst)
482    }
483
484    /// Returns the maximum number of threads the pool will execute concurrently.
485    ///
486    /// # Examples
487    ///
488    /// ```
489    /// use threadpool::ThreadPool;
490    ///
491    /// let mut pool = ThreadPool::new(4);
492    /// assert_eq!(4, pool.max_count());
493    ///
494    /// pool.set_num_threads(8);
495    /// assert_eq!(8, pool.max_count());
496    /// ```
497    pub fn max_count(&self) -> usize {
498        self.shared_data.max_thread_count.load(Ordering::Relaxed)
499    }
500
501    /// Returns the number of panicked threads over the lifetime of the pool.
502    ///
503    /// # Examples
504    ///
505    /// ```
506    /// use threadpool::ThreadPool;
507    ///
508    /// let pool = ThreadPool::new(4);
509    /// for n in 0..10 {
510    ///     pool.execute(move || {
511    ///         // simulate a panic
512    ///         if n % 2 == 0 {
513    ///             panic!()
514    ///         }
515    ///     });
516    /// }
517    /// pool.join();
518    ///
519    /// assert_eq!(5, pool.panic_count());
520    /// ```
521    pub fn panic_count(&self) -> usize {
522        self.shared_data.panic_count.load(Ordering::Relaxed)
523    }
524
525    /// **Deprecated: Use [`ThreadPool::set_num_threads`](#method.set_num_threads)**
526    #[deprecated(since = "1.3.0", note = "use ThreadPool::set_num_threads")]
527    pub fn set_threads(&mut self, num_threads: usize) {
528        self.set_num_threads(num_threads)
529    }
530
531    /// Sets the number of worker-threads to use as `num_threads`.
532    /// Can be used to change the threadpool size during runtime.
533    /// Will not abort already running or waiting threads.
534    ///
535    /// # Panics
536    ///
537    /// This function will panic if `num_threads` is 0.
538    ///
539    /// # Examples
540    ///
541    /// ```
542    /// use threadpool::ThreadPool;
543    /// use std::time::Duration;
544    /// use std::thread::sleep;
545    ///
546    /// let mut pool = ThreadPool::new(4);
547    /// for _ in 0..10 {
548    ///     pool.execute(move || {
549    ///         sleep(Duration::from_secs(100));
550    ///     });
551    /// }
552    ///
553    /// sleep(Duration::from_secs(1)); // wait for threads to start
554    /// assert_eq!(4, pool.active_count());
555    /// assert_eq!(6, pool.queued_count());
556    ///
557    /// // Increase thread capacity of the pool
558    /// pool.set_num_threads(8);
559    ///
560    /// sleep(Duration::from_secs(1)); // wait for new threads to start
561    /// assert_eq!(8, pool.active_count());
562    /// assert_eq!(2, pool.queued_count());
563    ///
564    /// // Decrease thread capacity of the pool
565    /// // No active threads are killed
566    /// pool.set_num_threads(4);
567    ///
568    /// assert_eq!(8, pool.active_count());
569    /// assert_eq!(2, pool.queued_count());
570    /// ```
571    pub fn set_num_threads(&mut self, num_threads: usize) {
572        assert!(num_threads >= 1);
573        let prev_num_threads = self.shared_data
574            .max_thread_count
575            .swap(num_threads, Ordering::Release);
576        if let Some(num_spawn) = num_threads.checked_sub(prev_num_threads) {
577            // Spawn new threads
578            for _ in 0..num_spawn {
579                spawn_in_pool(self.shared_data.clone());
580            }
581        }
582    }
583
584    /// Block the current thread until all jobs in the pool have been executed.
585    ///
586    /// Calling `join` on an empty pool will cause an immediate return.
587    /// `join` may be called from multiple threads concurrently.
588    /// A `join` is an atomic point in time. All threads joining before the join
589    /// event will exit together even if the pool is processing new jobs by the
590    /// time they get scheduled.
591    ///
592    /// Calling `join` from a thread within the pool will cause a deadlock. This
593    /// behavior is considered safe.
594    ///
595    /// # Examples
596    ///
597    /// ```
598    /// use threadpool::ThreadPool;
599    /// use std::sync::Arc;
600    /// use std::sync::atomic::{AtomicUsize, Ordering};
601    ///
602    /// let pool = ThreadPool::new(8);
603    /// let test_count = Arc::new(AtomicUsize::new(0));
604    ///
605    /// for _ in 0..42 {
606    ///     let test_count = test_count.clone();
607    ///     pool.execute(move || {
608    ///         test_count.fetch_add(1, Ordering::Relaxed);
609    ///     });
610    /// }
611    ///
612    /// pool.join();
613    /// assert_eq!(42, test_count.load(Ordering::Relaxed));
614    /// ```
615    pub fn join(&self) {
616        // fast path requires no mutex
617        if self.shared_data.has_work() == false {
618            return ();
619        }
620
621        let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
622        let mut lock = self.shared_data.empty_trigger.lock().unwrap();
623
624        while generation == self.shared_data.join_generation.load(Ordering::Relaxed)
625            && self.shared_data.has_work()
626        {
627            lock = self.shared_data.empty_condvar.wait(lock).unwrap();
628        }
629
630        // increase generation if we are the first thread to come out of the loop
631        self.shared_data.join_generation.compare_and_swap(
632            generation,
633            generation.wrapping_add(1),
634            Ordering::SeqCst,
635        );
636    }
637}
638
639impl Clone for ThreadPool {
640    /// Cloning a pool will create a new handle to the pool.
641    /// The behavior is similar to [Arc](https://doc.rust-lang.org/stable/std/sync/struct.Arc.html).
642    ///
643    /// We could for example submit jobs from multiple threads concurrently.
644    ///
645    /// ```
646    /// use threadpool::ThreadPool;
647    /// use std::thread;
648    /// use std::sync::mpsc::channel;
649    ///
650    /// let pool = ThreadPool::with_name("clone example".into(), 2);
651    ///
652    /// let results = (0..2)
653    ///     .map(|i| {
654    ///         let pool = pool.clone();
655    ///         thread::spawn(move || {
656    ///             let (tx, rx) = channel();
657    ///             for i in 1..12 {
658    ///                 let tx = tx.clone();
659    ///                 pool.execute(move || {
660    ///                     tx.send(i).expect("channel will be waiting");
661    ///                 });
662    ///             }
663    ///             drop(tx);
664    ///             if i == 0 {
665    ///                 rx.iter().fold(0, |accumulator, element| accumulator + element)
666    ///             } else {
667    ///                 rx.iter().fold(1, |accumulator, element| accumulator * element)
668    ///             }
669    ///         })
670    ///     })
671    ///     .map(|join_handle| join_handle.join().expect("collect results from threads"))
672    ///     .collect::<Vec<usize>>();
673    ///
674    /// assert_eq!(vec![66, 39916800], results);
675    /// ```
676    fn clone(&self) -> ThreadPool {
677        ThreadPool {
678            jobs: self.jobs.clone(),
679            shared_data: self.shared_data.clone(),
680        }
681    }
682}
683
684/// Create a thread pool with one thread per CPU.
685/// On machines with hyperthreading,
686/// this will create one thread per hyperthread.
687impl Default for ThreadPool {
688    fn default() -> Self {
689        ThreadPool::new(num_cpus::get())
690    }
691}
692
693impl fmt::Debug for ThreadPool {
694    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
695        f.debug_struct("ThreadPool")
696            .field("name", &self.shared_data.name)
697            .field("queued_count", &self.queued_count())
698            .field("active_count", &self.active_count())
699            .field("max_count", &self.max_count())
700            .finish()
701    }
702}
703
704impl PartialEq for ThreadPool {
705    /// Check if you are working with the same pool
706    ///
707    /// ```
708    /// use threadpool::ThreadPool;
709    ///
710    /// let a = ThreadPool::new(2);
711    /// let b = ThreadPool::new(2);
712    ///
713    /// assert_eq!(a, a);
714    /// assert_eq!(b, b);
715    ///
716    /// # // TODO: change this to assert_ne in the future
717    /// assert!(a != b);
718    /// assert!(b != a);
719    /// ```
720    fn eq(&self, other: &ThreadPool) -> bool {
721        let a: &ThreadPoolSharedData = &*self.shared_data;
722        let b: &ThreadPoolSharedData = &*other.shared_data;
723        a as *const ThreadPoolSharedData == b as *const ThreadPoolSharedData
724        // with rust 1.17 and late:
725        // Arc::ptr_eq(&self.shared_data, &other.shared_data)
726    }
727}
728impl Eq for ThreadPool {}
729
730fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>) {
731    let mut builder = thread::Builder::new();
732    if let Some(ref name) = shared_data.name {
733        builder = builder.name(name.clone());
734    }
735    if let Some(ref stack_size) = shared_data.stack_size {
736        builder = builder.stack_size(stack_size.to_owned());
737    }
738    builder
739        .spawn(move || {
740            // Will spawn a new thread on panic unless it is cancelled.
741            let sentinel = Sentinel::new(&shared_data);
742
743            loop {
744                // Shutdown this thread if the pool has become smaller
745                let thread_counter_val = shared_data.active_count.load(Ordering::Acquire);
746                let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
747                if thread_counter_val >= max_thread_count_val {
748                    break;
749                }
750                let message = shared_data.job_receiver.recv();
751
752                let job = match message {
753                    Ok(job) => job,
754                    // The ThreadPool was dropped.
755                    Err(..) => break,
756                };
757                // Do not allow IR around the job execution
758                shared_data.active_count.fetch_add(1, Ordering::SeqCst);
759                shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
760
761                job.call_box();
762
763                shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
764                shared_data.no_work_notify_all();
765            }
766
767            sentinel.cancel();
768        })
769        .unwrap();
770}
771
772#[cfg(test)]
773mod test {
774    use super::{Builder, ThreadPool};
775    use std::sync::atomic::{AtomicUsize, Ordering};
776    use std::sync::mpsc::{channel, sync_channel};
777    use std::sync::{Arc, Barrier};
778    use std::thread::{self, sleep};
779    use std::time::Duration;
780
781    const TEST_TASKS: usize = 4;
782
783    #[test]
784    fn test_set_num_threads_increasing() {
785        let new_thread_amount = TEST_TASKS + 8;
786        let mut pool = ThreadPool::new(TEST_TASKS);
787        for _ in 0..TEST_TASKS {
788            pool.execute(move || sleep(Duration::from_secs(23)));
789        }
790        sleep(Duration::from_secs(1));
791        assert_eq!(pool.active_count(), TEST_TASKS);
792
793        pool.set_num_threads(new_thread_amount);
794
795        for _ in 0..(new_thread_amount - TEST_TASKS) {
796            pool.execute(move || sleep(Duration::from_secs(23)));
797        }
798        sleep(Duration::from_secs(1));
799        assert_eq!(pool.active_count(), new_thread_amount);
800
801        pool.join();
802    }
803
804    #[test]
805    fn test_set_num_threads_decreasing() {
806        let new_thread_amount = 2;
807        let mut pool = ThreadPool::new(TEST_TASKS);
808        for _ in 0..TEST_TASKS {
809            pool.execute(move || {
810                assert_eq!(1, 1);
811            });
812        }
813        pool.set_num_threads(new_thread_amount);
814        for _ in 0..new_thread_amount {
815            pool.execute(move || sleep(Duration::from_secs(23)));
816        }
817        sleep(Duration::from_secs(1));
818        assert_eq!(pool.active_count(), new_thread_amount);
819
820        pool.join();
821    }
822
823    #[test]
824    fn test_active_count() {
825        let pool = ThreadPool::new(TEST_TASKS);
826        for _ in 0..2 * TEST_TASKS {
827            pool.execute(move || loop {
828                sleep(Duration::from_secs(10))
829            });
830        }
831        sleep(Duration::from_secs(1));
832        let active_count = pool.active_count();
833        assert_eq!(active_count, TEST_TASKS);
834        let initialized_count = pool.max_count();
835        assert_eq!(initialized_count, TEST_TASKS);
836    }
837
838    #[test]
839    fn test_works() {
840        let pool = ThreadPool::new(TEST_TASKS);
841
842        let (tx, rx) = channel();
843        for _ in 0..TEST_TASKS {
844            let tx = tx.clone();
845            pool.execute(move || {
846                tx.send(1).unwrap();
847            });
848        }
849
850        assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
851    }
852
853    #[test]
854    #[should_panic]
855    fn test_zero_tasks_panic() {
856        ThreadPool::new(0);
857    }
858
859    #[test]
860    fn test_recovery_from_subtask_panic() {
861        let pool = ThreadPool::new(TEST_TASKS);
862
863        // Panic all the existing threads.
864        for _ in 0..TEST_TASKS {
865            pool.execute(move || panic!("Ignore this panic, it must!"));
866        }
867        pool.join();
868
869        assert_eq!(pool.panic_count(), TEST_TASKS);
870
871        // Ensure new threads were spawned to compensate.
872        let (tx, rx) = channel();
873        for _ in 0..TEST_TASKS {
874            let tx = tx.clone();
875            pool.execute(move || {
876                tx.send(1).unwrap();
877            });
878        }
879
880        assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
881    }
882
883    #[test]
884    fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
885        let pool = ThreadPool::new(TEST_TASKS);
886        let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
887
888        // Panic all the existing threads in a bit.
889        for _ in 0..TEST_TASKS {
890            let waiter = waiter.clone();
891            pool.execute(move || {
892                waiter.wait();
893                panic!("Ignore this panic, it should!");
894            });
895        }
896
897        drop(pool);
898
899        // Kick off the failure.
900        waiter.wait();
901    }
902
903    #[test]
904    fn test_massive_task_creation() {
905        let test_tasks = 4_200_000;
906
907        let pool = ThreadPool::new(TEST_TASKS);
908        let b0 = Arc::new(Barrier::new(TEST_TASKS + 1));
909        let b1 = Arc::new(Barrier::new(TEST_TASKS + 1));
910
911        let (tx, rx) = channel();
912
913        for i in 0..test_tasks {
914            let tx = tx.clone();
915            let (b0, b1) = (b0.clone(), b1.clone());
916
917            pool.execute(move || {
918                // Wait until the pool has been filled once.
919                if i < TEST_TASKS {
920                    b0.wait();
921                    // wait so the pool can be measured
922                    b1.wait();
923                }
924
925                tx.send(1).unwrap();
926            });
927        }
928
929        b0.wait();
930        assert_eq!(pool.active_count(), TEST_TASKS);
931        b1.wait();
932
933        assert_eq!(rx.iter().take(test_tasks).fold(0, |a, b| a + b), test_tasks);
934        pool.join();
935
936        let atomic_active_count = pool.active_count();
937        assert!(
938            atomic_active_count == 0,
939            "atomic_active_count: {}",
940            atomic_active_count
941        );
942    }
943
944    #[test]
945    fn test_shrink() {
946        let test_tasks_begin = TEST_TASKS + 2;
947
948        let mut pool = ThreadPool::new(test_tasks_begin);
949        let b0 = Arc::new(Barrier::new(test_tasks_begin + 1));
950        let b1 = Arc::new(Barrier::new(test_tasks_begin + 1));
951
952        for _ in 0..test_tasks_begin {
953            let (b0, b1) = (b0.clone(), b1.clone());
954            pool.execute(move || {
955                b0.wait();
956                b1.wait();
957            });
958        }
959
960        let b2 = Arc::new(Barrier::new(TEST_TASKS + 1));
961        let b3 = Arc::new(Barrier::new(TEST_TASKS + 1));
962
963        for _ in 0..TEST_TASKS {
964            let (b2, b3) = (b2.clone(), b3.clone());
965            pool.execute(move || {
966                b2.wait();
967                b3.wait();
968            });
969        }
970
971        b0.wait();
972        pool.set_num_threads(TEST_TASKS);
973
974        assert_eq!(pool.active_count(), test_tasks_begin);
975        b1.wait();
976
977        b2.wait();
978        assert_eq!(pool.active_count(), TEST_TASKS);
979        b3.wait();
980    }
981
982    #[test]
983    fn test_name() {
984        let name = "test";
985        let mut pool = ThreadPool::with_name(name.to_owned(), 2);
986        let (tx, rx) = sync_channel(0);
987
988        // initial thread should share the name "test"
989        for _ in 0..2 {
990            let tx = tx.clone();
991            pool.execute(move || {
992                let name = thread::current().name().unwrap().to_owned();
993                tx.send(name).unwrap();
994            });
995        }
996
997        // new spawn thread should share the name "test" too.
998        pool.set_num_threads(3);
999        let tx_clone = tx.clone();
1000        pool.execute(move || {
1001            let name = thread::current().name().unwrap().to_owned();
1002            tx_clone.send(name).unwrap();
1003            panic!();
1004        });
1005
1006        // recover thread should share the name "test" too.
1007        pool.execute(move || {
1008            let name = thread::current().name().unwrap().to_owned();
1009            tx.send(name).unwrap();
1010        });
1011
1012        for thread_name in rx.iter().take(4) {
1013            assert_eq!(name, thread_name);
1014        }
1015    }
1016
1017    #[test]
1018    fn test_debug() {
1019        let pool = ThreadPool::new(4);
1020        let debug = format!("{:?}", pool);
1021        assert_eq!(
1022            debug,
1023            "ThreadPool { name: None, queued_count: 0, active_count: 0, max_count: 4 }"
1024        );
1025
1026        let pool = ThreadPool::with_name("hello".into(), 4);
1027        let debug = format!("{:?}", pool);
1028        assert_eq!(
1029            debug,
1030            "ThreadPool { name: Some(\"hello\"), queued_count: 0, active_count: 0, max_count: 4 }"
1031        );
1032
1033        let pool = ThreadPool::new(4);
1034        pool.execute(move || sleep(Duration::from_secs(5)));
1035        sleep(Duration::from_secs(1));
1036        let debug = format!("{:?}", pool);
1037        assert_eq!(
1038            debug,
1039            "ThreadPool { name: None, queued_count: 0, active_count: 1, max_count: 4 }"
1040        );
1041    }
1042
1043    #[test]
1044    fn test_repeate_join() {
1045        let pool = ThreadPool::with_name("repeate join test".into(), 8);
1046        let test_count = Arc::new(AtomicUsize::new(0));
1047
1048        for _ in 0..42 {
1049            let test_count = test_count.clone();
1050            pool.execute(move || {
1051                sleep(Duration::from_secs(2));
1052                test_count.fetch_add(1, Ordering::Release);
1053            });
1054        }
1055
1056        println!("{:?}", pool);
1057        pool.join();
1058        assert_eq!(42, test_count.load(Ordering::Acquire));
1059
1060        for _ in 0..42 {
1061            let test_count = test_count.clone();
1062            pool.execute(move || {
1063                sleep(Duration::from_secs(2));
1064                test_count.fetch_add(1, Ordering::Relaxed);
1065            });
1066        }
1067        pool.join();
1068        assert_eq!(84, test_count.load(Ordering::Relaxed));
1069    }
1070
1071    #[test]
1072    fn test_multi_join() {
1073        use std::sync::mpsc::TryRecvError::*;
1074
1075        // Toggle the following lines to debug the deadlock
1076        fn error(_s: String) {
1077            //use ::std::io::Write;
1078            //let stderr = ::std::io::stderr();
1079            //let mut stderr = stderr.lock();
1080            //stderr.write(&_s.as_bytes()).is_ok();
1081        }
1082
1083        let pool0 = ThreadPool::with_name("multi join pool0".into(), 4);
1084        let pool1 = ThreadPool::with_name("multi join pool1".into(), 4);
1085        let (tx, rx) = channel();
1086
1087        for i in 0..8 {
1088            let pool1 = pool1.clone();
1089            let pool0_ = pool0.clone();
1090            let tx = tx.clone();
1091            pool0.execute(move || {
1092                pool1.execute(move || {
1093                    error(format!("p1: {} -=- {:?}\n", i, pool0_));
1094                    pool0_.join();
1095                    error(format!("p1: send({})\n", i));
1096                    tx.send(i).expect("send i from pool1 -> main");
1097                });
1098                error(format!("p0: {}\n", i));
1099            });
1100        }
1101        drop(tx);
1102
1103        assert_eq!(rx.try_recv(), Err(Empty));
1104        error(format!("{:?}\n{:?}\n", pool0, pool1));
1105        pool0.join();
1106        error(format!("pool0.join() complete =-= {:?}", pool1));
1107        pool1.join();
1108        error("pool1.join() complete\n".into());
1109        assert_eq!(
1110            rx.iter().fold(0, |acc, i| acc + i),
1111            0 + 1 + 2 + 3 + 4 + 5 + 6 + 7
1112        );
1113    }
1114
1115    #[test]
1116    fn test_empty_pool() {
1117        // Joining an empty pool must return imminently
1118        let pool = ThreadPool::new(4);
1119
1120        pool.join();
1121
1122        assert!(true);
1123    }
1124
1125    #[test]
1126    fn test_no_fun_or_joy() {
1127        // What happens when you keep adding jobs after a join
1128
1129        fn sleepy_function() {
1130            sleep(Duration::from_secs(6));
1131        }
1132
1133        let pool = ThreadPool::with_name("no fun or joy".into(), 8);
1134
1135        pool.execute(sleepy_function);
1136
1137        let p_t = pool.clone();
1138        thread::spawn(move || {
1139            (0..23).map(|_| p_t.execute(sleepy_function)).count();
1140        });
1141
1142        pool.join();
1143    }
1144
1145    #[test]
1146    fn test_clone() {
1147        let pool = ThreadPool::with_name("clone example".into(), 2);
1148
1149        // This batch of jobs will occupy the pool for some time
1150        for _ in 0..6 {
1151            pool.execute(move || {
1152                sleep(Duration::from_secs(2));
1153            });
1154        }
1155
1156        // The following jobs will be inserted into the pool in a random fashion
1157        let t0 = {
1158            let pool = pool.clone();
1159            thread::spawn(move || {
1160                // wait for the first batch of tasks to finish
1161                pool.join();
1162
1163                let (tx, rx) = channel();
1164                for i in 0..42 {
1165                    let tx = tx.clone();
1166                    pool.execute(move || {
1167                        tx.send(i).expect("channel will be waiting");
1168                    });
1169                }
1170                drop(tx);
1171                rx.iter()
1172                    .fold(0, |accumulator, element| accumulator + element)
1173            })
1174        };
1175        let t1 = {
1176            let pool = pool.clone();
1177            thread::spawn(move || {
1178                // wait for the first batch of tasks to finish
1179                pool.join();
1180
1181                let (tx, rx) = channel();
1182                for i in 1..12 {
1183                    let tx = tx.clone();
1184                    pool.execute(move || {
1185                        tx.send(i).expect("channel will be waiting");
1186                    });
1187                }
1188                drop(tx);
1189                rx.iter()
1190                    .fold(1, |accumulator, element| accumulator * element)
1191            })
1192        };
1193
1194        assert_eq!(
1195            861,
1196            t0.join()
1197                .expect("thread 0 will return after calculating additions",)
1198        );
1199        assert_eq!(
1200            39916800,
1201            t1.join()
1202                .expect("thread 1 will return after calculating multiplications",)
1203        );
1204    }
1205
1206    #[test]
1207    fn test_sync_shared_data() {
1208        fn assert_sync<T: Sync>() {}
1209        assert_sync::<super::ThreadPoolSharedData>();
1210    }
1211
1212    #[test]
1213    fn test_send_shared_data() {
1214        fn assert_send<T: Send>() {}
1215        assert_send::<super::ThreadPoolSharedData>();
1216    }
1217
1218    #[test]
1219    fn test_send() {
1220        fn assert_send<T: Send>() {}
1221        assert_send::<ThreadPool>();
1222    }
1223
1224    #[test]
1225    fn test_cloned_eq() {
1226        let a = ThreadPool::new(2);
1227
1228        assert_eq!(a, a.clone());
1229    }
1230
1231    #[test]
1232    /// The scenario is joining threads should not be stuck once their wave
1233    /// of joins has completed. So once one thread joining on a pool has
1234    /// succeded other threads joining on the same pool must get out even if
1235    /// the thread is used for other jobs while the first group is finishing
1236    /// their join
1237    ///
1238    /// In this example this means the waiting threads will exit the join in
1239    /// groups of four because the waiter pool has four workers.
1240    fn test_join_wavesurfer() {
1241        let n_cycles = 4;
1242        let n_workers = 4;
1243        let (tx, rx) = channel();
1244        let builder = Builder::new()
1245            .num_threads(n_workers)
1246            .thread_name("join wavesurfer".into());
1247        let p_waiter = builder.clone().build();
1248        let p_clock = builder.build();
1249
1250        let barrier = Arc::new(Barrier::new(3));
1251        let wave_clock = Arc::new(AtomicUsize::new(0));
1252        let clock_thread = {
1253            let barrier = barrier.clone();
1254            let wave_clock = wave_clock.clone();
1255            thread::spawn(move || {
1256                barrier.wait();
1257                for wave_num in 0..n_cycles {
1258                    wave_clock.store(wave_num, Ordering::SeqCst);
1259                    sleep(Duration::from_secs(1));
1260                }
1261            })
1262        };
1263
1264        {
1265            let barrier = barrier.clone();
1266            p_clock.execute(move || {
1267                barrier.wait();
1268                // this sleep is for stabilisation on weaker platforms
1269                sleep(Duration::from_millis(100));
1270            });
1271        }
1272
1273        // prepare three waves of jobs
1274        for i in 0..3 * n_workers {
1275            let p_clock = p_clock.clone();
1276            let tx = tx.clone();
1277            let wave_clock = wave_clock.clone();
1278            p_waiter.execute(move || {
1279                let now = wave_clock.load(Ordering::SeqCst);
1280                p_clock.join();
1281                // submit jobs for the second wave
1282                p_clock.execute(|| sleep(Duration::from_secs(1)));
1283                let clock = wave_clock.load(Ordering::SeqCst);
1284                tx.send((now, clock, i)).unwrap();
1285            });
1286        }
1287        println!("all scheduled at {}", wave_clock.load(Ordering::SeqCst));
1288        barrier.wait();
1289
1290        p_clock.join();
1291        //p_waiter.join();
1292
1293        drop(tx);
1294        let mut hist = vec![0; n_cycles];
1295        let mut data = vec![];
1296        for (now, after, i) in rx.iter() {
1297            let mut dur = after - now;
1298            if dur >= n_cycles - 1 {
1299                dur = n_cycles - 1;
1300            }
1301            hist[dur] += 1;
1302
1303            data.push((now, after, i));
1304        }
1305        for (i, n) in hist.iter().enumerate() {
1306            println!(
1307                "\t{}: {} {}",
1308                i,
1309                n,
1310                &*(0..*n).fold("".to_owned(), |s, _| s + "*")
1311            );
1312        }
1313        assert!(data.iter().all(|&(cycle, stop, i)| if i < n_workers {
1314            cycle == stop
1315        } else {
1316            cycle < stop
1317        }));
1318
1319        clock_thread.join().unwrap();
1320    }
1321}