1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
//! Simple thread pool
//!
//! # Usage
//!
//! ```rust
//! use slave_pool::ThreadPool;
//! const SECOND: core::time::Duration = core::time::Duration::from_secs(1);
//!
//! static POOL: ThreadPool = ThreadPool::new();
//!
//! POOL.set_threads(8); //Tell how many threads you want
//!
//! let mut handles = Vec::new();
//! for idx in 0..8 {
//!     handles.push(POOL.spawn_handle(move || {
//!         std::thread::sleep(SECOND);
//!         idx
//!     }));
//! }
//!
//! POOL.set_threads(0); //Tells to shut down threads
//!
//! for (idx, handle) in handles.drain(..).enumerate() {
//!     assert_eq!(handle.wait().unwrap(), idx) //Even though we told  it to shutdown all threads, it is going to finish queued job first
//! }
//!
//! let handle = POOL.spawn_handle(|| {});
//! assert!(handle.wait_timeout(SECOND).is_err()); // All are shutdown now
//!
//! POOL.set_threads(1); //But let's add one more
//!
//! assert!(handle.wait().is_ok());
//!
//! let handle = POOL.spawn_handle(|| panic!("Oh no!")); // We can panic, if we want
//!
//! assert!(handle.wait().is_err()); // In that case we'll get error, but thread will be ok
//!
//! let handle = POOL.spawn_handle(|| {});
//!
//! POOL.set_threads(0);
//!
//! assert!(handle.wait().is_ok());
//! ```

#![warn(missing_docs)]
#![cfg_attr(feature = "cargo-clippy", allow(clippy::style))]

use std::{thread, io};

use core::{time, fmt};
use core::sync::atomic::{Ordering, AtomicUsize, AtomicU16};

mod utils;
mod spin;
mod oneshot;

#[derive(Debug)]
///Describes possible reasons for join to fail
pub enum JoinError {
    ///Job wasn't finished and aborted.
    Disconnect,
    ///Timeout expired, job continues.
    Timeout,
    ///Job was already consumed.
    ///
    ///Only possible if handle successfully finished with `wait_timeout`
    ///or via reference future.
    AlreadyConsumed,
}

///Handle to the job, allowing to await for it to finish
///
///It provides methods to block current thread to wait for job to finish.
///Alternatively the handle implements `Future` allowing it to be used in async context.
///
///Note that it is undesirable for it to be awaited from multiple threads,
///therefore `Clone` is not implemented, even though it is possible
pub struct JobHandle<T> {
    inner: oneshot::Receiver<T>
}

impl<T> fmt::Debug for JobHandle<T> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "JobHandle")
    }
}

impl<T> JobHandle<T> {
    #[inline]
    ///Awaits for job to finish indefinitely.
    pub fn wait(self) -> Result<T, JoinError> {
        self.inner.recv()
    }

    #[inline]
    ///Awaits for job to finish for limited time.
    pub fn wait_timeout(&self, timeout: time::Duration) -> Result<T, JoinError> {
        self.inner.recv_timeout(timeout)
    }
}

impl<T> core::future::Future for JobHandle<T> {
    type Output = Result<T, JoinError>;

    #[inline]
    fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> {
        let inner = unsafe {
            self.map_unchecked_mut(|this| &mut this.inner)
        };

        core::future::Future::poll(inner, cx)
    }
}

enum Message {
    Execute(Box<dyn FnOnce() + Send + 'static>),
    Shutdown,
}

struct State {
    send: crossbeam_channel::Sender<Message>,
    recv: crossbeam_channel::Receiver<Message>,
}

unsafe impl Sync for ThreadPool {}

///Thread pool that allows to change number of threads at runtime.
///
///On `Drop` it instructs threads to shutdown, but doesn't await for them to finish
///
///# Note
///
///The pool doesn't implement any sort of flow control.
///If workers are busy, message will remain in queue until any other thread can take it.
///
///# Clone
///
///Thread pool intentionally doesn't implement `Clone`
///If you want to share it, then share it by using global variable or on heap.
///It is thread safe, so concurrent access is allowed.
///
///# Panic
///
///Each thread wraps execution of job into `catch_unwind` to ensure that thread is not aborted
///on panic
pub struct ThreadPool {
    stack_size: AtomicUsize,
    thread_num: AtomicU16,
    thread_num_lock: spin::Lock,
    name: &'static str,
    init_lock: std::sync::Once,
    //Option is fine as extra size goes from padding, so it
    //doesn't increase overall size, but when changing layout
    //consider to switch to MaybeUninit
    state: core::cell::Cell<Option<State>>,
}

impl ThreadPool {
    ///Creates new thread pool with default params
    pub const fn new() -> Self {
        Self::with_defaults("", 0)
    }

