threadpool_executor/
threadpool.rs

1use super::error::*;
2use super::*;
3use std::{
4    collections::HashMap,
5    panic::UnwindSafe,
6    sync::{
7        atomic::{AtomicUsize, Ordering},
8        Arc, Mutex,
9    },
10    thread,
11    time::Duration,
12};
13
14pub(super) type Job = Box<dyn FnOnce() + Send + 'static>;
15
16///
17/// The policy that can be set when the task submited exceed the maximum size of the `Threadpool`.
18///
19#[derive(Debug)]
20pub enum ExceedLimitPolicy {
21    ///
22    /// The task will wait until some workers are idle.
23    ///
24    Wait,
25    ///
26    /// The task will be rejected, A `TaskRejected` `ExecutorError` will be given.
27    ///
28    /// ```
29    /// let pool = threadpool_executor::threadpool::Builder::new()
30    ///         .core_pool_size(1)
31    ///         .maximum_pool_size(1)
32    ///         .exeed_limit_policy(threadpool_executor::threadpool::ExceedLimitPolicy::Reject)
33    ///         .build();
34    /// let res = pool.execute(|| {
35    ///         std::thread::sleep(std::time::Duration::from_secs(10));
36    /// });
37    /// assert!(res.is_ok());
38    /// let res = pool.execute(|| "a");
39    /// assert!(res.is_err());
40    /// if let Err(err) = res {
41    ///         matches!(err.kind(), threadpool_executor::error::ErrorKind::TaskRejected);
42    /// }
43    /// ```
44    ///
45    Reject,
46    ///
47    /// The task will be run in the caller's thread, and will run immediatly.
48    ///
49    CallerRuns,
50}
51
52pub struct Builder {
53    core_pool_size: Option<usize>,
54    maximum_pool_size: Option<usize>,
55    exeed_limit_policy: Option<ExceedLimitPolicy>,
56    keep_alive_time: Option<Duration>,
57}
58
59impl Builder {
60    const DEFALUT_KEEP_ALIVE_SEC: u64 = 300;
61
62    ///
63    /// A builder use to build a `ThreadPool`
64    ///
65    /// # Example
66    ///
67    /// ```
68    /// let pool = threadpool_executor::threadpool::Builder::new()
69    /// .core_pool_size(1)
70    /// .maximum_pool_size(3)
71    /// .keep_alive_time(std::time::Duration::from_secs(300)) // None-core-thread keep_alive_time, default value is 5 minutes.
72    /// .exeed_limit_policy(threadpool_executor::threadpool::ExceedLimitPolicy::Reject) // Default value is Wait.
73    /// .build();
74    /// ```
75    ///
76    pub fn new() -> Builder {
77        Builder {
78            core_pool_size: None,
79            maximum_pool_size: None,
80            exeed_limit_policy: Some(ExceedLimitPolicy::Wait),
81            keep_alive_time: Some(Duration::from_secs(Builder::DEFALUT_KEEP_ALIVE_SEC)),
82        }
83    }
84
85    ///
86    /// Core threads will run until the threadpool dropped.
87    ///
88    /// # Example
89    ///
90    /// ```
91    /// let pool = threadpool_executor::threadpool::Builder::new()
92    /// .core_pool_size(1)
93    /// .build();
94    /// ```
95    ///
96    pub fn core_pool_size(mut self, size: usize) -> Builder {
97        self.core_pool_size = Some(size);
98        self
99    }
100
101    ///
102    /// Maximum threads that run in this threadpool, include the core threads,
103    /// the size of the none-core threads = `maximum_pool_size - core_pool_size`.
104    ///
105    /// None core threads will live with a given `keep_alive_time`. If the `keep_alive_time`
106    /// is not set, it will default to 5 minutes.
107    ///
108    /// # Example
109    ///
110    /// ```
111    /// let pool = threadpool_executor::threadpool::Builder::new()
112    /// .core_pool_size(1)
113    /// .maximum_pool_size(3)
114    /// .build();
115    /// ```
116    ///
117    pub fn maximum_pool_size(mut self, size: usize) -> Builder {
118        assert!(size > 0);
119        self.maximum_pool_size = Some(size);
120        self
121    }
122
123    ///
124    /// When the threads are all working, the new tasks coming will follow the given policy
125    ///
126    /// # Example
127    ///
128    /// ```
129    /// let pool = threadpool_executor::threadpool::Builder::new()
130    /// .core_pool_size(1)
131    /// .maximum_pool_size(1)
132    /// .exeed_limit_policy(threadpool_executor::threadpool::ExceedLimitPolicy::Reject)
133    /// .build();
134    /// let res = pool.execute(|| {
135    ///     std::thread::sleep(std::time::Duration::from_secs(3));
136    /// });
137    /// assert!(res.is_ok());
138    /// let res = pool.execute(|| "a");
139    /// assert!(res.is_err());
140    /// if let Err(err) = res {
141    ///     matches!(err.kind(), threadpool_executor::error::ErrorKind::TaskRejected);
142    /// }
143    /// ```
144    ///
145    pub fn exeed_limit_policy(mut self, policy: ExceedLimitPolicy) -> Builder {
146        self.exeed_limit_policy = Some(policy);
147        self
148    }
149
150    ///
151    /// None core threads will live with a given `keep_alive_time`. If the `keep_alive_time`
152    /// is not set, it will default to 5 minutes.
153    ///
154    /// # Example
155    ///
156    /// ```
157    /// let pool = threadpool_executor::threadpool::Builder::new()
158    /// .core_pool_size(1)
159    /// .maximum_pool_size(3)
160    /// .keep_alive_time(std::time::Duration::from_secs(60))
161    /// .build();
162    /// ```
163    ///
164    pub fn keep_alive_time(mut self, keep_alive_time: Duration) -> Builder {
165        assert!(!keep_alive_time.is_zero());
166        self.keep_alive_time = Some(keep_alive_time);
167        self
168    }
169
170    pub fn build(self) -> ThreadPool {
171        let init_size = match self.core_pool_size {
172            Some(size) => size,
173            None => 0,
174        };
175        let max_size = match self.maximum_pool_size {
176            Some(size) => size,
177            None => usize::MAX,
178        };
179        let policy = match self.exeed_limit_policy {
180            Some(policy) => policy,
181            None => ExceedLimitPolicy::Wait,
182        };
183        ThreadPool::create(init_size, max_size, policy, self.keep_alive_time)
184    }
185}
186
187impl ThreadPool {
188    ///
189    /// Create a fix size thread pool with the Polic `ExceedLimitPolicy::Wait`
190    ///
191    /// # Example
192    ///
193    /// ```
194    /// use threadpool_executor::ThreadPool;
195    ///
196    /// let pool = ThreadPool::new(1);
197    /// pool.execute(|| {println!("hello, world!");});
198    /// ```
199    ///
200    pub fn new(size: usize) -> ThreadPool {
201        ThreadPool::create(size, size, ExceedLimitPolicy::Wait, None)
202    }
203
204    fn create(
205        core_size: usize,
206        max_size: usize,
207        policy: ExceedLimitPolicy,
208        keep_alive_time: Option<Duration>,
209    ) -> ThreadPool {
210        assert!(max_size > 0);
211        assert!(max_size >= core_size);
212
213        let (task_sender, task_receiver) = crossbeam_channel::unbounded();
214
215        let (task_status_sender, task_status_receiver) = crossbeam_channel::unbounded();
216
217        let mut workers = HashMap::new();
218        for id in 0..core_size {
219            workers.insert(
220                id,
221                Worker::new(id, task_receiver.clone(), None, task_status_sender.clone()),
222            );
223        }
224
225        let worker_count = Arc::new(AtomicUsize::new(core_size));
226        let working_count = Arc::new(AtomicUsize::new(0));
227
228        let workers = Arc::new(Mutex::new(workers));
229        let ws = Arc::clone(&workers);
230
231        let wkc = Arc::clone(&worker_count);
232        let wkingc = Arc::clone(&working_count);
233
234        let m_thread = thread::Builder::new()
235            .name("thead-pool-cleaner".to_string())
236            .spawn(move || loop {
237                match task_status_receiver.recv() {
238                    Ok(id) => {
239                        log::debug!("receive task[#{:?}] status: {:?}", id.0, id.1);
240                        match id.1 {
241                            WorkerStatus::ThreadExit => {
242                                drop(ws.lock().unwrap().remove(&id.0));
243                                wkc.fetch_sub(1, Ordering::Relaxed);
244                            }
245                            WorkerStatus::JobDone => {
246                                wkingc.fetch_sub(1, Ordering::Relaxed);
247                            }
248                        }
249                    }
250                    Err(_) => {
251                        log::debug!("All sender is close, exit this thread.");
252                        break;
253                    }
254                }
255            })
256            .unwrap();
257
258        ThreadPool {
259            current_id: AtomicUsize::new(core_size),
260            workers,
261            worker_count,
262            working_count,
263            task_sender: Some(task_sender),
264            task_receiver,
265            worker_status_sender: Some(task_status_sender),
266            m_thread: Some(m_thread),
267            max_size,
268            policy,
269            keep_alive_time,
270        }
271    }
272
273    ///
274    /// Execute a closure in the threadpool, return a `Result` indicating whether the `submit` operation succeeded or not.
275    ///
276    /// `Submit` operation will fail when the pool reach to `the maximum_pool_size` and the `exeed_limit_policy` is set to `Reject`.
277    ///
278    /// You can get a `Expectation<T>` when `Result` is `Ok`, `T` here is the return type of your closure.
279    ///
280    /// You can use `get_result` or `get_result_timeout` method in the `Expectation` object to get the result of your closure. The
281    /// two method above will block when the result is returned or timeout.
282    ///
283    /// `Expectation::get_result` and `Expectation::get_result_timeout` return a `Result` which will return the return value of your
284    /// closure when `Ok`, and `Err` will be returned when your closure `panic`.
285    ///
286    /// # Example
287    ///
288    /// ```
289    /// let pool = threadpool_executor::ThreadPool::new(1);
290    /// let exp = pool.execute(|| 1 + 2);
291    /// assert_eq!(exp.unwrap().get_result().unwrap(), 3);
292    /// ```
293    ///
294    /// When `panic`:
295    ///
296    /// ```
297    /// let pool = threadpool_executor::ThreadPool::new(1);
298    /// let exp = pool.execute(|| {
299    ///     panic!("panic!!!");
300    /// });
301    /// let res = exp.unwrap().get_result();
302    /// assert!(res.is_err());
303    /// if let Err(err) = res {
304    ///     matches!(err.kind(), threadpool_executor::error::ErrorKind::Panic);
305    /// }
306    /// ```
307    ///
308    pub fn execute<F, T>(&self, f: F) -> Result<Expectation<T>, ExecutorError>
309    where
310        F: FnOnce() -> T + Send + UnwindSafe + 'static,
311        T: Send + 'static,
312    {
313        let (result_sender, result_receiver) = crossbeam_channel::unbounded();
314
315        let task_cancelled = Arc::new(AtomicBool::new(false));
316        let task_started = Arc::new(AtomicBool::new(false));
317        let task_done = Arc::new(AtomicBool::new(false));
318
319        let job_cancelled = task_cancelled.clone();
320        let job_started = task_started.clone();
321        let job_done = task_done.clone();
322        let job = move || {
323            if job_cancelled.load(Ordering::Relaxed) {
324                log::debug!("Job is cancelled!");
325                return;
326            }
327            job_started.store(true, Ordering::Relaxed);
328            if let Err(_) = result_sender.send(std::panic::catch_unwind(f)) {
329                log::debug!("Cannot send res to receiver, receiver may close. ");
330            }
331            job_done.store(true, Ordering::Relaxed);
332        };
333
334        let worker_count = self.worker_count.load(Ordering::Relaxed);
335        let working_count = self.working_count.load(Ordering::Relaxed);
336        log::debug!(
337            "workers {}, working {}, max: {}",
338            worker_count,
339            working_count,
340            self.max_size
341        );
342        if working_count >= self.max_size {
343            log::debug!(
344                "Working tasks reach the max size. use policy {:?}",
345                self.policy
346            );
347            match self.policy {
348                ExceedLimitPolicy::Wait => {}
349                ExceedLimitPolicy::Reject => {
350                    return Err(ExecutorError::new(
351                        ErrorKind::TaskRejected,
352                        "Working tasks reaches to the limit.".to_string(),
353                    ));
354                }
355                ExceedLimitPolicy::CallerRuns => {
356                    log::debug!("Run the task at the caller's thread. run now.");
357                    job();
358                    return Ok(Expectation {
359                        task_cancelled,
360                        task_started,
361                        task_done,
362                        result_receiver: Some(result_receiver),
363                    });
364                }
365            };
366        }
367        if working_count >= worker_count && working_count < self.max_size {
368            if let Ok(mut workers) = self.workers.lock() {
369                let id = self.current_id.fetch_add(1, Ordering::Relaxed);
370                let task_status_sender = match self.worker_status_sender.clone() {
371                    Some(sender) => sender,
372                    None => {
373                        return Err(ExecutorError::new(
374                            ErrorKind::PoolEnded,
375                            "This threadpool is already dropped.".to_string(),
376                        ));
377                    }
378                };
379                self.worker_count.fetch_add(1, Ordering::Relaxed);
380                workers.insert(
381                    id,
382                    Worker::new(
383                        id,
384                        self.task_receiver.clone(),
385                        self.keep_alive_time.clone(),
386                        task_status_sender,
387                    ),
388                );
389            }
390        }
391        self.working_count.fetch_add(1, Ordering::Relaxed);
392
393        if let Ok(_) = self.task_sender.as_ref().unwrap().send(Box::new(job)) {
394            Ok(Expectation {
395                task_cancelled,
396                task_started,
397                task_done,
398                result_receiver: Some(result_receiver),
399            })
400        } else {
401            Err(ExecutorError::new(
402                ErrorKind::PoolEnded,
403                "Cannot send message to worker thread, This threadpool is already dropped."
404                    .to_string(),
405            ))
406        }
407    }
408
409    ///
410    /// Returen the current workers' size.
411    ///
412    pub fn size(&self) -> usize {
413        self.workers.lock().unwrap().len()
414    }
415}
416
417impl Drop for ThreadPool {
418    fn drop(&mut self) {
419        log::debug!("Dropping thread pool...");
420        // drop the sender, so the receiver in workers will receiv error and then break the loop and join the thread.
421        drop(self.task_sender.take());
422        // drop the original task sender. After all senders dropped, the cleanr thread will receive error, and then break and joint.
423        drop(self.worker_status_sender.take());
424        if let Some(thread) = self.m_thread.take() {
425            if let Err(_) = thread.join() {}
426        }
427    }
428}
429
430pub(super) struct Worker {
431    id: usize,
432    thread: Option<thread::JoinHandle<()>>,
433}
434
435#[derive(Debug)]
436pub(super) enum WorkerStatus {
437    JobDone,
438    ThreadExit,
439}
440
441impl Worker {
442    fn run_in_thread(
443        id: usize,
444        task_receiver: crossbeam_channel::Receiver<Job>,
445        wait_time_out: Option<Duration>,
446        task_status_sender: crossbeam_channel::Sender<(usize, WorkerStatus)>,
447    ) {
448        loop {
449            let job = if let Some(timeout) = wait_time_out {
450                log::debug!("Worker[#{:?}] will wait for time {:?}", id, timeout);
451                match task_receiver.recv_timeout(timeout) {
452                    Ok(job) => job,
453                    Err(_) => {
454                        log::debug!(
455                            "Worker[#{:?}] receive message timeout, release this worker.",
456                            id
457                        );
458                        break;
459                    }
460                }
461            } else {
462                match task_receiver.recv() {
463                    Ok(job) => job,
464                    Err(_) => {
465                        log::debug!("Worker[#{:?}] Chanel sender may disconnect, receive job, error. exit this loop .", id);
466                        break;
467                    }
468                }
469            };
470
471            log::debug!("Worker[#{:?}] gets the job now, run the job.", id);
472
473            job();
474
475            if let Err(_) = task_status_sender.send((id, WorkerStatus::JobDone)) {
476                log::debug!(
477                    "Worker[#{:?}] Send worder staus error, receiver may close.",
478                    id
479                );
480            };
481        }
482        if let Err(_) = task_status_sender.send((id, WorkerStatus::ThreadExit)) {
483            log::debug!(
484                "Worker[#{:?}] Send worder staus error, receiver may close.",
485                id
486            );
487        }
488    }
489
490    fn new(
491        id: usize,
492        task_receiver: crossbeam_channel::Receiver<Job>,
493        wait_time_out: Option<Duration>,
494        task_status_sender: crossbeam_channel::Sender<(usize, WorkerStatus)>,
495    ) -> Worker {
496        let thread = thread::Builder::new()
497            .name("thread-pool-worker-".to_string() + id.to_string().as_str())
498            .spawn(move || {
499                Worker::run_in_thread(id, task_receiver, wait_time_out, task_status_sender)
500            })
501            .unwrap();
502        Worker {
503            id,
504            thread: Some(thread),
505        }
506    }
507}
508
509impl Drop for Worker {
510    fn drop(&mut self) {
511        if let Some(t) = self.thread.take() {
512            log::debug!("Dropping worker {:?}", self.id);
513            if let Err(err) = t.join() {
514                log::debug!("Drop worker... close thread ...error, {:?}", err)
515            } else {
516                log::debug!("Drop worker {:?} successfully!", self.id);
517            }
518        }
519        log::debug!("Worker {:?} is dropped.", self.id)
520    }
521}
522
523impl<T> Expectation<T>
524where
525    T: Send + 'static,
526{
527    ///
528    /// Show whether the task is cancelled.
529    ///
530    pub fn is_cancelled(&self) -> bool {
531        self.task_cancelled.load(Ordering::Relaxed)
532    }
533
534    ///
535    /// Show whether the task is done.
536    ///
537    pub fn is_done(&self) -> bool {
538        self.task_done.load(Ordering::Relaxed)
539    }
540
541    ///
542    /// Cancel the task when the task is still waiting in line.
543    ///
544    /// If the task is started running, `Err` will be return.
545    ///
546    /// If the task is already cancel, `Err` will be return.
547    ///
548    pub fn cancel(&mut self) -> Result<(), ExecutorError> {
549        if self.task_done.load(Ordering::Relaxed) {
550            return Err(ExecutorError::new(
551                ErrorKind::TaskRunning,
552                "Task has been done. ".to_string(),
553            ));
554        }
555
556        if self.task_started.load(Ordering::Relaxed) {
557            return Err(ExecutorError::new(
558                ErrorKind::TaskRunning,
559                "Task is already running, cannot stop now. ".to_string(),
560            ));
561        }
562
563        if self.task_cancelled.load(Ordering::Relaxed) {
564            return Err(ExecutorError::new(
565                ErrorKind::TaskCancelled,
566                "Task has been cancelled. ".to_string(),
567            ));
568        }
569        self.task_cancelled.store(true, Ordering::Relaxed);
570
571        match self.result_receiver.take() {
572            Some(receiver) => drop(receiver),
573            None => {
574                return Err(ExecutorError::new(
575                    ErrorKind::ResultAlreadyTaken,
576                    "Result is already taken.".to_string(),
577                ));
578            }
579        }
580        Ok(())
581    }
582
583    /// This method returns a `Result` which will return the return value of your
584    /// closure when `Ok`, and `Err` will be returned when your closure `panic`.
585    ///
586    /// # Example
587    ///
588    /// ```
589    /// let pool = threadpool_executor::ThreadPool::new(1);
590    /// let exp = pool.execute(|| 1 + 2);
591    /// assert_eq!(exp.unwrap().get_result().unwrap(), 3);
592    /// ```
593    ///
594    /// When `panic`:
595    ///
596    /// ```
597    /// let pool = threadpool_executor::ThreadPool::new(1);
598    /// let exp = pool.execute(|| {
599    ///     panic!("panic!!!");
600    /// });
601    /// let res = exp.unwrap().get_result();
602    /// assert!(res.is_err());
603    /// if let Err(err) = res {
604    ///     matches!(err.kind(), threadpool_executor::error::ErrorKind::Panic);
605    /// }
606    /// ```
607    ///
608    pub fn get_result(&mut self) -> Result<T, ExecutorError> {
609        if self.task_cancelled.load(Ordering::Relaxed) {
610            return Err(ExecutorError::new(
611                ErrorKind::TaskCancelled,
612                "Task has been cancelled. ".to_string(),
613            ));
614        }
615        if let Some(receiver) = self.result_receiver.take() {
616            match receiver.recv() {
617                Ok(res) => match res {
618                    Ok(res) => Ok(res),
619                    Err(cause) => Err(ExecutorError::with_cause(
620                        ErrorKind::Panic,
621                        "Function panic!".to_string(),
622                        cause,
623                    )),
624                },
625                Err(_) => Err(ExecutorError::new(
626                    ErrorKind::PoolEnded,
627                    "Cannot receive message from the  worker thread, This threadpool is already dropped."
628                        .to_string(),
629                )),
630            }
631        } else {
632            log::debug!("Receive result error! Result may be taken!");
633            Err(ExecutorError::new(
634                ErrorKind::ResultAlreadyTaken,
635                "Result is already taken.".to_string(),
636            ))
637        }
638    }
639
640    /// This method returns a `Result` which will return the return value of your
641    /// closure when `Ok`, and `Err` will be returned when your closure `panic` or
642    /// `timeout`.
643    ///
644    /// # Example
645    ///
646    /// ```
647    /// let pool = threadpool_executor::ThreadPool::new(1);
648    /// let exp = pool.execute(|| 1 + 2);
649    /// assert_eq!(exp.unwrap().get_result_timeout(std::time::Duration::from_secs(1)).unwrap(), 3);
650    /// ```
651    ///
652    /// When `timeout`:
653    ///
654    /// ```
655    /// use std::time::Duration;
656    /// let pool = threadpool_executor::ThreadPool::new(1);
657    /// let exp = pool.execute(|| {
658    ///     std::thread::sleep(Duration::from_secs(3));
659    /// });
660    /// let res = exp.unwrap().get_result_timeout(Duration::from_secs(1));
661    /// assert!(res.is_err());
662    /// if let Err(err) = res {
663    ///     matches!(err.kind(), threadpool_executor::error::ErrorKind::TimeOut);
664    /// }
665    /// ```
666    ///
667    pub fn get_result_timeout(&mut self, timeout: Duration) -> Result<T, ExecutorError> {
668        if self.task_cancelled.load(Ordering::Relaxed) {
669            return Err(ExecutorError::new(
670                ErrorKind::TaskCancelled,
671                "Task has been cancelled. ".to_string(),
672            ));
673        }
674        if let Some(receiver) = self.result_receiver.take() {
675            match receiver.recv_timeout(timeout) {
676                Ok(res) => match res {
677                    Ok(res) => Ok(res),
678                    Err(cause) => Err(ExecutorError::with_cause(
679                        ErrorKind::Panic,
680                        "Function panic!".to_string(),
681                        cause,
682                    )),
683                },
684                Err(_) => Err(ExecutorError::new(
685                    ErrorKind::TimeOut,
686                    "Receive result timeout.".to_string(),
687                )),
688            }
689        } else {
690            Err(ExecutorError::new(
691                ErrorKind::ResultAlreadyTaken,
692                "Result is already taken.".to_string(),
693            ))
694        }
695    }
696}
697
698impl<T> Drop for Expectation<T> {
699    fn drop(&mut self) {
700        drop(self.result_receiver.take());
701    }
702}