scoped_thread_pool/
lib.rs

1#![cfg_attr(test, deny(warnings))]
2#![deny(missing_docs)]
3
4//! # scoped-pool
5//!
6//! A flexible thread pool providing scoped threads.
7//!
8
9extern crate crossbeam;
10extern crate variance;
11
12#[macro_use]
13extern crate scopeguard;
14
15use crossbeam::channel::{Sender, Receiver, unbounded};
16use variance::InvariantLifetime as Id;
17
18use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
19use std::sync::{Arc, Condvar, Mutex};
20use std::{mem, thread};
21
22/// A thread-pool providing scoped and unscoped threads.
23///
24/// The primary ways of interacting with the `Pool` are
25/// the `spawn` and `scoped` convenience methods or through
26/// the `Scope` type directly.
27#[derive(Clone, Default)]
28pub struct Pool {
29    wait: Arc<WaitGroup>,
30    inner: Arc<PoolInner>,
31}
32
33impl Pool {
34    /// Create a new Pool with `size` threads.
35    ///
36    /// If `size` is zero, no threads will be spawned. Threads can
37    /// be added later via `expand`.
38    ///
39    /// NOTE: Since Pool can be freely cloned, it does not represent a unique
40    /// handle to the thread pool. As a consequence, the thread pool is not
41    /// automatically shut down; you must explicitly call `Pool::shutdown` to
42    /// shut down the pool.
43    #[inline]
44    pub fn new(size: usize) -> Pool {
45        // Create an empty pool.
46        let pool = Pool::empty();
47
48        // Start the requested number of threads.
49        for _ in 0..size {
50            pool.expand();
51        }
52
53        pool
54    }
55
56    /// Create a new Pool with `size` threads and given thread config.
57    ///
58    /// If `size` is zero, no threads will be spawned. Threads can
59    /// be added later via `expand`.
60    ///
61    /// NOTE: Since Pool can be freely cloned, it does not represent a unique
62    /// handle to the thread pool. As a consequence, the thread pool is not
63    /// automatically shut down; you must explicitly call `Pool::shutdown` to
64    /// shut down the pool.
65    #[inline]
66    pub fn with_thread_config(size: usize, thread_config: ThreadConfig) -> Pool {
67        // Create an empty pool with configuration.
68        let pool = Pool {
69            inner: Arc::new(PoolInner::with_thread_config(thread_config)),
70            ..Pool::default()
71        };
72
73        // Start the requested number of threads.
74        for _ in 0..size {
75            pool.expand();
76        }
77
78        pool
79    }
80
81    /// Create an empty Pool, with no threads.
82    ///
83    /// Note that no jobs will run until `expand` is called and
84    /// worker threads are added.
85    #[inline]
86    pub fn empty() -> Pool {
87        Pool::default()
88    }
89
90    /// How many worker threads are currently active.
91    #[inline]
92    pub fn workers(&self) -> usize {
93        // All threads submit themselves when they start and
94        // complete when they stop, so the threads we are waiting
95        // for are still active.
96        self.wait.waiting()
97    }
98
99    /// Spawn a `'static'` job to be run on this pool.
100    ///
101    /// We do not wait on the job to complete.
102    ///
103    /// Panics in the job will propogate to the calling thread.
104    #[inline]
105    pub fn spawn<F: FnOnce() + Send + 'static>(&self, job: F) {
106        // Run the job on a scope which lasts forever, and won't block.
107        Scope::forever(self.clone()).execute(job)
108    }
109
110    /// Create a Scope for scheduling a group of jobs in `'scope'`.
111    ///
112    /// `scoped` will return only when the `scheduler` function and
113    /// all jobs queued on the given Scope have been run.
114    ///
115    /// Panics in any of the jobs or in the scheduler function itself
116    /// will propogate to the calling thread.
117    #[inline]
118    pub fn scoped<'scope, F, R>(&self, scheduler: F) -> R
119    where
120        F: FnOnce(&Scope<'scope>) -> R,
121    {
122        // Zoom to the correct scope, then run the scheduler.
123        Scope::forever(self.clone()).zoom(scheduler)
124    }
125
126    /// Shutdown the Pool.
127    ///
128    /// WARNING: Extreme care should be taken to not call shutdown concurrently
129    /// with any scoped calls, or deadlock can occur.
130    ///
131    /// All threads will be shut down eventually, but only threads started before the
132    /// call to shutdown are guaranteed to be shut down before the call to shutdown
133    /// returns.
134    #[inline]
135    pub fn shutdown(&self) {
136        // Start the shutdown process.
137        self.inner.queue.push(PoolMessage::Quit);
138
139        // Wait for it to complete.
140        self.wait.join()
141    }
142
143    /// Expand the Pool by spawning an additional thread.
144    ///
145    /// Can accelerate the completion of running jobs.
146    #[inline]
147    pub fn expand(&self) {
148        let pool = self.clone();
149
150        // Submit the new thread to the thread waitgroup.
151        pool.wait.submit();
152
153        let thread_number = self.inner.thread_counter.fetch_add(1, Ordering::SeqCst);
154
155        // Deal with thread configuration.
156        let mut builder = thread::Builder::new();
157        if let Some(ref prefix) = self.inner.thread_config.prefix {
158            let name = format!("{}{}", prefix, thread_number);
159            builder = builder.name(name);
160        }
161        if let Some(stack_size) = self.inner.thread_config.stack_size {
162            builder = builder.stack_size(stack_size);
163        }
164
165        // Start the actual thread.
166        builder.spawn(move || pool.run_thread()).unwrap();
167    }
168
169    fn run_thread(self) {
170        // Create a sentinel to capture panics on this thread.
171        let mut thread_sentinel = ThreadSentinel(Some(self.clone()));
172
173        loop {
174            match self.inner.queue.pop() {
175                // On Quit, repropogate and quit.
176                PoolMessage::Quit => {
177                    // Repropogate the Quit message to other threads.
178                    self.inner.queue.push(PoolMessage::Quit);
179
180                    // Cancel the thread sentinel so we don't panic waiting
181                    // shutdown threads, and don't restart the thread.
182                    thread_sentinel.cancel();
183
184                    // Terminate the thread.
185                    break;
186                }
187
188                // On Task, run the task then complete the WaitGroup.
189                PoolMessage::Task(job, wait) => {
190                    let sentinel = Sentinel(self.clone(), Some(wait.clone()));
191                    job.run();
192                    sentinel.cancel();
193                }
194            }
195        }
196    }
197}
198
199struct BlockingQueue<T> {
200    sender: Sender<T>,
201    receiver: Receiver<T>,
202}
203
204impl<T> BlockingQueue<T> {
205    fn new() -> BlockingQueue<T> {
206        let (tx, rx) = unbounded();
207        BlockingQueue {
208            sender: tx,
209            receiver: rx,
210        }
211    }
212
213    fn pop(&self) -> T {
214        self.receiver.recv().unwrap()
215    }
216
217    fn push(&self, message: T) {
218        self.sender.send(message).unwrap();
219    }
220}
221
222struct PoolInner {
223    queue: BlockingQueue<PoolMessage>,
224    thread_config: ThreadConfig,
225    thread_counter: AtomicUsize,
226}
227
228impl PoolInner {
229    fn with_thread_config(thread_config: ThreadConfig) -> Self {
230        PoolInner {
231            thread_config,
232            ..Self::default()
233        }
234    }
235}
236
237impl Default for PoolInner {
238    fn default() -> Self {
239        PoolInner {
240            queue: BlockingQueue::new(),
241            thread_config: ThreadConfig::default(),
242            thread_counter: AtomicUsize::new(1),
243        }
244    }
245}
246
247/// Thread configuration. Provides detailed control over the properties and behavior of new
248/// threads.
249#[derive(Default)]
250pub struct ThreadConfig {
251    prefix: Option<String>,
252    stack_size: Option<usize>,
253}
254
255impl ThreadConfig {
256    /// Generates the base configuration for spawning a thread, from which configuration methods
257    /// can be chained.
258    pub fn new() -> ThreadConfig {
259        ThreadConfig {
260            prefix: None,
261            stack_size: None,
262        }
263    }
264
265    /// Name prefix of spawned threads. Thread number will be appended to this prefix to form each
266    /// thread's unique name. Currently the name is used for identification only in panic
267    /// messages.
268    pub fn prefix<S: Into<String>>(self, prefix: S) -> ThreadConfig {
269        ThreadConfig {
270            prefix: Some(prefix.into()),
271            ..self
272        }
273    }
274
275    /// Sets the size of the stack for the new thread.
276    pub fn stack_size(self, stack_size: usize) -> ThreadConfig {
277        ThreadConfig {
278            stack_size: Some(stack_size),
279            ..self
280        }
281    }
282}
283
284/// An execution scope, represents a set of jobs running on a Pool.
285///
286/// ## Understanding Scope lifetimes
287///
288/// Besides `Scope<'static>`, all `Scope` objects are accessed behind a
289/// reference of the form `&'scheduler Scope<'scope>`.
290///
291/// `'scheduler` is the lifetime associated with the *body* of the
292/// "scheduler" function (functions passed to `zoom`/`scoped`).
293///
294/// `'scope` is the lifetime which data captured in `execute` or `recurse`
295/// closures must outlive - in other words, `'scope` is the maximum lifetime
296/// of all jobs scheduler on a `Scope`.
297///
298/// Note that since `'scope: 'scheduler` (`'scope` outlives `'scheduler`)
299/// `&'scheduler Scope<'scope>` can't be captured in an `execute` closure;
300/// this is the reason for the existence of the `recurse` API, which will
301/// inject the same scope with a new `'scheduler` lifetime (this time set
302/// to the body of the function passed to `recurse`).
303pub struct Scope<'scope> {
304    pool: Pool,
305    wait: Arc<WaitGroup>,
306    _scope: Id<'scope>,
307}
308
309impl<'scope> Scope<'scope> {
310    /// Create a Scope which lasts forever.
311    #[inline]
312    pub fn forever(pool: Pool) -> Scope<'static> {
313        Scope {
314            pool,
315            wait: Arc::new(WaitGroup::new()),
316            _scope: Id::default(),
317        }
318    }
319
320    /// Add a job to this scope.
321    ///
322    /// Subsequent calls to `join` will wait for this job to complete.
323    pub fn execute<F>(&self, job: F)
324    where
325        F: FnOnce() + Send + 'scope,
326    {
327        // Submit the job *before* submitting it to the queue.
328        self.wait.submit();
329
330        let task = unsafe {
331            // Safe because we will ensure the task finishes executing before
332            // 'scope via joining before the resolution of `'scope`.
333            mem::transmute::<Box<dyn Task + Send + 'scope>, Box<dyn Task + Send + 'static>>(
334                Box::new(job),
335            )
336        };
337
338        // Submit the task to be executed.
339        self.pool
340            .inner
341            .queue
342            .push(PoolMessage::Task(task, self.wait.clone()));
343    }
344
345    /// Add a job to this scope which itself will get access to the scope.
346    ///
347    /// Like with `execute`, subsequent calls to `join` will wait for this
348    /// job (and all jobs scheduled on the scope it receives) to complete.
349    pub fn recurse<F>(&self, job: F)
350    where
351        F: FnOnce(&Self) + Send + 'scope,
352    {
353        // Create another scope with the *same* lifetime.
354        let this = unsafe { self.clone() };
355
356        self.execute(move || job(&this));
357    }
358
359    /// Create a new subscope, bound to a lifetime smaller than our existing Scope.
360    ///
361    /// The subscope has a different job set, and is joined before zoom returns.
362    pub fn zoom<'smaller, F, R>(&self, scheduler: F) -> R
363    where
364        F: FnOnce(&Scope<'smaller>) -> R,
365        'scope: 'smaller,
366    {
367        let scope = unsafe { self.refine() };
368
369        // Join the scope either on completion of the scheduler or panic.
370        defer!(scope.join());
371
372        // Schedule all tasks then join all tasks
373        scheduler(&scope)
374    }
375
376    /// Awaits all jobs submitted on this Scope to be completed.
377    ///
378    /// Only guaranteed to join jobs which where `execute`d logically
379    /// prior to `join`. Jobs `execute`d concurrently with `join` may
380    /// or may not be completed before `join` returns.
381    #[inline]
382    pub fn join(&self) {
383        self.wait.join()
384    }
385
386    #[inline]
387    unsafe fn clone(&self) -> Self {
388        Scope {
389            pool: self.pool.clone(),
390            wait: self.wait.clone(),
391            _scope: Id::default(),
392        }
393    }
394
395    // Create a new scope with a smaller lifetime on the same pool.
396    #[inline]
397    unsafe fn refine<'other>(&self) -> Scope<'other>
398    where
399        'scope: 'other,
400    {
401        Scope {
402            pool: self.pool.clone(),
403            wait: Arc::new(WaitGroup::new()),
404            _scope: Id::default(),
405        }
406    }
407}
408
409enum PoolMessage {
410    Quit,
411    Task(Box<dyn Task + Send>, Arc<WaitGroup>),
412}
413
414/// A synchronization primitive for awaiting a set of actions.
415///
416/// Adding new jobs is done with `submit`, jobs are completed with `complete`,
417/// and any thread may wait for all jobs to be `complete`d with `join`.
418pub struct WaitGroup {
419    pending: AtomicUsize,
420    poisoned: AtomicBool,
421    lock: Mutex<()>,
422    cond: Condvar,
423}
424
425impl Default for WaitGroup {
426    fn default() -> Self {
427        WaitGroup {
428            pending: AtomicUsize::new(0),
429            poisoned: AtomicBool::new(false),
430            lock: Mutex::new(()),
431            cond: Condvar::new(),
432        }
433    }
434}
435
436impl WaitGroup {
437    /// Create a new empty WaitGroup.
438    #[inline]
439    pub fn new() -> Self {
440        WaitGroup::default()
441    }
442
443    /// How many submitted tasks are waiting for completion.
444    #[inline]
445    pub fn waiting(&self) -> usize {
446        self.pending.load(Ordering::SeqCst)
447    }
448
449    /// Submit to this WaitGroup, causing `join` to wait
450    /// for an additional `complete`.
451    #[inline]
452    pub fn submit(&self) {
453        self.pending.fetch_add(1, Ordering::SeqCst);
454    }
455
456    /// Complete a previous `submit`.
457    #[inline]
458    pub fn complete(&self) {
459        // Mark the current job complete.
460        let old = self.pending.fetch_sub(1, Ordering::SeqCst);
461
462        // If that was the last job, wake joiners.
463        if old == 1 {
464            let _lock = self.lock.lock().unwrap();
465            self.cond.notify_all()
466        }
467    }
468
469    /// Poison the WaitGroup so all `join`ing threads panic.
470    #[inline]
471    pub fn poison(&self) {
472        // Poison the waitgroup.
473        self.poisoned.store(true, Ordering::SeqCst);
474
475        // Mark the current job complete.
476        let old = self.pending.fetch_sub(1, Ordering::SeqCst);
477
478        // If that was the last job, wake joiners.
479        if old == 1 {
480            let _lock = self.lock.lock().unwrap();
481            self.cond.notify_all()
482        }
483    }
484
485    /// Wait for `submit`s to this WaitGroup to be `complete`d.
486    ///
487    /// Submits occuring completely before joins will always be waited on.
488    ///
489    /// Submits occuring concurrently with a `join` may or may not
490    /// be waited for.
491    ///
492    /// Before submitting, `join` will always return immediately.
493    #[inline]
494    pub fn join(&self) {
495        let mut lock = self.lock.lock().unwrap();
496
497        while self.pending.load(Ordering::SeqCst) > 0 {
498            lock = self.cond.wait(lock).unwrap();
499        }
500
501        if self.poisoned.load(Ordering::SeqCst) {
502            panic!("WaitGroup explicitly poisoned!")
503        }
504    }
505}
506
507// Poisons the given pool on drop unless canceled.
508//
509// Used to ensure panic propogation between jobs and waiting threads.
510struct Sentinel(Pool, Option<Arc<WaitGroup>>);
511
512impl Sentinel {
513    fn cancel(mut self) {
514        if let Some(wait) = self.1.take() {
515            wait.complete()
516        }
517    }
518}
519
520impl Drop for Sentinel {
521    fn drop(&mut self) {
522        if let Some(wait) = self.1.take() {
523            wait.poison()
524        }
525    }
526}
527
528struct ThreadSentinel(Option<Pool>);
529
530impl ThreadSentinel {
531    fn cancel(&mut self) {
532        if let Some(pool) = self.0.take() {
533            pool.wait.complete();
534        }
535    }
536}
537
538impl Drop for ThreadSentinel {
539    fn drop(&mut self) {
540        if let Some(pool) = self.0.take() {
541            // NOTE: We restart the thread first so we don't accidentally
542            // hit zero threads before restarting.
543
544            // Restart the thread.
545            pool.expand();
546
547            // Poison the pool.
548            pool.wait.poison();
549        }
550    }
551}
552
553trait Task {
554    fn run(self: Box<Self>);
555}
556
557impl<F: FnOnce()> Task for F {
558    fn run(self: Box<Self>) {
559        (*self)()
560    }
561}
562
563#[cfg(test)]
564mod test {
565    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
566    use std::thread::sleep;
567    use std::time::Duration;
568
569    use {crate::Pool, crate::Scope, crate::ThreadConfig};
570
571    #[test]
572    fn test_simple_use() {
573        let pool = Pool::new(4);
574
575        let mut buf = [0, 0, 0, 0];
576
577        pool.scoped(|scope| {
578            for i in &mut buf {
579                scope.execute(move || *i += 1);
580            }
581        });
582
583        assert_eq!(&buf, &[1, 1, 1, 1]);
584    }
585
586    #[test]
587    fn test_zoom() {
588        let pool = Pool::new(4);
589
590        let mut outer = 0;
591
592        pool.scoped(|scope| {
593            let mut inner = 0;
594            scope.zoom(|scope2| scope2.execute(|| inner = 1));
595            assert_eq!(inner, 1);
596
597            outer = 1;
598        });
599
600        assert_eq!(outer, 1);
601    }
602
603    #[test]
604    fn test_recurse() {
605        let pool = Pool::new(12);
606
607        let mut buf = [0, 0, 0, 0];
608
609        pool.scoped(|next| {
610            next.recurse(|next| {
611                buf[0] = 1;
612
613                next.execute(|| {
614                    buf[1] = 1;
615                });
616            });
617        });
618
619        assert_eq!(&buf, &[1, 1, 0, 0]);
620    }
621
622    #[test]
623    fn test_spawn_doesnt_hang() {
624        let pool = Pool::new(1);
625        pool.spawn(move || loop {});
626    }
627
628    #[test]
629    fn test_forever_zoom() {
630        let pool = Pool::new(16);
631        let forever = Scope::forever(pool.clone());
632
633        let ran = AtomicBool::new(false);
634
635        forever.zoom(|scope| scope.execute(|| ran.store(true, Ordering::SeqCst)));
636
637        assert!(ran.load(Ordering::SeqCst));
638    }
639
640    #[test]
641    fn test_shutdown() {
642        let pool = Pool::new(4);
643        pool.shutdown();
644    }
645
646    #[test]
647    #[should_panic]
648    fn test_scheduler_panic() {
649        let pool = Pool::new(4);
650        pool.scoped(|_| panic!());
651    }
652
653    #[test]
654    #[should_panic]
655    fn test_scoped_execute_panic() {
656        let pool = Pool::new(4);
657        pool.scoped(|scope| scope.execute(|| panic!()));
658    }
659
660    #[test]
661    #[should_panic]
662    fn test_pool_panic() {
663        let _pool = Pool::new(1);
664        panic!();
665    }
666
667    #[test]
668    #[should_panic]
669    fn test_zoomed_scoped_execute_panic() {
670        let pool = Pool::new(4);
671        pool.scoped(|scope| scope.zoom(|scope2| scope2.execute(|| panic!())));
672    }
673
674    #[test]
675    #[should_panic]
676    fn test_recurse_scheduler_panic() {
677        let pool = Pool::new(4);
678        pool.scoped(|scope| scope.recurse(|_| panic!()));
679    }
680
681    #[test]
682    #[should_panic]
683    fn test_recurse_execute_panic() {
684        let pool = Pool::new(4);
685        pool.scoped(|scope| scope.recurse(|scope2| scope2.execute(|| panic!())));
686    }
687
688    struct Canary<'a> {
689        drops: DropCounter<'a>,
690        expected: usize,
691    }
692
693    #[derive(Clone)]
694    struct DropCounter<'a>(&'a AtomicUsize);
695
696    impl<'a> Drop for DropCounter<'a> {
697        fn drop(&mut self) {
698            self.0.fetch_add(1, Ordering::SeqCst);
699        }
700    }
701
702    impl<'a> Drop for Canary<'a> {
703        fn drop(&mut self) {
704            let drops = self.drops.0.load(Ordering::SeqCst);
705            assert_eq!(drops, self.expected);
706        }
707    }
708
709    #[test]
710    #[should_panic]
711    fn test_scoped_panic_waits_for_all_tasks() {
712        let tasks = 50;
713        let panicking_task_fraction = 10;
714        let panicking_tasks = tasks / panicking_task_fraction;
715        let expected_drops = tasks + panicking_tasks;
716
717        let counter = Box::new(AtomicUsize::new(0));
718        let drops = DropCounter(&*counter);
719
720        // Actual check occurs on drop of this during unwinding.
721        let _canary = Canary {
722            drops: drops.clone(),
723            expected: expected_drops,
724        };
725
726        let pool = Pool::new(12);
727
728        pool.scoped(|scope| {
729            for task in 0..tasks {
730                let drop_counter = drops.clone();
731
732                scope.execute(move || {
733                    sleep(Duration::from_millis(10));
734
735                    drop::<DropCounter>(drop_counter);
736                });
737
738                if task % panicking_task_fraction == 0 {
739                    let drop_counter = drops.clone();
740
741                    scope.execute(move || {
742                        // Just make sure we capture it.
743                        let _drops = drop_counter;
744                        panic!();
745                    });
746                }
747            }
748        });
749    }
750
751    #[test]
752    #[should_panic]
753    fn test_scheduler_panic_waits_for_tasks() {
754        let tasks = 50;
755        let counter = Box::new(AtomicUsize::new(0));
756        let drops = DropCounter(&*counter);
757
758        let _canary = Canary {
759            drops: drops.clone(),
760            expected: tasks,
761        };
762
763        let pool = Pool::new(12);
764
765        pool.scoped(|scope| {
766            for _ in 0..tasks {
767                let drop_counter = drops.clone();
768
769                scope.execute(move || {
770                    sleep(Duration::from_millis(25));
771                    drop::<DropCounter>(drop_counter);
772                });
773            }
774
775            panic!();
776        });
777    }
778
779    #[test]
780    fn test_no_thread_config() {
781        let pool = Pool::new(1);
782
783        pool.scoped(|scope| {
784            scope.execute(|| {
785                assert!(::std::thread::current().name().is_none());
786            });
787        });
788    }
789
790    #[test]
791    fn test_with_thread_config() {
792        let config = ThreadConfig::new().prefix("pool-");
793
794        let pool = Pool::with_thread_config(1, config);
795
796        pool.scoped(|scope| {
797            scope.execute(|| {
798                assert_eq!(::std::thread::current().name().unwrap(), "pool-1");
799            });
800        });
801    }
802}