slave_pool/
lib.rs

1//! Simple thread pool
2//!
3//! # Usage
4//!
5//! ```rust,no_run
6//! use slave_pool::ThreadPool;
7//! const SECOND: core::time::Duration = core::time::Duration::from_secs(1);
8//!
9//! static POOL: ThreadPool = ThreadPool::new();
10//!
11//! POOL.set_threads(8); //Tell how many threads you want
12//!
13//! let mut handles = Vec::new();
14//! for idx in 0..8 {
15//!     handles.push(POOL.spawn_handle(move || {
16//!         std::thread::sleep(SECOND);
17//!         idx
18//!     }));
19//! }
20//!
21//! POOL.set_threads(0); //Tells to shut down threads
22//!
23//! for (idx, handle) in handles.drain(..).enumerate() {
24//!     assert_eq!(handle.wait().unwrap(), idx) //Even though we told  it to shutdown all threads, it is going to finish queued job first
25//! }
26//!
27//! let handle = POOL.spawn_handle(|| {});
28//! assert!(handle.wait_timeout(SECOND).is_err()); // All are shutdown now
29//!
30//! POOL.set_threads(1); //But let's add one more
31//!
32//! assert!(handle.wait().is_ok());
33//!
34//! let handle = POOL.spawn_handle(|| panic!("Oh no!")); // We can panic, if we want
35//!
36//! assert!(handle.wait().is_err()); // In that case we'll get error, but thread will be ok
37//!
38//! let handle = POOL.spawn_handle(|| {});
39//!
40//! POOL.set_threads(0);
41//!
42//! assert!(handle.wait().is_ok());
43//! std::thread::sleep(SECOND);
44//! ```
45
46#![warn(missing_docs)]
47#![allow(clippy::style)]
48
49use std::{thread, io, sync};
50use core::{time, fmt, ops, future, pin, task};
51use core::sync::atomic::{Ordering, AtomicUsize, AtomicU16};
52
53mod utils;
54mod spin;
55pub mod oneshot;
56
57#[derive(PartialEq, Eq, Debug)]
58///Describes possible reasons for join to fail
59pub enum JoinError {
60    ///Job wasn't finished and aborted.
61    Disconnect,
62    ///Job was already consumed.
63    ///
64    ///Only possible if handle successfully finished with one of the `wait` or via reference future.
65    AlreadyConsumed,
66}
67
68#[repr(transparent)]
69///Handle to the job, allowing to await for it to finish
70///
71///It provides methods to block current thread to wait for job to finish.
72///Alternatively the handle implements `Future` allowing it to be used in async context.
73///
74///It is impossible to await this handle from multiple threads at the same time as it would require
75///locking, hence `Clone` is not implemented even though under the hood it is shared pointer.
76pub struct JobHandle<T> {
77    inner: oneshot::Receiver<T>
78}
79
80impl<T> fmt::Debug for JobHandle<T> {
81    #[inline(always)]
82    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83        write!(f, "JobHandle")
84    }
85}
86
87impl<T> JobHandle<T> {
88    #[inline(always)]
89    ///Checks if the associated job has finished running.
90    pub fn is_finished(&self) -> bool {
91        self.inner.is_ready()
92    }
93
94    #[inline]
95    ///Attempts to check of job is ready
96    pub fn try_wait(&self) -> Result<Option<T>, JoinError> {
97        self.inner.try_recv()
98    }
99
100    #[inline]
101    ///Awaits for job to finish indefinitely.
102    pub fn wait(self) -> Result<T, JoinError> {
103        self.inner.recv()
104    }
105
106    #[inline]
107    ///Awaits for job to finish for limited time.
108    pub fn wait_timeout(&self, timeout: time::Duration) -> Result<Option<T>, JoinError> {
109        self.inner.recv_timeout(timeout)
110    }
111}
112
113impl<T> future::Future for JobHandle<T> {
114    type Output = Result<T, JoinError>;
115
116    #[inline]
117    fn poll(self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
118        let inner = pin::Pin::new(&mut self.get_mut().inner);
119
120        future::Future::poll(inner, cx)
121    }
122}
123
124enum Message {
125    Execute(Box<dyn FnOnce() + Send + 'static>),
126    Shutdown(Option<oneshot::Sender<()>>),
127}
128
129//Ensure size remains that of Box
130const _: () = {
131    assert!(core::mem::size_of::<Box<dyn FnOnce() + Send + 'static>>() == 16);
132    assert!(core::mem::size_of::<oneshot::Sender<()>>() == 8);
133    assert!(core::mem::size_of::<Message>() == 16);
134};
135
136//Since 1.67 mpsc uses mpmc under the hood so override shitty !Sync
137//Unfortunately it is also not transparent so I cannot transmute it into underlying mpmc for
138//purpose of clone, hence it would require to be wrapped into Arc
139struct Receiver<T>(pub sync::mpsc::Receiver<T>);
140
141unsafe impl<T: Send> Sync for Receiver<T> {
142}
143unsafe impl<T: Send> Send for Receiver<T> {
144}
145
146impl<T> ops::Deref for Receiver<T> {
147    type Target = sync::mpsc::Receiver<T>;
148    #[inline(always)]
149    fn deref(&self) -> &Self::Target {
150        &self.0
151    }
152}
153
154struct State {
155    send: sync::mpsc::Sender<Message>,
156    recv: sync::Arc<Receiver<Message>>,
157}
158
159#[derive(Clone)]
160struct ThreadBuilder {
161    idx: u16,
162    name: &'static str,
163    stack_size: usize,
164    receiver: sync::Arc<Receiver<Message>>,
165}
166
167impl ThreadBuilder {
168    pub fn spawn(self) -> Result<thread::JoinHandle<()>, io::Error> {
169        let mut result = thread::Builder::new();
170        if !self.name.is_empty() {
171            result = result.name(format!("{}-{}", self.name, self.idx))
172        }
173        if self.stack_size != 0 {
174            result = result.stack_size(self.stack_size)
175        }
176        let recv = self.receiver.clone();
177
178        //Builder should be taken out to prevent re-spawn
179        let mut guard = ThreadGuard(Some(self));
180        let worker_fn = move || loop {
181            match recv.recv() {
182                Ok(Message::Execute(job)) => {
183                    job();
184                },
185                Ok(Message::Shutdown(Some(notifier))) => {
186                    guard.0.take();
187                    let _ = notifier.send(());
188                    break;
189                }
190                Ok(Message::Shutdown(None)) | Err(_) => {
191                    guard.0.take();
192                    break;
193                },
194            }
195        };
196
197        result.spawn(worker_fn)
198    }
199}
200
201#[repr(transparent)]
202struct ThreadGuard(Option<ThreadBuilder>);
203
204impl Drop for ThreadGuard {
205    fn drop(&mut self) {
206        if thread::panicking() {
207            if let Some(builder) = self.0.take() {
208                //At this point if we cannot respawn thread there is something utterly wrong, so do not double panic
209                let _ = builder.spawn();
210            }
211        }
212    }
213}
214
215///Thread pool that allows to change number of threads at runtime.
216///
217///On `Drop` it instructs threads to shutdown, but doesn't await for them to finish
218///
219///# Note
220///
221///The pool doesn't implement any sort of flow control.
222///If workers are busy, message will remain in queue until any other thread can take it.
223///
224///# Clone
225///
226///Thread pool intentionally doesn't implement `Clone`
227///If you want to share it, then share it by using global variable or on heap.
228///It is thread safe, so concurrent access is allowed.
229///
230///# Panic
231///
232///Each thread wraps execution of job into `catch_unwind` to ensure that thread is not aborted
233///on panic
234pub struct ThreadPool {
235    stack_size: AtomicUsize,
236    thread_num: AtomicU16,
237    thread_num_lock: spin::Lock,
238    name: &'static str,
239    once_state: std::sync::OnceLock<State>,
240}
241
242impl ThreadPool {
243    #[inline(always)]
244    ///Creates new thread pool with default params
245    pub const fn new() -> Self {
246        Self::with_defaults("", 0)
247    }
248
249    #[inline(always)]
250    ///Creates new instance by specifying all params
251    pub const fn with_defaults(name: &'static str, stack_size: usize) -> Self {
252        Self {
253            stack_size: AtomicUsize::new(stack_size),
254            thread_num: AtomicU16::new(0),
255            thread_num_lock: spin::Lock::new(),
256            name,
257            once_state: std::sync::OnceLock::new(),
258        }
259    }
260
261    fn get_state(&self) -> &State {
262        self.once_state.get_or_init(|| {
263            let (send, recv) = sync::mpsc::channel();
264            State {
265                send,
266                recv: sync::Arc::new(Receiver(recv)),
267            }
268        })
269    }
270
271    #[inline]
272    ///Sets stack size to use.
273    ///
274    ///By default it uses default value, used by Rust's stdlib.
275    ///But setting this variable overrides it, allowing to customize it.
276    ///
277    ///This setting takes effect only when creating new threads
278    pub fn set_stack_size(&self, stack_size: usize) -> usize {
279        self.stack_size.swap(stack_size, Ordering::AcqRel)
280    }
281
282    ///Sets worker number, starting new threads if it is greater than previous
283    ///
284    ///In case if it is less, extra threads are shut down.
285    ///Returns previous number of threads.
286    ///
287    ///By default when pool is created no threads are started.
288    ///
289    ///If any thread fails to start, function returns immediately with error.
290    ///
291    ///# Note
292    ///
293    ///Any calls to this method are serialized, which means under hood it locks out
294    ///any attempt to change number of threads, until it is done
295    pub fn set_threads(&self, thread_num: u16) -> io::Result<u16> {
296        let _guard = self.thread_num_lock.lock();
297        let old_thread_num = self.thread_num.swap(thread_num, Ordering::Relaxed);
298
299        if old_thread_num > thread_num {
300            let state = self.get_state();
301
302            let shutdown_num = old_thread_num.saturating_sub(thread_num);
303            for _ in 0..shutdown_num {
304                if state.send.send(Message::Shutdown(None)).is_err() {
305                    break;
306                }
307            }
308
309        } else if thread_num > old_thread_num {
310            let create_num = thread_num.saturating_sub(old_thread_num);
311            let state = self.get_state();
312
313            for num in 0..create_num {
314                let builder = ThreadBuilder {
315                    idx: num,
316                    stack_size: self.stack_size.load(Ordering::Relaxed),
317                    name: self.name,
318                    receiver: state.recv.clone(),
319                };
320
321                match builder.spawn() {
322                    Ok(_) => (),
323                    Err(error) => {
324                        self.thread_num.store(old_thread_num.saturating_add(num), Ordering::Relaxed);
325                        return Err(error);
326                    }
327                }
328            }
329        }
330
331        Ok(old_thread_num)
332    }
333
334    ///Terminates all threads and clears internal state
335    ///
336    ///Mutable access guarantees that only one writer can clear state without need of internal lock
337    pub fn shutdown(&mut self) {
338        let _guard = self.thread_num_lock.lock();
339        let old_thread_num = self.thread_num.swap(0, Ordering::Relaxed);
340
341        {
342            let state = self.get_state();
343
344            for _ in 0..old_thread_num {
345                if state.send.send(Message::Shutdown(None)).is_err() {
346                    break;
347                }
348            }
349        }
350
351        //Take state and drop it
352        let _ = self.once_state.take();
353    }
354
355    ///Terminates all threads, awaiting their completion and clears internal state
356    ///
357    ///Mutable access guarantees that only one writer can clear state without need of internal lock
358    pub fn shutdown_and_join(&mut self) {
359        let _guard = self.thread_num_lock.lock();
360        let old_thread_num = self.thread_num.swap(0, Ordering::Relaxed);
361
362        let mut joiners = Vec::new();
363        {
364            let state = self.get_state();
365
366            for _ in 0..old_thread_num {
367                let (sender, receiver) = oneshot::oneshot();
368                if state.send.send(Message::Shutdown(Some(sender))).is_err() {
369                    break;
370                }
371                joiners.push(receiver);
372            }
373        }
374
375        for receiver in joiners {
376            let _ = receiver.recv();
377        }
378        //Take state and drop it
379        let _ = self.once_state.take();
380    }
381
382    ///Schedules new execution, sending it over to one of the workers.
383    pub fn spawn<F: FnOnce() + Send + 'static>(&self, job: F) {
384        let state = self.get_state();
385        let _ = state.send.send(Message::Execute(Box::new(job)));
386    }
387
388    ///Schedules execution, that allows to await and receive it's result.
389    pub fn spawn_handle<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(&self, job: F) -> JobHandle<R> {
390        let (send, recv) = oneshot::oneshot();
391        let job = move || {
392            let _ = send.send((job)());
393        };
394        let _ = self.get_state().send.send(Message::Execute(Box::new(job)));
395
396        JobHandle {
397            inner: recv
398        }
399    }
400}
401
402impl Drop for ThreadPool {
403    #[inline(always)]
404    fn drop(&mut self) {
405        self.shutdown();
406    }
407}
408
409impl fmt::Debug for ThreadPool {
410    #[inline(always)]
411    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
412        fmt.write_fmt(format_args!("ThreadPool {{ threads: {} }}", self.thread_num.load(Ordering::Relaxed)))
413    }
414}
415
416unsafe impl Sync for ThreadPool {}
417