slave_pool/
lib.rs

1//! Simple thread pool
2//!
3//! # Usage
4//!
5//! ```rust
6//! use slave_pool::ThreadPool;
7//! const SECOND: core::time::Duration = core::time::Duration::from_secs(1);
8//!
9//! static POOL: ThreadPool = ThreadPool::new();
10//!
11//! POOL.set_threads(8); //Tell how many threads you want
12//!
13//! let mut handles = Vec::new();
14//! for idx in 0..8 {
15//!     handles.push(POOL.spawn_handle(move || {
16//!         std::thread::sleep(SECOND);
17//!         idx
18//!     }));
19//! }
20//!
21//! POOL.set_threads(0); //Tells to shut down threads
22//!
23//! for (idx, handle) in handles.drain(..).enumerate() {
24//!     assert_eq!(handle.wait().unwrap(), idx) //Even though we told  it to shutdown all threads, it is going to finish queued job first
25//! }
26//!
27//! let handle = POOL.spawn_handle(|| {});
28//! assert!(handle.wait_timeout(SECOND).is_err()); // All are shutdown now
29//!
30//! POOL.set_threads(1); //But let's add one more
31//!
32//! assert!(handle.wait().is_ok());
33//!
34//! let handle = POOL.spawn_handle(|| panic!("Oh no!")); // We can panic, if we want
35//!
36//! assert!(handle.wait().is_err()); // In that case we'll get error, but thread will be ok
37//!
38//! let handle = POOL.spawn_handle(|| {});
39//!
40//! POOL.set_threads(0);
41//!
42//! assert!(handle.wait().is_ok());
43//! ```
44
45#![warn(missing_docs)]
46#![cfg_attr(feature = "cargo-clippy", allow(clippy::style))]
47
48use std::{thread, io};
49
50use core::{time, fmt};
51use core::sync::atomic::{Ordering, AtomicUsize, AtomicU16};
52
53mod utils;
54mod spin;
55mod oneshot;
56
57#[derive(Debug)]
58///Describes possible reasons for join to fail
59pub enum JoinError {
60    ///Job wasn't finished and aborted.
61    Disconnect,
62    ///Timeout expired, job continues.
63    Timeout,
64    ///Job was already consumed.
65    ///
66    ///Only possible if handle successfully finished with `wait_timeout`
67    ///or via reference future.
68    AlreadyConsumed,
69}
70
71///Handle to the job, allowing to await for it to finish
72///
73///It provides methods to block current thread to wait for job to finish.
74///Alternatively the handle implements `Future` allowing it to be used in async context.
75///
76///Note that it is undesirable for it to be awaited from multiple threads,
77///therefore `Clone` is not implemented, even though it is possible
78pub struct JobHandle<T> {
79    inner: oneshot::Receiver<T>
80}
81
82impl<T> fmt::Debug for JobHandle<T> {
83    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84        write!(f, "JobHandle")
85    }
86}
87
88impl<T> JobHandle<T> {
89    #[inline]
90    ///Awaits for job to finish indefinitely.
91    pub fn wait(self) -> Result<T, JoinError> {
92        self.inner.recv()
93    }
94
95    #[inline]
96    ///Awaits for job to finish for limited time.
97    pub fn wait_timeout(&self, timeout: time::Duration) -> Result<T, JoinError> {
98        self.inner.recv_timeout(timeout)
99    }
100}
101
102impl<T> core::future::Future for JobHandle<T> {
103    type Output = Result<T, JoinError>;
104
105    #[inline]
106    fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> {
107        let inner = unsafe {
108            self.map_unchecked_mut(|this| &mut this.inner)
109        };
110
111        core::future::Future::poll(inner, cx)
112    }
113}
114
115enum Message {
116    Execute(Box<dyn FnOnce() + Send + 'static>),
117    Shutdown,
118}
119
120struct State {
121    send: crossbeam_channel::Sender<Message>,
122    recv: crossbeam_channel::Receiver<Message>,
123}
124
125unsafe impl Sync for ThreadPool {}
126
127///Thread pool that allows to change number of threads at runtime.
128///
129///On `Drop` it instructs threads to shutdown, but doesn't await for them to finish
130///
131///# Note
132///
133///The pool doesn't implement any sort of flow control.
134///If workers are busy, message will remain in queue until any other thread can take it.
135///
136///# Clone
137///
138///Thread pool intentionally doesn't implement `Clone`
139///If you want to share it, then share it by using global variable or on heap.
140///It is thread safe, so concurrent access is allowed.
141///
142///# Panic
143///
144///Each thread wraps execution of job into `catch_unwind` to ensure that thread is not aborted
145///on panic
146pub struct ThreadPool {
147    stack_size: AtomicUsize,
148    thread_num: AtomicU16,
149    thread_num_lock: spin::Lock,
150    name: &'static str,
151    init_lock: std::sync::Once,
152    //Option is fine as extra size goes from padding, so it
153    //doesn't increase overall size, but when changing layout
154    //consider to switch to MaybeUninit
155    state: core::cell::Cell<Option<State>>,
156}
157
158impl ThreadPool {
159    ///Creates new thread pool with default params
160    pub const fn new() -> Self {
161        Self::with_defaults("", 0)
162    }
163
164    ///Creates new instance by specifying all params
165    pub const fn with_defaults(name: &'static str, stack_size: usize) -> Self {
166        Self {
167            stack_size: AtomicUsize::new(stack_size),
168            thread_num: AtomicU16::new(0),
169            thread_num_lock: spin::Lock::new(),
170            name,
171            init_lock: std::sync::Once::new(),
172            state: core::cell::Cell::new(None),
173        }
174    }
175
176    fn get_state(&self) -> &State {
177        self.init_lock.call_once(|| {
178            let (send, recv) = crossbeam_channel::unbounded();
179            self.state.set(Some(State {
180                send,
181                recv,
182            }))
183        });
184
185        match unsafe { &*self.state.as_ptr() } {
186            Some(state) => state,
187            None => unreach!(),
188        }
189    }
190
191    #[inline]
192    ///Sets stack size to use.
193    ///
194    ///By default it uses default value, used by Rust's stdlib.
195    ///But setting this variable overrides it, allowing to customize it.
196    ///
197    ///This setting takes effect only when creating new threads
198    pub fn set_stack_size(&self, stack_size: usize) -> usize {
199        self.stack_size.swap(stack_size, Ordering::AcqRel)
200    }
201
202    ///Sets worker number, starting new threads if it is greater than previous
203    ///
204    ///In case if it is less, extra threads are shut down.
205    ///Returns previous number of threads.
206    ///
207    ///By default when pool is created no threads are started.
208    ///
209    ///If any thread fails to start, function returns immediately with error.
210    ///
211    ///# Note
212    ///
213    ///Any calls to this method are serialized, which means under hood it locks out
214    ///any attempt to change number of threads, until it is done
215    pub fn set_threads(&self, thread_num: u16) -> io::Result<u16> {
216        let mut _guard = self.thread_num_lock.lock();
217        let old_thread_num = self.thread_num.load(Ordering::Relaxed);
218        self.thread_num.store(thread_num, Ordering::Relaxed);
219
220        if old_thread_num > thread_num {
221            let state = self.get_state();
222
223            let shutdown_num = old_thread_num - thread_num;
224            for _ in 0..shutdown_num {
225                if state.send.send(Message::Shutdown).is_err() {
226                    break;
227                }
228            }
229
230        } else if thread_num > old_thread_num {
231            let create_num = thread_num - old_thread_num;
232            let stack_size = self.stack_size.load(Ordering::Acquire);
233            let state = self.get_state();
234
235            for num in 0..create_num {
236                let recv = state.recv.clone();
237
238                let builder = match self.name {
239                    "" => thread::Builder::new(),
240                    name => thread::Builder::new().name(name.to_owned()),
241                };
242
243                let builder = match stack_size {
244                    0 => builder,
245                    stack_size => builder.stack_size(stack_size),
246                };
247
248                let result = builder.spawn(move || loop { match recv.recv() {
249                    Ok(Message::Execute(job)) => {
250                        //TODO: for some reason closures has no impl, wonder why?
251                        let job = std::panic::AssertUnwindSafe(job);
252                        let _ = std::panic::catch_unwind(|| (job.0)());
253                    },
254                    Ok(Message::Shutdown) | Err(_) => break,
255                }});
256
257                match result {
258                    Ok(_) => (),
259                    Err(error) => {
260                        self.thread_num.store(old_thread_num + num, Ordering::Relaxed);
261                        return Err(error);
262                    }
263                }
264            }
265        }
266
267        Ok(old_thread_num)
268    }
269
270    ///Schedules new execution, sending it over to one of the workers.
271    pub fn spawn<F: FnOnce() + Send + 'static>(&self, job: F) {
272        let state = self.get_state();
273        let _ = state.send.send(Message::Execute(Box::new(job)));
274    }
275
276    ///Schedules execution, that allows to await and receive it's result.
277    pub fn spawn_handle<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(&self, job: F) -> JobHandle<R> {
278        let (send, recv) = oneshot::oneshot();
279        let job = move || {
280            let _ = send.send(job());
281        };
282        let _ = self.get_state().send.send(Message::Execute(Box::new(job)));
283
284        JobHandle {
285            inner: recv
286        }
287    }
288}
289
290impl fmt::Debug for ThreadPool {
291    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292        write!(f, "ThreadPool {{ threads: {} }}", self.thread_num.load(Ordering::Relaxed))
293    }
294}