1#![cfg_attr(test, deny(warnings))]
2#![deny(missing_docs)]
3
4extern crate crossbeam;
10extern crate variance;
11
12#[macro_use]
13extern crate scopeguard;
14
15use crossbeam::channel::{Sender, Receiver, unbounded};
16use variance::InvariantLifetime as Id;
17
18use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
19use std::sync::{Arc, Condvar, Mutex};
20use std::{mem, thread};
21
22#[derive(Clone, Default)]
28pub struct Pool {
29 wait: Arc<WaitGroup>,
30 inner: Arc<PoolInner>,
31}
32
33impl Pool {
34 #[inline]
44 pub fn new(size: usize) -> Pool {
45 let pool = Pool::empty();
47
48 for _ in 0..size {
50 pool.expand();
51 }
52
53 pool
54 }
55
56 #[inline]
66 pub fn with_thread_config(size: usize, thread_config: ThreadConfig) -> Pool {
67 let pool = Pool {
69 inner: Arc::new(PoolInner::with_thread_config(thread_config)),
70 ..Pool::default()
71 };
72
73 for _ in 0..size {
75 pool.expand();
76 }
77
78 pool
79 }
80
81 #[inline]
86 pub fn empty() -> Pool {
87 Pool::default()
88 }
89
90 #[inline]
92 pub fn workers(&self) -> usize {
93 self.wait.waiting()
97 }
98
99 #[inline]
105 pub fn spawn<F: FnOnce() + Send + 'static>(&self, job: F) {
106 Scope::forever(self.clone()).execute(job)
108 }
109
110 #[inline]
118 pub fn scoped<'scope, F, R>(&self, scheduler: F) -> R
119 where
120 F: FnOnce(&Scope<'scope>) -> R,
121 {
122 Scope::forever(self.clone()).zoom(scheduler)
124 }
125
126 #[inline]
135 pub fn shutdown(&self) {
136 self.inner.queue.push(PoolMessage::Quit);
138
139 self.wait.join()
141 }
142
143 #[inline]
147 pub fn expand(&self) {
148 let pool = self.clone();
149
150 pool.wait.submit();
152
153 let thread_number = self.inner.thread_counter.fetch_add(1, Ordering::SeqCst);
154
155 let mut builder = thread::Builder::new();
157 if let Some(ref prefix) = self.inner.thread_config.prefix {
158 let name = format!("{}{}", prefix, thread_number);
159 builder = builder.name(name);
160 }
161 if let Some(stack_size) = self.inner.thread_config.stack_size {
162 builder = builder.stack_size(stack_size);
163 }
164
165 builder.spawn(move || pool.run_thread()).unwrap();
167 }
168
169 fn run_thread(self) {
170 let mut thread_sentinel = ThreadSentinel(Some(self.clone()));
172
173 loop {
174 match self.inner.queue.pop() {
175 PoolMessage::Quit => {
177 self.inner.queue.push(PoolMessage::Quit);
179
180 thread_sentinel.cancel();
183
184 break;
186 }
187
188 PoolMessage::Task(job, wait) => {
190 let sentinel = Sentinel(self.clone(), Some(wait.clone()));
191 job.run();
192 sentinel.cancel();
193 }
194 }
195 }
196 }
197}
198
199struct BlockingQueue<T> {
200 sender: Sender<T>,
201 receiver: Receiver<T>,
202}
203
204impl<T> BlockingQueue<T> {
205 fn new() -> BlockingQueue<T> {
206 let (tx, rx) = unbounded();
207 BlockingQueue {
208 sender: tx,
209 receiver: rx,
210 }
211 }
212
213 fn pop(&self) -> T {
214 self.receiver.recv().unwrap()
215 }
216
217 fn push(&self, message: T) {
218 self.sender.send(message).unwrap();
219 }
220}
221
222struct PoolInner {
223 queue: BlockingQueue<PoolMessage>,
224 thread_config: ThreadConfig,
225 thread_counter: AtomicUsize,
226}
227
228impl PoolInner {
229 fn with_thread_config(thread_config: ThreadConfig) -> Self {
230 PoolInner {
231 thread_config,
232 ..Self::default()
233 }
234 }
235}
236
237impl Default for PoolInner {
238 fn default() -> Self {
239 PoolInner {
240 queue: BlockingQueue::new(),
241 thread_config: ThreadConfig::default(),
242 thread_counter: AtomicUsize::new(1),
243 }
244 }
245}
246
247#[derive(Default)]
250pub struct ThreadConfig {
251 prefix: Option<String>,
252 stack_size: Option<usize>,
253}
254
255impl ThreadConfig {
256 pub fn new() -> ThreadConfig {
259 ThreadConfig {
260 prefix: None,
261 stack_size: None,
262 }
263 }
264
265 pub fn prefix<S: Into<String>>(self, prefix: S) -> ThreadConfig {
269 ThreadConfig {
270 prefix: Some(prefix.into()),
271 ..self
272 }
273 }
274
275 pub fn stack_size(self, stack_size: usize) -> ThreadConfig {
277 ThreadConfig {
278 stack_size: Some(stack_size),
279 ..self
280 }
281 }
282}
283
284pub struct Scope<'scope> {
304 pool: Pool,
305 wait: Arc<WaitGroup>,
306 _scope: Id<'scope>,
307}
308
309impl<'scope> Scope<'scope> {
310 #[inline]
312 pub fn forever(pool: Pool) -> Scope<'static> {
313 Scope {
314 pool,
315 wait: Arc::new(WaitGroup::new()),
316 _scope: Id::default(),
317 }
318 }
319
320 pub fn execute<F>(&self, job: F)
324 where
325 F: FnOnce() + Send + 'scope,
326 {
327 self.wait.submit();
329
330 let task = unsafe {
331 mem::transmute::<Box<dyn Task + Send + 'scope>, Box<dyn Task + Send + 'static>>(
334 Box::new(job),
335 )
336 };
337
338 self.pool
340 .inner
341 .queue
342 .push(PoolMessage::Task(task, self.wait.clone()));
343 }
344
345 pub fn recurse<F>(&self, job: F)
350 where
351 F: FnOnce(&Self) + Send + 'scope,
352 {
353 let this = unsafe { self.clone() };
355
356 self.execute(move || job(&this));
357 }
358
359 pub fn zoom<'smaller, F, R>(&self, scheduler: F) -> R
363 where
364 F: FnOnce(&Scope<'smaller>) -> R,
365 'scope: 'smaller,
366 {
367 let scope = unsafe { self.refine() };
368
369 defer!(scope.join());
371
372 scheduler(&scope)
374 }
375
376 #[inline]
382 pub fn join(&self) {
383 self.wait.join()
384 }
385
386 #[inline]
387 unsafe fn clone(&self) -> Self {
388 Scope {
389 pool: self.pool.clone(),
390 wait: self.wait.clone(),
391 _scope: Id::default(),
392 }
393 }
394
395 #[inline]
397 unsafe fn refine<'other>(&self) -> Scope<'other>
398 where
399 'scope: 'other,
400 {
401 Scope {
402 pool: self.pool.clone(),
403 wait: Arc::new(WaitGroup::new()),
404 _scope: Id::default(),
405 }
406 }
407}
408
409enum PoolMessage {
410 Quit,
411 Task(Box<dyn Task + Send>, Arc<WaitGroup>),
412}
413
414pub struct WaitGroup {
419 pending: AtomicUsize,
420 poisoned: AtomicBool,
421 lock: Mutex<()>,
422 cond: Condvar,
423}
424
425impl Default for WaitGroup {
426 fn default() -> Self {
427 WaitGroup {
428 pending: AtomicUsize::new(0),
429 poisoned: AtomicBool::new(false),
430 lock: Mutex::new(()),
431 cond: Condvar::new(),
432 }
433 }
434}
435
436impl WaitGroup {
437 #[inline]
439 pub fn new() -> Self {
440 WaitGroup::default()
441 }
442
443 #[inline]
445 pub fn waiting(&self) -> usize {
446 self.pending.load(Ordering::SeqCst)
447 }
448
449 #[inline]
452 pub fn submit(&self) {
453 self.pending.fetch_add(1, Ordering::SeqCst);
454 }
455
456 #[inline]
458 pub fn complete(&self) {
459 let old = self.pending.fetch_sub(1, Ordering::SeqCst);
461
462 if old == 1 {
464 let _lock = self.lock.lock().unwrap();
465 self.cond.notify_all()
466 }
467 }
468
469 #[inline]
471 pub fn poison(&self) {
472 self.poisoned.store(true, Ordering::SeqCst);
474
475 let old = self.pending.fetch_sub(1, Ordering::SeqCst);
477
478 if old == 1 {
480 let _lock = self.lock.lock().unwrap();
481 self.cond.notify_all()
482 }
483 }
484
485 #[inline]
494 pub fn join(&self) {
495 let mut lock = self.lock.lock().unwrap();
496
497 while self.pending.load(Ordering::SeqCst) > 0 {
498 lock = self.cond.wait(lock).unwrap();
499 }
500
501 if self.poisoned.load(Ordering::SeqCst) {
502 panic!("WaitGroup explicitly poisoned!")
503 }
504 }
505}
506
507struct Sentinel(Pool, Option<Arc<WaitGroup>>);
511
512impl Sentinel {
513 fn cancel(mut self) {
514 if let Some(wait) = self.1.take() {
515 wait.complete()
516 }
517 }
518}
519
520impl Drop for Sentinel {
521 fn drop(&mut self) {
522 if let Some(wait) = self.1.take() {
523 wait.poison()
524 }
525 }
526}
527
528struct ThreadSentinel(Option<Pool>);
529
530impl ThreadSentinel {
531 fn cancel(&mut self) {
532 if let Some(pool) = self.0.take() {
533 pool.wait.complete();
534 }
535 }
536}
537
538impl Drop for ThreadSentinel {
539 fn drop(&mut self) {
540 if let Some(pool) = self.0.take() {
541 pool.expand();
546
547 pool.wait.poison();
549 }
550 }
551}
552
553trait Task {
554 fn run(self: Box<Self>);
555}
556
557impl<F: FnOnce()> Task for F {
558 fn run(self: Box<Self>) {
559 (*self)()
560 }
561}
562
563#[cfg(test)]
564mod test {
565 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
566 use std::thread::sleep;
567 use std::time::Duration;
568
569 use {crate::Pool, crate::Scope, crate::ThreadConfig};
570
571 #[test]
572 fn test_simple_use() {
573 let pool = Pool::new(4);
574
575 let mut buf = [0, 0, 0, 0];
576
577 pool.scoped(|scope| {
578 for i in &mut buf {
579 scope.execute(move || *i += 1);
580 }
581 });
582
583 assert_eq!(&buf, &[1, 1, 1, 1]);
584 }
585
586 #[test]
587 fn test_zoom() {
588 let pool = Pool::new(4);
589
590 let mut outer = 0;
591
592 pool.scoped(|scope| {
593 let mut inner = 0;
594 scope.zoom(|scope2| scope2.execute(|| inner = 1));
595 assert_eq!(inner, 1);
596
597 outer = 1;
598 });
599
600 assert_eq!(outer, 1);
601 }
602
603 #[test]
604 fn test_recurse() {
605 let pool = Pool::new(12);
606
607 let mut buf = [0, 0, 0, 0];
608
609 pool.scoped(|next| {
610 next.recurse(|next| {
611 buf[0] = 1;
612
613 next.execute(|| {
614 buf[1] = 1;
615 });
616 });
617 });
618
619 assert_eq!(&buf, &[1, 1, 0, 0]);
620 }
621
622 #[test]
623 fn test_spawn_doesnt_hang() {
624 let pool = Pool::new(1);
625 pool.spawn(move || loop {});
626 }
627
628 #[test]
629 fn test_forever_zoom() {
630 let pool = Pool::new(16);
631 let forever = Scope::forever(pool.clone());
632
633 let ran = AtomicBool::new(false);
634
635 forever.zoom(|scope| scope.execute(|| ran.store(true, Ordering::SeqCst)));
636
637 assert!(ran.load(Ordering::SeqCst));
638 }
639
640 #[test]
641 fn test_shutdown() {
642 let pool = Pool::new(4);
643 pool.shutdown();
644 }
645
646 #[test]
647 #[should_panic]
648 fn test_scheduler_panic() {
649 let pool = Pool::new(4);
650 pool.scoped(|_| panic!());
651 }
652
653 #[test]
654 #[should_panic]
655 fn test_scoped_execute_panic() {
656 let pool = Pool::new(4);
657 pool.scoped(|scope| scope.execute(|| panic!()));
658 }
659
660 #[test]
661 #[should_panic]
662 fn test_pool_panic() {
663 let _pool = Pool::new(1);
664 panic!();
665 }
666
667 #[test]
668 #[should_panic]
669 fn test_zoomed_scoped_execute_panic() {
670 let pool = Pool::new(4);
671 pool.scoped(|scope| scope.zoom(|scope2| scope2.execute(|| panic!())));
672 }
673
674 #[test]
675 #[should_panic]
676 fn test_recurse_scheduler_panic() {
677 let pool = Pool::new(4);
678 pool.scoped(|scope| scope.recurse(|_| panic!()));
679 }
680
681 #[test]
682 #[should_panic]
683 fn test_recurse_execute_panic() {
684 let pool = Pool::new(4);
685 pool.scoped(|scope| scope.recurse(|scope2| scope2.execute(|| panic!())));
686 }
687
688 struct Canary<'a> {
689 drops: DropCounter<'a>,
690 expected: usize,
691 }
692
693 #[derive(Clone)]
694 struct DropCounter<'a>(&'a AtomicUsize);
695
696 impl<'a> Drop for DropCounter<'a> {
697 fn drop(&mut self) {
698 self.0.fetch_add(1, Ordering::SeqCst);
699 }
700 }
701
702 impl<'a> Drop for Canary<'a> {
703 fn drop(&mut self) {
704 let drops = self.drops.0.load(Ordering::SeqCst);
705 assert_eq!(drops, self.expected);
706 }
707 }
708
709 #[test]
710 #[should_panic]
711 fn test_scoped_panic_waits_for_all_tasks() {
712 let tasks = 50;
713 let panicking_task_fraction = 10;
714 let panicking_tasks = tasks / panicking_task_fraction;
715 let expected_drops = tasks + panicking_tasks;
716
717 let counter = Box::new(AtomicUsize::new(0));
718 let drops = DropCounter(&*counter);
719
720 let _canary = Canary {
722 drops: drops.clone(),
723 expected: expected_drops,
724 };
725
726 let pool = Pool::new(12);
727
728 pool.scoped(|scope| {
729 for task in 0..tasks {
730 let drop_counter = drops.clone();
731
732 scope.execute(move || {
733 sleep(Duration::from_millis(10));
734
735 drop::<DropCounter>(drop_counter);
736 });
737
738 if task % panicking_task_fraction == 0 {
739 let drop_counter = drops.clone();
740
741 scope.execute(move || {
742 let _drops = drop_counter;
744 panic!();
745 });
746 }
747 }
748 });
749 }
750
751 #[test]
752 #[should_panic]
753 fn test_scheduler_panic_waits_for_tasks() {
754 let tasks = 50;
755 let counter = Box::new(AtomicUsize::new(0));
756 let drops = DropCounter(&*counter);
757
758 let _canary = Canary {
759 drops: drops.clone(),
760 expected: tasks,
761 };
762
763 let pool = Pool::new(12);
764
765 pool.scoped(|scope| {
766 for _ in 0..tasks {
767 let drop_counter = drops.clone();
768
769 scope.execute(move || {
770 sleep(Duration::from_millis(25));
771 drop::<DropCounter>(drop_counter);
772 });
773 }
774
775 panic!();
776 });
777 }
778
779 #[test]
780 fn test_no_thread_config() {
781 let pool = Pool::new(1);
782
783 pool.scoped(|scope| {
784 scope.execute(|| {
785 assert!(::std::thread::current().name().is_none());
786 });
787 });
788 }
789
790 #[test]
791 fn test_with_thread_config() {
792 let config = ThreadConfig::new().prefix("pool-");
793
794 let pool = Pool::with_thread_config(1, config);
795
796 pool.scoped(|scope| {
797 scope.execute(|| {
798 assert_eq!(::std::thread::current().name().unwrap(), "pool-1");
799 });
800 });
801 }
802}