1use std::thread::{self, JoinHandle};
2
3use crossbeam::channel::{self as mpmc, Receiver, Sender};
4use once_cell::sync::Lazy;
5
6use crate::{
7 error::{Error, InvalidArgumentError},
8 sink::{OverflowPolicy, Task},
9 sync::*,
10 Result,
11};
12
13pub struct ThreadPool(ArcSwapOption<ThreadPoolInner>);
35
36struct ThreadPoolInner {
37 threads: Vec<Option<JoinHandle<()>>>,
38 sender: Option<Sender<Task>>,
39}
40
41type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
42
43#[allow(missing_docs)]
44pub struct ThreadPoolBuilder {
45 capacity: usize,
46 threads: usize,
47 on_thread_spawn: Option<Callback>,
48 on_thread_finish: Option<Callback>,
49}
50
51struct Worker {
52 receiver: Receiver<Task>,
53}
54
55impl ThreadPool {
56 #[must_use]
68 pub fn builder() -> ThreadPoolBuilder {
69 ThreadPoolBuilder {
70 capacity: 8192,
71 threads: 1,
72 on_thread_spawn: None,
73 on_thread_finish: None,
74 }
75 }
76
77 pub fn new() -> Result<Self> {
80 Self::builder().build()
81 }
82
83 pub(super) fn assign_task(&self, task: Task, overflow_policy: OverflowPolicy) -> Result<()> {
84 let inner = self.0.load();
85 let sender = inner.as_ref().unwrap().sender.as_ref().unwrap();
86
87 match overflow_policy {
88 OverflowPolicy::Block => sender.send(task).map_err(Error::from_crossbeam_send),
89 OverflowPolicy::DropIncoming => sender
90 .try_send(task)
91 .map_err(Error::from_crossbeam_try_send),
92 }
93 }
94
95 pub(super) fn destroy(&self) {
96 if let Some(mut inner) = self.0.swap(None) {
97 let inner = Arc::get_mut(&mut inner).unwrap();
99
100 inner.sender.take();
103
104 for thread in &mut inner.threads {
105 if let Some(thread) = thread.take() {
106 thread.join().expect("failed to join a thread from pool");
107 }
108 }
109 }
110 }
111}
112
113impl Drop for ThreadPool {
114 fn drop(&mut self) {
115 self.destroy();
116 }
117}
118
119impl ThreadPoolBuilder {
120 #[must_use]
132 pub fn capacity(&mut self, capacity: usize) -> &mut Self {
133 self.capacity = capacity;
134 self
135 }
136
137 #[must_use]
143 #[allow(dead_code)]
144 fn threads(&mut self, threads: usize) -> &mut Self {
145 self.threads = threads;
146 self
147 }
148
149 #[must_use]
153 pub fn on_thread_spawn<F>(&mut self, f: F) -> &mut Self
154 where
155 F: Fn() + Send + Sync + 'static,
156 {
157 self.on_thread_spawn = Some(Arc::new(f));
158 self
159 }
160
161 #[must_use]
164 pub fn on_thread_finish<F>(&mut self, f: F) -> &mut Self
165 where
166 F: Fn() + Send + Sync + 'static,
167 {
168 self.on_thread_finish = Some(Arc::new(f));
169 self
170 }
171
172 pub fn build(&self) -> Result<ThreadPool> {
174 if self.capacity < 1 {
175 return Err(Error::InvalidArgument(
176 InvalidArgumentError::ThreadPoolCapacity("cannot be 0".to_string()),
177 ));
178 }
179
180 if self.threads < 1 {
181 panic!("threads of ThreadPool cannot be 0");
184 }
185
186 let (sender, receiver) = mpmc::bounded(self.capacity);
187
188 let mut threads = Vec::new();
189 threads.resize_with(self.threads, || {
190 let receiver = receiver.clone();
191 let on_thread_spawn = self.on_thread_spawn.clone();
192 let on_thread_finish = self.on_thread_finish.clone();
193
194 Some(thread::spawn(move || {
195 if let Some(f) = on_thread_spawn {
196 f();
197 }
198
199 Worker { receiver }.run();
200
201 if let Some(f) = on_thread_finish {
202 f();
203 }
204 }))
205 });
206
207 Ok(ThreadPool(ArcSwapOption::new(Some(Arc::new(
208 ThreadPoolInner {
209 threads,
210 sender: Some(sender),
211 },
212 )))))
213 }
214}
215
216impl Worker {
217 fn run(&self) {
218 while let Ok(task) = self.receiver.recv() {
219 task.exec();
220 }
221 }
222}
223
224#[must_use]
225pub(crate) fn default_thread_pool() -> Arc<ThreadPool> {
226 static POOL_WEAK: Lazy<Mutex<Weak<ThreadPool>>> = Lazy::new(|| Mutex::new(Weak::new()));
227
228 let mut pool_weak = POOL_WEAK.lock_expect();
229
230 match pool_weak.upgrade() {
231 Some(pool) => pool,
232 None => {
233 let pool = Arc::new(ThreadPool::builder().build().unwrap());
234 *pool_weak = Arc::downgrade(&pool);
235 pool
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn error_capacity_0() {
246 assert!(matches!(
247 ThreadPool::builder().capacity(0).build(),
248 Err(Error::InvalidArgument(
249 InvalidArgumentError::ThreadPoolCapacity(_)
250 ))
251 ));
252 }
253
254 #[test]
255 #[should_panic]
256 fn panic_thread_0() {
257 let _ = ThreadPool::builder().threads(0).build();
258 }
259}