thread_pool/
thread_pool.rs

1use {Task, TaskBox};
2use state::{AtomicState, Lifecycle, CAPACITY};
3use two_lock_queue::{self as mpmc, SendError, SendTimeoutError, TrySendError, RecvTimeoutError};
4use num_cpus;
5
6use std::{fmt, thread, usize};
7use std::sync::{Arc, Mutex, Condvar};
8use std::sync::atomic::AtomicUsize;
9use std::sync::atomic::Ordering::Relaxed;
10use std::time::Duration;
11
12/// Execute tasks on one of possibly several pooled threads.
13///
14/// For more details, see the [library level documentation](./index.html).
15pub struct ThreadPool<T> {
16    inner: Arc<Inner<T>>,
17}
18
19/// Thread pool configuration.
20///
21/// Provide detailed control over the properties and behavior of the thread
22/// pool.
23#[derive(Debug)]
24pub struct Builder {
25    // Thread pool specific configuration values
26    thread_pool: Config,
27
28    // Max number of tasks that can be pending in the work queue
29    work_queue_capacity: usize,
30}
31
32/// Thread pool specific configuration values
33struct Config {
34    core_pool_size: usize,
35    max_pool_size: usize,
36    keep_alive: Option<Duration>,
37    allow_core_thread_timeout: bool,
38    // Used to configure a worker thread
39    name_prefix: Option<String>,
40    stack_size: Option<usize>,
41    after_start: Option<Arc<Fn() + Send + Sync>>,
42    before_stop: Option<Arc<Fn() + Send + Sync>>,
43}
44
45/// A handle that allows dispatching work to a thread pool.
46pub struct Sender<T> {
47    tx: mpmc::Sender<T>,
48    inner: Arc<Inner<T>>,
49}
50
51struct Inner<T> {
52    // The main pool control state is an atomic integer packing two conceptual
53    // fields
54    //   worker_count: indicating the effective number of threads
55    //   lifecycle:    indicating whether running, shutting down etc
56    //
57    // In order to pack them into one i32, we limit `worker_count` to (2^29)-1
58    // (about 500 million) threads rather than (2^31)-1 (2 billion) otherwise
59    // representable.
60    //
61    // The `worker_count` is the number of workers that have been permitted to
62    // start and not permitted to stop. The value may be transiently different
63    // from the actual number of live threads, for example when a thread
64    // spawning fails to create a thread when asked, and when exiting threads
65    // are still performing bookkeeping before terminating. The user-visible
66    // pool size is reported as the current size of the workers set.
67    //
68    // The `lifecycle` provides the main lifecyle control, taking on values:
69    //
70    //   Running:    Accept new tasks and process queued tasks
71    //   Shutdown:   Don't accept new tasks, but process queued tasks. This
72    //               state is tracked by the work queue
73    //   Stop:       Don't accept new tasks, don't process queued tasks, and
74    //               interrupt in-progress tasks
75    //   Tidying:    All tasks have terminated, worker_count is zero, the thread
76    //               transitioning to state Tidying will run the terminated() hook
77    //               method
78    //   Terminated: terminated() has completed
79    //
80    // The numerical order among these values matters, to allow ordered
81    // comparisons. The lifecycle monotonically increases over time, but need
82    // not hit each state. The transitions are:
83    //
84    //   Running -> Shutdown
85    //      On invocation of shutdown(), perhaps implicitly in finalize()
86    //
87    //   (Running or Shutdown) -> Stop
88    //      On invocation of shutdown_now()
89    //
90    //   Shutdown -> Tidying
91    //      When both queue and pool are empty
92    //
93    //   Stop -> Tidying
94    //      When pool is empty
95    //
96    //   Tidying -> Terminated
97    //      When the terminated() hook method has completed
98    //
99    // Threads waiting in await_termination() will return when the state reaches
100    // Terminated.
101    //
102    // Detecting the transition from Shutdown to Tidying is less
103    // straightforward than you'd like because the queue may become empty after
104    // non-empty and vice versa during Shutdown state, but we can only
105    // terminate if, after seeing that it is empty, we see that workerCount is
106    // 0 (which sometimes entails a recheck -- see below).
107    state: AtomicState,
108
109    // Used to keep the work channel open even if there are no running threads.
110    // This handle is cloned when spawning new workers
111    rx: mpmc::Receiver<T>,
112
113    // Acquired when waiting for the pool to shutdown
114    termination_mutex: Mutex<()>,
115
116    // Signaled when pool shutdown
117    termination_signal: Condvar,
118
119    // Used to name threads
120    next_thread_id: AtomicUsize,
121
122    // Configuration
123    config: Config,
124}
125
126impl<T> Clone for ThreadPool<T> {
127    fn clone(&self) -> Self {
128        ThreadPool { inner: self.inner.clone() }
129    }
130}
131
132impl<T> Clone for Sender<T> {
133    fn clone(&self) -> Self {
134        Sender {
135            tx: self.tx.clone(),
136            inner: self.inner.clone(),
137        }
138    }
139}
140
141impl fmt::Debug for Config {
142    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
143        const SOME: &'static &'static str = &"Some(_)";
144        const NONE: &'static &'static str = &"None";
145
146        fmt.debug_struct("ThreadPool")
147           .field("core_pool_size", &self.core_pool_size)
148           .field("core_pool_size", &self.core_pool_size)
149           .field("max_pool_size", &self.max_pool_size)
150           .field("keep_alive", &self.keep_alive)
151           .field("allow_core_thread_timeout", &self.allow_core_thread_timeout)
152           .field("name_prefix", &self.name_prefix)
153           .field("stack_size", &self.stack_size)
154           .field("after_start", if self.after_start.is_some() { SOME } else { NONE })
155           .field("before_stop", if self.before_stop.is_some() { SOME } else { NONE })
156           .finish()
157    }
158}
159
160impl<T> fmt::Debug for ThreadPool<T> {
161    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
162        fmt.debug_struct("ThreadPool").finish()
163    }
164}
165
166impl<T> fmt::Debug for Sender<T> {
167    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
168        fmt.debug_struct("Sender").finish()
169    }
170}
171
172/// Tracks state associated with a worker thread
173struct Worker<T> {
174    // Work queue receive handle
175    rx: mpmc::Receiver<T>,
176    // Shared thread pool state
177    inner: Arc<Inner<T>>,
178}
179
180// ===== impl Builder =====
181
182impl Builder {
183    /// Returns a builder with default values
184    pub fn new() -> Builder {
185        let num_cpus = num_cpus::get();
186
187        Builder {
188            thread_pool: Config {
189                core_pool_size: num_cpus,
190                max_pool_size: num_cpus,
191                keep_alive: None,
192                allow_core_thread_timeout: false,
193                name_prefix: None,
194                stack_size: None,
195                after_start: None,
196                before_stop: None,
197            },
198            work_queue_capacity: 64 * 1_024,
199        }
200    }
201
202    /// Set the thread pool's core size.
203    ///
204    /// The number of threads to keep in the pool, even if they are idle.
205    pub fn core_pool_size(mut self, val: usize) -> Self {
206        self.thread_pool.core_pool_size = val;
207        self
208    }
209
210    /// Set the thread pool's maximum size
211    ///
212    /// The maximum number of threads to allow in the pool.
213    pub fn max_pool_size(mut self, val: usize) -> Self {
214        self.thread_pool.max_pool_size = val;
215        self
216    }
217
218    /// Set the thread keep alive duration
219    ///
220    /// When the number of threads is greater than core target or core threads
221    /// are allowed to timeout, this is the maximum time that idle threads will
222    /// wait for new tasks before terminating.
223    pub fn keep_alive(mut self, val: Duration) -> Self {
224        self.thread_pool.keep_alive = Some(val);
225        self
226    }
227
228    /// Allow core threads to timeout
229    pub fn allow_core_thread_timeout(mut self) -> Self {
230        self.thread_pool.allow_core_thread_timeout = true;
231        self
232    }
233
234    /// Maximum number of jobs that can be pending in the work queue
235    pub fn work_queue_capacity(mut self, val: usize) -> Self {
236        self.work_queue_capacity = val;
237        self
238    }
239
240    /// Set name prefix of threads spawned by the pool
241    ///
242    /// Thread name prefix is used for generating thread names. For example, if
243    /// prefix is `my-pool-`, then threads in the pool will get names like
244    /// `my-pool-1` etc.
245    pub fn name_prefix<S: Into<String>>(mut self, val: S) -> Self {
246        self.thread_pool.name_prefix = Some(val.into());
247        self
248    }
249
250    /// Set the stack size of threads spawned by the pool
251    pub fn stack_size(mut self, val: usize) -> Self {
252        self.thread_pool.stack_size = Some(val);
253        self
254    }
255
256    /// Execute function `f` right after each thread is started but before
257    /// running any tasks on it
258    ///
259    /// This is initially intended for bookkeeping and monitoring uses
260    pub fn after_start<F>(mut self, f: F) -> Self
261        where F: Fn() + Send + Sync + 'static
262    {
263        self.thread_pool.after_start = Some(Arc::new(f));
264        self
265    }
266
267    /// Execute function `f` before each worker thread stops
268    ///
269    /// This is initially intended for bookkeeping and monitoring uses
270    pub fn before_stop<F>(mut self, f: F) -> Self
271        where F: Fn() + Send + Sync + 'static
272    {
273        self.thread_pool.before_stop = Some(Arc::new(f));
274        self
275    }
276
277    /// Build and return the configured thread pool
278    pub fn build<T: Task>(self) -> (Sender<T>, ThreadPool<T>) {
279        assert!(self.thread_pool.core_pool_size >= 1, "at least one thread required");
280        assert!(self.thread_pool.core_pool_size <= self.thread_pool.max_pool_size,
281                "`core_pool_size` cannot be greater than `max_pool_size`");
282        assert!(self.thread_pool.max_pool_size >= self.thread_pool.core_pool_size,
283                "`max_pool_size` must be greater or equal to `core_pool_size`");
284
285
286        // Create the work queue
287        let (tx, rx) = mpmc::channel(self.work_queue_capacity);
288
289        let inner = Arc::new(Inner {
290            // Thread pool starts in the running state
291            state: AtomicState::new(Lifecycle::Running),
292            rx: rx,
293            termination_mutex: Mutex::new(()),
294            termination_signal: Condvar::new(),
295            next_thread_id: AtomicUsize::new(1),
296            config: self.thread_pool,
297        });
298
299        let sender = Sender {
300            tx: tx,
301            inner: inner.clone(),
302        };
303
304        let pool = ThreadPool {
305            inner: inner,
306        };
307
308        (sender, pool)
309    }
310}
311
312impl<T: Task> ThreadPool<T> {
313    /// Create a thread pool that reuses a fixed number of threads operating off
314    /// a shared unbounded queue.
315    ///
316    /// At any point, at most `size` threads will be active processing tasks. If
317    /// additional tasks are submitted when all threads are active, they will
318    /// wait in the queue until a thread is available. If any thread terminates
319    /// due to a failure during execution prior to the thread pool shutting
320    /// down, a new one will take its place if needed to execute subsequent
321    /// tasks. The threads in the pool will exist until the thread pool is
322    /// explicitly shutdown.
323    pub fn fixed_size(size: usize) -> (Sender<T>, ThreadPool<T>) {
324        Builder::new()
325            .core_pool_size(size)
326            .max_pool_size(size)
327            .work_queue_capacity(usize::MAX)
328            .build()
329    }
330
331    /// Create a thread pool with a single worker thread operating off an
332    /// unbounded queue.
333    ///
334    /// Note, however, that if this single thread termintaes due to a failure
335    /// during execution prior to the thread pool shutting down, a new one will
336    /// take its place if needed to execute subsequent tasks. Tasks are
337    /// guaranteed to execute sequentially, and no more than one task will be
338    /// active at any given time.
339    pub fn single_thread() -> (Sender<T>, ThreadPool<T>) {
340        Builder::new()
341            .core_pool_size(1)
342            .max_pool_size(1)
343            .work_queue_capacity(usize::MAX)
344            .build()
345    }
346
347    /// Start a core thread, causing it to idly wait for work.
348    ///
349    /// This overrides the default policy of starting core threads only when new
350    /// tasks are executed. This function will return `false` if all core
351    /// threads have already been started.
352    pub fn prestart_core_thread(&self) -> bool {
353        let wc = self.inner.state.load().worker_count();
354
355        if wc < self.inner.config.core_pool_size {
356            self.inner.add_worker(None, &self.inner).is_ok()
357        } else {
358            false
359        }
360    }
361
362    /// Start all core threads, causing them to idly wait for work.
363    ///
364    /// This overrides the default policy of starting core threads only when new
365    /// tasks are executed.
366    pub fn prestart_core_threads(&self) {
367        while self.prestart_core_thread() {}
368    }
369
370    /// Initiate an orderly shutdown.
371    ///
372    /// Any previously submitted tasks are executed, but no new tasks will be
373    /// accepted. Invocation has no additional effect if the thread pool has
374    /// already been shut down.
375    ///
376    /// This function will not wait for previously submitted tasks to complete
377    /// execution. Use `await_termination` to do that.
378    pub fn shutdown(&self) {
379        self.inner.rx.close();
380    }
381
382    /// Shutdown the thread pool as fast as possible.
383    ///
384    /// Worker threads will no longer receive tasks off of the work queue. This
385    /// function will drain any remaining tasks before returning.
386    ///
387    /// There are no guarantees beyond best-effort attempts to actually shutdown
388    /// in a timely fashion. Threads will finish processing tasks that are
389    /// currently running.
390    pub fn shutdown_now(&self) {
391        self.inner.rx.close();
392
393        // Try transitioning the state
394        if self.inner.state.try_transition_to_stop() {
395            loop {
396                match self.inner.rx.recv() {
397                    Err(_) => return,
398                    Ok(_) => {}
399                }
400            }
401        }
402    }
403
404    /// Returns `true` if the thread pool is in the process of terminating but
405    /// has not yet terminated.
406    pub fn is_terminating(&self) -> bool {
407        !self.inner.rx.is_open() && !self.is_terminated()
408    }
409
410    /// Returns `true` if the thread pool is currently terminated.
411    pub fn is_terminated(&self) -> bool {
412        self.inner.state.load().is_terminated()
413    }
414
415    /// Blocks the current thread until the thread pool has terminated
416    pub fn await_termination(&self) {
417        let mut lock = self.inner.termination_mutex.lock().unwrap();
418
419        while !self.inner.state.load().is_terminated() {
420            lock = self.inner.termination_signal.wait(lock).unwrap();
421        }
422    }
423
424    /// Returns the current number of running threads
425    pub fn size(&self) -> usize {
426        self.inner.state.load().worker_count()
427    }
428
429    /// Returns the current number of pending tasks
430    pub fn queued(&self) -> usize {
431        self.inner.rx.len()
432    }
433}
434
435impl<T: Task> Sender<T> {
436    /// Send a task to the thread pool, blocking if necessary
437    ///
438    /// The function may result in spawning additional threads depending on the
439    /// current state and configuration of the thread pool.
440    pub fn send(&self, task: T) -> Result<(), SendError<T>> {
441        match self.try_send(task) {
442            Ok(_) => Ok(()),
443            Err(TrySendError::Disconnected(task)) => Err(SendError(task)),
444            Err(TrySendError::Full(task)) => {
445                // At capacity with all threads spawned, so just block
446                self.tx.send(task)
447            }
448        }
449    }
450
451    /// Send a task to the thread pool, blocking if necessary for up to `duration`
452    pub fn send_timeout(&self, task: T, timeout: Duration) -> Result<(), SendTimeoutError<T>> {
453        match self.try_send(task) {
454            Ok(_) => Ok(()),
455            Err(TrySendError::Disconnected(task)) => {
456                Err(SendTimeoutError::Disconnected(task))
457            }
458            Err(TrySendError::Full(task)) => {
459                // At capacity with all threads spawned, so just block
460                self.tx.send_timeout(task, timeout)
461            }
462        }
463    }
464
465    /// Send a task to the thread pool, returning immediately if at capacity.
466    pub fn try_send(&self, task: T) -> Result<(), TrySendError<T>> {
467        // Proceed in N steps
468
469        match self.tx.try_send(task) {
470            Ok(_) => {
471                // Ensure that all the core threads are running
472                let state = self.inner.state.load();
473
474                if state.worker_count() < self.inner.config.core_pool_size {
475                    let _ = self.inner.add_worker(None, &self.inner);
476                }
477
478                Ok(())
479            }
480            Err(TrySendError::Disconnected(task)) => {
481                return Err(TrySendError::Disconnected(task));
482            }
483            Err(TrySendError::Full(task)) => {
484                // Try to grow the pool size
485                match self.inner.add_worker(Some(task), &self.inner) {
486                    Ok(_) => return Ok(()),
487                    Err(task) => return Err(TrySendError::Full(task.unwrap())),
488                }
489            }
490        }
491    }
492}
493
494impl Sender<Box<TaskBox>> {
495    /// Send a fn to run on the thread pool, blocking if necessary
496    ///
497    /// The function may result in spawning additional threads depending on the
498    /// current state and configuration of the thread pool.
499    pub fn send_fn<F>(&self, task: F) -> Result<(), SendError<Box<TaskBox>>>
500        where F: FnOnce() + Send + 'static
501    {
502        let task: Box<TaskBox> = Box::new(task);
503        self.send(task)
504    }
505
506    /// Send a fn to run on the thread pool, blocking if necessary for up to
507    /// `duration`
508    pub fn send_fn_timeout<F>(&self, task: F, timeout: Duration)
509        -> Result<(), SendTimeoutError<Box<TaskBox>>>
510        where F: FnOnce() + Send + 'static
511    {
512        let task: Box<TaskBox> = Box::new(task);
513        self.send_timeout(task, timeout)
514    }
515
516    /// Send a fn to run on the thread pool, returning immediately if at
517    /// capacity.
518    pub fn try_send_fn<F>(&self, task: F)
519        -> Result<(), TrySendError<Box<TaskBox>>>
520        where F: FnOnce() + Send + 'static
521    {
522        let task: Box<TaskBox> = Box::new(task);
523        self.try_send(task)
524    }
525}
526
527// ===== impl Inner =====
528
529impl<T: Task> Inner<T> {
530    fn add_worker(&self, task: Option<T>, arc: &Arc<Inner<T>>)
531            -> Result<(), Option<T>> {
532
533        let core = task.is_none();
534        let mut state = self.state.load();
535
536        'retry: loop {
537            let lifecycle = state.lifecycle();
538
539            if lifecycle >= Lifecycle::Stop {
540                // If the lifecycle is greater than Lifecycle::Stop then never
541                // create a add a new worker
542                return Err(task);
543            }
544
545            loop {
546                let wc = state.worker_count();
547
548                // The number of threads that are expected to be running
549                let target = if core {
550                    self.config.core_pool_size
551                } else {
552                    self.config.max_pool_size
553                };
554
555                if wc >= CAPACITY || wc >= target {
556                    return Err(task);
557                }
558
559                state = match self.state.compare_and_inc_worker_count(state) {
560                    Ok(_) => break 'retry,
561                    Err(state) => state,
562                };
563
564                if state.lifecycle() != lifecycle {
565                    continue 'retry;
566                }
567
568                // CAS failed due to worker_count change; retry inner loop
569            }
570        }
571
572        // == Spawn the thread ==
573
574        let worker = Worker {
575            rx: self.rx.clone(),
576            inner: arc.clone(),
577        };
578
579        worker.spawn(task);
580
581        Ok(())
582    }
583
584    fn finalize_thread_pool(&self) {
585        // Transition to Terminated
586        if self.state.try_transition_to_tidying() {
587            self.state.transition_to_terminated();
588
589            // Notify all pending threads
590            self.termination_signal.notify_all();
591        }
592    }
593}
594
595// ===== impl Worker ====
596
597impl<T: Task> Worker<T> {
598    fn spawn(self, initial_task: Option<T>) {
599        let mut b = thread::Builder::new();
600
601        {
602            let c = &self.inner.config;
603
604            if let Some(stack_size) = c.stack_size {
605                b = b.stack_size(stack_size);
606            }
607
608            if let Some(ref name_prefix) = c.name_prefix {
609                let i = self.inner.next_thread_id.fetch_add(1, Relaxed);
610                b = b.name(format!("{}{}", name_prefix, i));
611            }
612        }
613
614        b.spawn(move || self.run(initial_task)).unwrap();
615    }
616
617    fn run(mut self, mut initial_task: Option<T>) {
618        use std::panic::{self, AssertUnwindSafe};
619
620        // Run the before hook
621        self.inner.config.after_start.as_ref().map(|f| f());
622
623        while let Some(task) = self.next_task(initial_task.take()) {
624            // AssertUnwindSafe is used because `Task` is `Send + 'static`, which
625            // is essentially unwind safe
626            let _ = panic::catch_unwind(AssertUnwindSafe(move || task.run()));
627        }
628    }
629
630    // Gets the next task, blocking if necessary. Returns None if the worker
631    // should shutdown
632    fn next_task(&mut self, mut task: Option<T>) -> Option<T> {
633        // Load the state
634        let state = self.inner.state.load();
635
636        // Did the last `recv_task` call timeout?
637        let mut timed_out = false;
638        let allow_core_thread_timeout = self.inner.config.allow_core_thread_timeout;
639        let core_pool_size = self.inner.config.core_pool_size;
640
641        loop {
642            if state.lifecycle() >= Lifecycle::Stop {
643                // Run the after hook
644                self.inner.config.before_stop.as_ref().map(|f| f());
645
646                // No more tasks should be removed from the queue, exit the
647                // worker
648                self.decrement_worker_count();
649
650                // Nothing else to do
651                return None;
652            }
653
654            if task.is_some() {
655                break;
656            }
657
658            let wc = state.worker_count();
659
660            // Determine if there is a timeout for receiving the next task
661            let timeout = if wc > core_pool_size || allow_core_thread_timeout {
662                self.inner.config.keep_alive
663            } else {
664                None
665            };
666
667            if wc > self.inner.config.max_pool_size || (timeout.is_some() && timed_out) {
668                // Only shutdown all threads if the work queue is empty
669                if wc > 1 || self.rx.len() == 0 {
670                    if self.inner.state.compare_and_dec_worker_count(state) {
671                        // Run the after hook
672                        self.inner.config.before_stop.as_ref().map(|f| f());
673
674                        // This can never be a termination state since the
675                        // lifecycle is not Stop or Terminate (checked above) and
676                        // the queue has not been accessed, so it is unknown
677                        // whether or not there is a pending task.
678                        //
679                        // This means that there is no need to call
680                        // `finalize_worker`
681                        return None;
682                    }
683
684                    // CAS failed, restart loop
685                    continue;
686                }
687            }
688
689            match self.recv_task(timeout) {
690                Ok(t) => {
691                    // Grab the task, but the loop will restart in order to
692                    // check the state again. If the state transitioned to Stop
693                    // while the worker was blocked on the queue, the task
694                    // should be discarded and the worker shutdown.
695                    task = Some(t);
696                }
697                Err(RecvTimeoutError::Disconnected) => {
698                    // Run the after hook
699                    self.inner.config.before_stop.as_ref().map(|f| f());
700
701                    // No more tasks should be removed from the queue, exit the
702                    // worker
703                    self.decrement_worker_count();
704
705                    // Nothing else to do
706                    return None;
707                }
708                Err(RecvTimeoutError::Timeout) => {
709                    timed_out = true;
710                }
711            }
712        }
713
714        task
715    }
716
717    fn recv_task(&self, timeout: Option<Duration>) -> Result<T, RecvTimeoutError> {
718        match timeout {
719            Some(timeout) => self.rx.recv_timeout(timeout),
720            None => self.rx.recv().map_err(|_| RecvTimeoutError::Disconnected),
721        }
722    }
723
724    fn decrement_worker_count(&self) {
725        let state = self.inner.state.fetch_dec_worker_count();
726
727        if state.worker_count() == 1 && !self.rx.is_open() {
728            self.inner.finalize_thread_pool();
729        }
730    }
731}