worker_pool/lib.rs
1/*!
2This crate provides the [`WorkerPool`] struct, which lets you manage a set of threads that need to communicate with the parent thread.
3Throughout this documentation, the thread owning [`WorkerPool`] is called the "Manager",
4whereas the threads created and handled by the [`WorkerPool`] instance are called "Workers".
5
6Communication to and from the workers are done using [`std::sync::mpsc`] queues.
7When the manager communicates to the workers, the messages are said to go "down",
8and when the workers communicate to the manager, the messages are said to go "up".
9
10Communication from the workers to the manager uses a [`SyncSender`][std::sync::mpsc::SyncSender] wrapped inside of [`WorkerSender`], as to not overwhelm the manager thread and cause a memory overflow.
11[`WorkerSender::send`] can thus block and [`WorkerSender::try_send`] can return [`Err(TrySendError::Full)`][std::sync::mpsc::TrySendError::Full].
12
13Because of the guarantees of [`WorkerSender`], locking or waiting for the queue to become available will *not* cause a deadlock when trying to join the threads,
14as all the joining methods of [`WorkerPool`] ([`WorkerPool::stop`] and [`WorkerPool::stop_and_join`]) will first empty the message queue before
15calling `join()`.
16
17This guarantee of absence of deadlocks comes at the cost of restricting what a workers may or may not do:
18- if a worker is blocking or looping indefinitely, then it must be able to receive a [`Stop`][DownMsg::Stop] message at any time
19- once the `Stop` message is received, execution must stop shortly: a worker may only block on the message queue of their `WorkerSender`
20- downward message queues aren't bounded, as that might otherwise introduce a deadlock when trying to send the `Stop` message
21
22Additionally, the livelock problem of requests queuing up in the upward channel while the manager thread tries to catch up with them is solved by [`WorkerPool::recv_burst`]:
23- any message sent before `recv_burst` was called will be yielded by its returned iterator ([`RecvBurstIterator`]) if it reaches the manager thread in time (otherwise, it'll sit in the queue until the next call to `recv_burst`)
24- any message sent after `recv_burst` was called will cause that iterator to stop and put the message in a temporary buffer for the next call to `recv_burst`
25- [`RecvBurstIterator`] is non-blocking and holds a mutable reference to its `WorkerPool`
26*/
27#![feature(negative_impls)]
28
29use std::thread::{JoinHandle};
30use std::marker::PhantomData;
31use std::time::Instant;
32
33use std::sync::mpsc::{channel, sync_channel};
34
35// TODO: support for robust causality as a feature
36// TODO: regular Sender
37
38#[cfg(test)]
39mod test;
40
41pub mod iterator;
42use iterator::*;
43
44mod msg;
45pub use msg::*;
46
47mod crate_macros;
48
49// As per this test: https://github.com/rust-lang/rust/blob/1.3.0/src/libstd/sync/mpsc/mod.rs#L1581-L1592
50// it looks like std::sync::mpsc channels are meant to preserve the order sent within a single thread
51
52/**
53The main struct, represents a pool of worker.
54The owner of this struct is the "Manager", while the threads handled by this struct are the "Workers".
55
56# Example
57
58```
59use worker_pool::WorkerPool;
60
61let mut pool: WorkerPool<String, ()> = WorkerPool::new(100);
62
63pool.execute(|tx, _rx| {
64 tx.send(String::from("Hello"));
65 tx.send(String::from("world!"));
66});
67
68assert_eq!(pool.stop().collect::<Vec<_>>().join(" "), "Hello world!");
69```
70*/
71pub struct WorkerPool<Up, Down>
72where
73 Up: Send + 'static,
74 Down: Send + 'static,
75{
76 channel: (SyncSender<UpMsg<Up>>, Receiver<UpMsg<Up>>),
77 buffer_length: usize,
78 buffer_prev: Option<UpMsg<Up>>,
79
80 workers: Vec<(JoinHandle<()>, Sender<DownMsg<Down>>)>,
81 worker_index: usize,
82
83 phantoms: PhantomData<(Up, Down)>
84}
85
86impl<Up, Down> WorkerPool<Up, Down>
87where
88 Up: Send + 'static,
89 Down: Send + 'static,
90{
91 /**
92 Creates a new WorkerPool instance, with a given maximum buffer length.
93
94 The higher the buffer length, the higher the message throughput,
95 but the higher the memory cost.
96 See [SyncSender](https://doc.rust-lang.org/std/sync/mpsc/struct.SyncSender.html) for more information.
97
98 # Example
99
100 ```
101 use std::time::Duration;
102 use worker_pool::{WorkerPool, DownMsg};
103
104 let mut pool: WorkerPool<usize, String> = WorkerPool::new(3);
105
106 pool.execute(|tx, rx| {
107 loop {
108 let msg = worker_pool::recv_break!(rx);
109 tx.send(msg.len()).unwrap();
110 }
111 });
112
113 pool.broadcast(DownMsg::Other(String::from("Betelgeuse")));
114 pool.broadcast(DownMsg::Other(String::from("Alpha Centauri")));
115 pool.broadcast(DownMsg::Other(String::from("Sirius")));
116
117 // When the worker will send the result for this message, tx.send will block, but the message
118 // will still count as being sent before recv_burst(). Whether or not it will appear in the
119 // iterator depends on the speed at which the items in the iterator are read.
120 pool.broadcast(DownMsg::Other(String::from("Procyon")));
121
122 pool.broadcast(DownMsg::Other(String::from("Sun")));
123
124 std::thread::sleep(Duration::new(0, 100_000_000));
125
126 assert_eq!(
127 pool.recv_burst().take(3).collect::<Vec<_>>(),
128 vec![10, 14, 6]
129 );
130 ```
131 */
132 #[inline]
133 pub fn new(buffer_length: usize) -> Self {
134 Self {
135 channel: sync_channel(buffer_length),
136 buffer_length,
137 buffer_prev: None,
138 workers: Vec::new(),
139 worker_index: 0,
140 phantoms: PhantomData,
141 }
142 }
143
144 /**
145 Spawns one worker thread with the given callback:
146
147 ```
148 # use worker_pool::*;
149 # let mut pool: WorkerPool<(), ()> = WorkerPool::new(100);
150 // Spawns 1 worker thread
151 pool.execute(|tx, rx| {
152 // Send messages in tx
153 // Receive messages in rx
154 });
155 # pool.stop_and_join();
156 ```
157
158 To prevent any deadlocks, the worker thread *must* stop shortly after receiving the `Stop` message:
159 - if it is in an infinite loop, then that loop must be broken (`recv_break!` and `try_recv_break!` will handle that for you)
160 - after the `Stop` message is received, it may only wait for space in the buffer of `tx`
161 - make sure that no lock will prevent a `Stop` message from being received
162 - the worker thread may `panic!`, in which case its exception will be propagated up by `stop()` and `stop_and_join()`
163 */
164 #[inline]
165 pub fn execute<F>(&mut self, callback: F)
166 where
167 F: (FnOnce(WorkerSender<Up>, Receiver<DownMsg<Down>>)),
168 F: Send + 'static
169 {
170 let (down_tx, down_rx) = channel();
171 let up_tx = self.channel.0.clone();
172 self.workers.push((
173 std::thread::spawn(move || {
174 (callback)(WorkerSender::new(up_tx), down_rx);
175 }),
176 down_tx
177 ));
178 }
179
180 /**
181 Spawns `n` worker threads with the given callback.
182 The callback must implement `Clone`.
183
184 ```
185 # use worker_pool::*;
186 # let mut pool: WorkerPool<(), ()> = WorkerPool::new(100);
187 // Spawns 16 worker thread
188 pool.execute_many(16, |tx, rx| {
189 // Send messages in tx
190 // Receive messages in rx
191 });
192 # pool.stop_and_join();
193 ```
194
195 To prevent any deadlocks, the worker threads *must* stop shortly after receiving the `Stop` message.
196 See [`execute`](#execute) for more information.
197 */
198 #[inline]
199 pub fn execute_many<F>(&mut self, n_workers: usize, callback: F)
200 where
201 F: (FnOnce(WorkerSender<Up>, Receiver<DownMsg<Down>>)),
202 F: Clone + Send + 'static
203 {
204 for _n in 0..n_workers {
205 self.execute(callback.clone());
206 }
207 }
208
209 /// Returns the maximum length of the message queue
210 #[inline]
211 pub fn buffer_length(&self) -> usize {
212 self.buffer_length
213 }
214
215 /// Receives a single message from a worker; this is a blocking operation.
216 /// If you need to call this function repeatedly, then consider iterating over the result of `recv_burst` instead.
217 pub fn recv(&mut self) -> Result<Up, RecvError> {
218 if self.buffer_prev.is_some() {
219 return Ok(std::mem::replace(&mut self.buffer_prev, None).unwrap().get());
220 }
221
222 self.channel.1.recv().map(|x| x.get())
223 }
224
225 /// Returns an iterator that will yield a "burst" of messages.
226 /// This iterator will respect causality, meaning that it will not yield any message that were sent after it was created.
227 /// You can thus safely iterate over all of the elements of this iterator without risking a livelock.
228 pub fn recv_burst<'b>(&'b mut self) -> RecvBurstIterator<'b, Up> {
229 let start = Instant::now();
230
231 RecvBurstIterator::new(
232 &self.channel.1,
233 &mut self.buffer_prev,
234 start
235 )
236 }
237
238 /// Stops the execution of all threads, returning an iterator that will yield and join all of the
239 /// messages from the workers. As soon as this function returns, the WorkerPool will be back to its starting state,
240 /// allowing you to execute more tasks immediately.
241 ///
242 /// The returned iterator will read all of the remaining messages one by one.
243 /// Once the last message is received, it will join all threads.
244 pub fn stop(&mut self) -> RecvAllIterator<Up> {
245 let channel = std::mem::replace(&mut self.channel, sync_channel(self.buffer_length));
246 let buffer_prev = std::mem::replace(&mut self.buffer_prev, None);
247 let workers_len = self.workers.len();
248 let workers = std::mem::replace(&mut self.workers, Vec::with_capacity(workers_len));
249 self.worker_index = 0;
250
251 let workers = workers.into_iter().map(|worker| {
252 // Note: the only instance where this can fail is if the receiver was dropped,
253 // in which case we can only hope that the thread will eventually join
254 let _ = worker.1.send(DownMsg::Stop);
255 worker.0
256 }).collect::<Vec<_>>();
257
258 RecvAllIterator::new(
259 channel.1,
260 buffer_prev,
261 workers
262 )
263 }
264
265 /// Stops the execution of all threads and joins them. Returns a Vec containing all of the remaining yielded values.
266 /// Note that the returned Vec will ignore the `buffer_length` limitation.
267 #[inline]
268 pub fn stop_and_join(&mut self) -> Vec<Up> {
269 let (sender, receiver) = std::mem::replace(&mut self.channel, sync_channel(self.buffer_length));
270 std::mem::drop(sender); // Prevent deadlock
271 let buffer_prev = std::mem::replace(&mut self.buffer_prev, None);
272 let workers_len = self.workers.len();
273 let workers = std::mem::replace(&mut self.workers, Vec::with_capacity(workers_len));
274 self.worker_index = 0;
275
276 for worker in workers.iter() {
277 // Note: the only instance where this can fail is if the receiver was dropped,
278 // in which case we can only hope that the thread will eventually join
279 let _ = worker.1.send(DownMsg::Stop);
280 }
281
282 let mut res = Vec::new();
283
284 if let Some(buffer_prev) = buffer_prev {
285 res.push(buffer_prev.get());
286 }
287
288 while let Ok(msg) = receiver.recv() {
289 res.push(msg.get());
290 }
291
292 for worker in workers {
293 match worker.0.join() {
294 Ok(_) => {},
295 Err(e) => std::panic::resume_unwind(e),
296 }
297 }
298
299 res
300 }
301
302 /// Sends `msg` to every worker.
303 /// If a worker has dropped their [`Receiver`][std::sync::mpsc::Receiver], then it will be skipped.
304 pub fn broadcast(&self, msg: DownMsg<Down>) where Down: Clone {
305 for (_join, tx) in self.workers.iter() {
306 // This will fail iff the thread has dropped its receiver, in which case we
307 // don't want for it to affect the other threads
308 let _ = tx.send(msg.clone());
309 }
310 }
311
312 /// Sends `msg` to a single worker, in a round-robin fashion.
313 /// Returns `Err` if there is no worker or if the worker has dropped its receiver.
314 pub fn broadcast_one(&mut self, msg: DownMsg<Down>) -> Result<(), SendError<DownMsg<Down>>> {
315 if self.workers.len() == 0 {
316 return Err(std::sync::mpsc::SendError(msg))
317 }
318
319 self.worker_index = (self.worker_index + 1) % self.workers.len();
320 self.workers[self.worker_index].1.send(msg)
321 }
322
323 pub fn get(&self, index: usize) -> Option<(&JoinHandle<()>, Sender<DownMsg<Down>>)> {
324 match self.workers.get(index) {
325 Some(x) => Some((&x.0, x.1.clone())),
326 None => None
327 }
328 }
329}