1use std::{
2 num::NonZeroUsize,
3 thread::{self, JoinHandle},
4};
5
6use crossbeam::channel::{self as mpmc, Receiver, Sender};
7use once_cell::sync::Lazy;
8
9use crate::{
10 error::Error,
11 sink::{OverflowPolicy, Task},
12 sync::*,
13 Result,
14};
15
16pub struct ThreadPool(ArcSwapOption<ThreadPoolInner>);
38
39struct ThreadPoolInner {
40 threads: Vec<Option<JoinHandle<()>>>,
41 sender: Option<Sender<Task>>,
42}
43
44type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
45
46#[allow(missing_docs)]
47pub struct ThreadPoolBuilder {
48 capacity: NonZeroUsize,
49 threads: NonZeroUsize,
50 on_thread_spawn: Option<Callback>,
51 on_thread_finish: Option<Callback>,
52}
53
54struct Worker {
55 receiver: Receiver<Task>,
56}
57
58impl ThreadPool {
59 #[must_use]
71 pub fn builder() -> ThreadPoolBuilder {
72 ThreadPoolBuilder {
73 capacity: NonZeroUsize::new(8192).unwrap(),
74 threads: NonZeroUsize::new(1).unwrap(),
75 on_thread_spawn: None,
76 on_thread_finish: None,
77 }
78 }
79
80 pub fn new() -> Result<Self> {
83 Self::builder().build()
84 }
85
86 pub(super) fn assign_task(&self, task: Task, overflow_policy: OverflowPolicy) -> Result<()> {
87 let inner = self.0.load();
88 let sender = inner.as_ref().unwrap().sender.as_ref().unwrap();
89
90 match overflow_policy {
91 OverflowPolicy::Block => sender.send(task).map_err(Error::from_crossbeam_send),
92 OverflowPolicy::DropIncoming => sender
93 .try_send(task)
94 .map_err(Error::from_crossbeam_try_send),
95 }
96 }
97
98 pub(super) fn destroy(&self) {
99 if let Some(mut inner) = self.0.swap(None) {
100 let inner = Arc::get_mut(&mut inner).unwrap();
102
103 inner.sender.take();
106
107 for thread in &mut inner.threads {
108 if let Some(thread) = thread.take() {
109 thread.join().expect("failed to join a thread from pool");
110 }
111 }
112 }
113 }
114}
115
116impl Drop for ThreadPool {
117 fn drop(&mut self) {
118 self.destroy();
119 }
120}
121
122impl ThreadPoolBuilder {
123 #[must_use]
130 pub fn capacity(&mut self, capacity: NonZeroUsize) -> &mut Self {
131 self.capacity = capacity;
132 self
133 }
134
135 #[must_use]
138 #[allow(dead_code)]
139 fn threads(&mut self, threads: NonZeroUsize) -> &mut Self {
140 self.threads = threads;
141 self
142 }
143
144 #[must_use]
148 pub fn on_thread_spawn<F>(&mut self, f: F) -> &mut Self
149 where
150 F: Fn() + Send + Sync + 'static,
151 {
152 self.on_thread_spawn = Some(Arc::new(f));
153 self
154 }
155
156 #[must_use]
159 pub fn on_thread_finish<F>(&mut self, f: F) -> &mut Self
160 where
161 F: Fn() + Send + Sync + 'static,
162 {
163 self.on_thread_finish = Some(Arc::new(f));
164 self
165 }
166
167 pub fn build(&self) -> Result<ThreadPool> {
169 let (sender, receiver) = mpmc::bounded(self.capacity.get());
170
171 let mut threads = Vec::new();
172 threads.resize_with(self.threads.get(), || {
173 let receiver = receiver.clone();
174 let on_thread_spawn = self.on_thread_spawn.clone();
175 let on_thread_finish = self.on_thread_finish.clone();
176
177 Some(thread::spawn(move || {
178 if let Some(f) = on_thread_spawn {
179 f();
180 }
181
182 Worker { receiver }.run();
183
184 if let Some(f) = on_thread_finish {
185 f();
186 }
187 }))
188 });
189
190 Ok(ThreadPool(ArcSwapOption::new(Some(Arc::new(
191 ThreadPoolInner {
192 threads,
193 sender: Some(sender),
194 },
195 )))))
196 }
197}
198
199impl Worker {
200 fn run(&self) {
201 while let Ok(task) = self.receiver.recv() {
202 task.exec();
203 }
204 }
205}
206
207#[must_use]
208pub(crate) fn default_thread_pool() -> Arc<ThreadPool> {
209 static POOL_WEAK: Lazy<Mutex<Weak<ThreadPool>>> = Lazy::new(|| Mutex::new(Weak::new()));
210
211 let mut pool_weak = POOL_WEAK.lock_expect();
212
213 match pool_weak.upgrade() {
214 Some(pool) => pool,
215 None => {
216 let pool = Arc::new(ThreadPool::builder().build().unwrap());
217 *pool_weak = Arc::downgrade(&pool);
218 pool
219 }
220 }
221}