    ///Creates new instance by specifying all params
    pub const fn with_defaults(name: &'static str, stack_size: usize) -> Self {
        Self {
            stack_size: AtomicUsize::new(stack_size),
            thread_num: AtomicU16::new(0),
            thread_num_lock: spin::Lock::new(),
            name,
            init_lock: std::sync::Once::new(),
            state: core::cell::Cell::new(None),
        }
    }

    fn get_state(&self) -> &State {
        self.init_lock.call_once(|| {
            let (send, recv) = crossbeam_channel::unbounded();
            self.state.set(Some(State {
                send,
                recv,
            }))
        });

        match unsafe { &*self.state.as_ptr() } {
            Some(state) => state,
            None => unreach!(),
        }
    }

    #[inline]
    ///Sets stack size to use.
    ///
    ///By default it uses default value, used by Rust's stdlib.
    ///But setting this variable overrides it, allowing to customize it.
    ///
    ///This setting takes effect only when creating new threads
    pub fn set_stack_size(&self, stack_size: usize) -> usize {
        self.stack_size.swap(stack_size, Ordering::AcqRel)
    }

    ///Sets worker number, starting new threads if it is greater than previous
    ///
    ///In case if it is less, extra threads are shut down.
    ///Returns previous number of threads.
    ///
    ///By default when pool is created no threads are started.
    ///
    ///If any thread fails to start, function returns immediately with error.
    ///
    ///# Note
    ///
    ///Any calls to this method are serialized, which means under hood it locks out
    ///any attempt to change number of threads, until it is done
    pub fn set_threads(&self, thread_num: u16) -> io::Result<u16> {
        let mut _guard = self.thread_num_lock.lock();
        let old_thread_num = self.thread_num.load(Ordering::Relaxed);
        self.thread_num.store(thread_num, Ordering::Relaxed);

        if old_thread_num > thread_num {
            let state = self.get_state();

            let shutdown_num = old_thread_num - thread_num;
            for _ in 0..shutdown_num {
                if state.send.send(Message::Shutdown).is_err() {
                    break;
                }
            }

        } else if thread_num > old_thread_num {
            let create_num = thread_num - old_thread_num;
            let stack_size = self.stack_size.load(Ordering::Acquire);
            let state = self.get_state();

            for num in 0..create_num {
                let recv = state.recv.clone();

                let builder = match self.name {
                    "" => thread::Builder::new(),
                    name => thread::Builder::new().name(name.to_owned()),
                };

                let builder = match stack_size {
                    0 => builder,
                    stack_size => builder.stack_size(stack_size),
                };

                let result = builder.spawn(move || loop { match recv.recv() {
                    Ok(Message::Execute(job)) => {
                        //TODO: for some reason closures has no impl, wonder why?
                        let job = std::panic::AssertUnwindSafe(job);
                        let _ = std::panic::catch_unwind(|| (job.0)());
                    },
                    Ok(Message::Shutdown) | Err(_) => break,
                }});

                match result {
                    Ok(_) => (),
                    Err(error) => {
                        self.thread_num.store(old_thread_num + num, Ordering::Relaxed);
                        return Err(error);
                    }
                }
            }
        }

        Ok(old_thread_num)
    }

    ///Schedules new execution, sending it over to one of the workers.
    pub fn spawn<F: FnOnce() + Send + 'static>(&self, job: F) {
        let state = self.get_state();
        let _ = state.send.send(Message::Execute(Box::new(job)));
    }

    ///Schedules execution, that allows to await and receive it's result.
    pub fn spawn_handle<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(&self, job: F) -> JobHandle<R> {
        let (send, recv) = oneshot::oneshot();
        let job = move || {
            let _ = send.send(job());
        };
        let _ = self.get_state().send.send(Message::Execute(Box::new(job)));

        JobHandle {
            inner: recv
        }
    }
}

impl fmt::Debug for ThreadPool {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "ThreadPool {{ threads: {} }}", self.thread_num.load(Ordering::Relaxed))
    }
}