worker_pool/
lib.rs

1/*!
2This crate provides the [`WorkerPool`] struct, which lets you manage a set of threads that need to communicate with the parent thread.
3Throughout this documentation, the thread owning [`WorkerPool`] is called the "Manager",
4whereas the threads created and handled by the [`WorkerPool`] instance are called "Workers".
5
6Communication to and from the workers are done using [`std::sync::mpsc`] queues.
7When the manager communicates to the workers, the messages are said to go "down",
8and when the workers communicate to the manager, the messages are said to go "up".
9
10Communication from the workers to the manager uses a [`SyncSender`][std::sync::mpsc::SyncSender] wrapped inside of [`WorkerSender`], as to not overwhelm the manager thread and cause a memory overflow.
11[`WorkerSender::send`] can thus block and [`WorkerSender::try_send`] can return [`Err(TrySendError::Full)`][std::sync::mpsc::TrySendError::Full].
12
13Because of the guarantees of [`WorkerSender`], locking or waiting for the queue to become available will *not* cause a deadlock when trying to join the threads,
14as all the joining methods of [`WorkerPool`] ([`WorkerPool::stop`] and [`WorkerPool::stop_and_join`]) will first empty the message queue before
15calling `join()`.
16
17This guarantee of absence of deadlocks comes at the cost of restricting what a workers may or may not do:
18- if a worker is blocking or looping indefinitely, then it must be able to receive a [`Stop`][DownMsg::Stop] message at any time
19- once the `Stop` message is received, execution must stop shortly: a worker may only block on the message queue of their `WorkerSender`
20- downward message queues aren't bounded, as that might otherwise introduce a deadlock when trying to send the `Stop` message
21
22Additionally, the livelock problem of requests queuing up in the upward channel while the manager thread tries to catch up with them is solved by [`WorkerPool::recv_burst`]:
23- any message sent before `recv_burst` was called will be yielded by its returned iterator ([`RecvBurstIterator`]) if it reaches the manager thread in time (otherwise, it'll sit in the queue until the next call to `recv_burst`)
24- any message sent after `recv_burst` was called will cause that iterator to stop and put the message in a temporary buffer for the next call to `recv_burst`
25- [`RecvBurstIterator`] is non-blocking and holds a mutable reference to its `WorkerPool`
26*/
27#![feature(negative_impls)]
28
29use std::thread::{JoinHandle};
30use std::marker::PhantomData;
31use std::time::Instant;
32
33use std::sync::mpsc::{channel, sync_channel};
34
35// TODO: support for robust causality as a feature
36// TODO: regular Sender
37
38#[cfg(test)]
39mod test;
40
41pub mod iterator;
42use iterator::*;
43
44mod msg;
45pub use msg::*;
46
47mod crate_macros;
48
49// As per this test: https://github.com/rust-lang/rust/blob/1.3.0/src/libstd/sync/mpsc/mod.rs#L1581-L1592
50// it looks like std::sync::mpsc channels are meant to preserve the order sent within a single thread
51
52/**
53The main struct, represents a pool of worker.
54The owner of this struct is the "Manager", while the threads handled by this struct are the "Workers".
55
56# Example
57
58```
59use worker_pool::WorkerPool;
60
61let mut pool: WorkerPool<String, ()> = WorkerPool::new(100);
62
63pool.execute(|tx, _rx| {
64    tx.send(String::from("Hello"));
65    tx.send(String::from("world!"));
66});
67
68assert_eq!(pool.stop().collect::<Vec<_>>().join(" "), "Hello world!");
69```
70*/
71pub struct WorkerPool<Up, Down>
72where
73    Up: Send + 'static,
74    Down: Send + 'static,
75{
76    channel: (SyncSender<UpMsg<Up>>, Receiver<UpMsg<Up>>),
77    buffer_length: usize,
78    buffer_prev: Option<UpMsg<Up>>,
79
80    workers: Vec<(JoinHandle<()>, Sender<DownMsg<Down>>)>,
81    worker_index: usize,
82
83    phantoms: PhantomData<(Up, Down)>
84}
85
86impl<Up, Down> WorkerPool<Up, Down>
87where
88    Up: Send + 'static,
89    Down: Send + 'static,
90{
91    /**
92    Creates a new WorkerPool instance, with a given maximum buffer length.
93
94    The higher the buffer length, the higher the message throughput,
95    but the higher the memory cost.
96    See [SyncSender](https://doc.rust-lang.org/std/sync/mpsc/struct.SyncSender.html) for more information.
97
98    # Example
99
100    ```
101    use std::time::Duration;
102    use worker_pool::{WorkerPool, DownMsg};
103
104    let mut pool: WorkerPool<usize, String> = WorkerPool::new(3);
105
106    pool.execute(|tx, rx| {
107        loop {
108            let msg = worker_pool::recv_break!(rx);
109            tx.send(msg.len()).unwrap();
110        }
111    });
112
113    pool.broadcast(DownMsg::Other(String::from("Betelgeuse")));
114    pool.broadcast(DownMsg::Other(String::from("Alpha Centauri")));
115    pool.broadcast(DownMsg::Other(String::from("Sirius")));
116
117    // When the worker will send the result for this message, tx.send will block, but the message
118    // will still count as being sent before recv_burst(). Whether or not it will appear in the
119    // iterator depends on the speed at which the items in the iterator are read.
120    pool.broadcast(DownMsg::Other(String::from("Procyon")));
121
122    pool.broadcast(DownMsg::Other(String::from("Sun")));
123
124    std::thread::sleep(Duration::new(0, 100_000_000));
125
126    assert_eq!(
127        pool.recv_burst().take(3).collect::<Vec<_>>(),
128        vec![10, 14, 6]
129    );
130    ```
131    */
132    #[inline]
133    pub fn new(buffer_length: usize) -> Self {
134        Self {
135            channel: sync_channel(buffer_length),
136            buffer_length,
137            buffer_prev: None,
138            workers: Vec::new(),
139            worker_index: 0,
140            phantoms: PhantomData,
141        }
142    }
143
144    /**
145    Spawns one worker thread with the given callback:
146
147    ```
148    # use worker_pool::*;
149    # let mut pool: WorkerPool<(), ()> = WorkerPool::new(100);
150    // Spawns 1 worker thread
151    pool.execute(|tx, rx| {
152        // Send messages in tx
153        // Receive messages in rx
154    });
155    # pool.stop_and_join();
156    ```
157
158    To prevent any deadlocks, the worker thread *must* stop shortly after receiving the `Stop` message:
159    - if it is in an infinite loop, then that loop must be broken (`recv_break!` and `try_recv_break!` will handle that for you)
160    - after the `Stop` message is received, it may only wait for space in the buffer of `tx`
161    - make sure that no lock will prevent a `Stop` message from being received
162    - the worker thread may `panic!`, in which case its exception will be propagated up by `stop()` and `stop_and_join()`
163    */
164    #[inline]
165    pub fn execute<F>(&mut self, callback: F)
166    where
167        F: (FnOnce(WorkerSender<Up>, Receiver<DownMsg<Down>>)),
168        F: Send + 'static
169    {
170        let (down_tx, down_rx) = channel();
171        let up_tx = self.channel.0.clone();
172        self.workers.push((
173            std::thread::spawn(move || {
174                (callback)(WorkerSender::new(up_tx), down_rx);
175            }),
176            down_tx
177        ));
178    }
179
180    /**
181    Spawns `n` worker threads with the given callback.
182    The callback must implement `Clone`.
183
184    ```
185    # use worker_pool::*;
186    # let mut pool: WorkerPool<(), ()> = WorkerPool::new(100);
187    // Spawns 16 worker thread
188    pool.execute_many(16, |tx, rx| {
189        // Send messages in tx
190        // Receive messages in rx
191    });
192    # pool.stop_and_join();
193    ```
194
195    To prevent any deadlocks, the worker threads *must* stop shortly after receiving the `Stop` message.
196    See [`execute`](#execute) for more information.
197    */
198    #[inline]
199    pub fn execute_many<F>(&mut self, n_workers: usize, callback: F)
200    where
201        F: (FnOnce(WorkerSender<Up>, Receiver<DownMsg<Down>>)),
202        F: Clone + Send + 'static
203    {
204        for _n in 0..n_workers {
205            self.execute(callback.clone());
206        }
207    }
208
209    /// Returns the maximum length of the message queue
210    #[inline]
211    pub fn buffer_length(&self) -> usize {
212        self.buffer_length
213    }
214
215    /// Receives a single message from a worker; this is a blocking operation.
216    /// If you need to call this function repeatedly, then consider iterating over the result of `recv_burst` instead.
217    pub fn recv(&mut self) -> Result<Up, RecvError> {
218        if self.buffer_prev.is_some() {
219            return Ok(std::mem::replace(&mut self.buffer_prev, None).unwrap().get());
220        }
221
222        self.channel.1.recv().map(|x| x.get())
223    }
224
225    /// Returns an iterator that will yield a "burst" of messages.
226    /// This iterator will respect causality, meaning that it will not yield any message that were sent after it was created.
227    /// You can thus safely iterate over all of the elements of this iterator without risking a livelock.
228    pub fn recv_burst<'b>(&'b mut self) -> RecvBurstIterator<'b, Up> {
229        let start = Instant::now();
230
231        RecvBurstIterator::new(
232            &self.channel.1,
233            &mut self.buffer_prev,
234            start
235        )
236    }
237
238    /// Stops the execution of all threads, returning an iterator that will yield and join all of the
239    /// messages from the workers. As soon as this function returns, the WorkerPool will be back to its starting state,
240    /// allowing you to execute more tasks immediately.
241    ///
242    /// The returned iterator will read all of the remaining messages one by one.
243    /// Once the last message is received, it will join all threads.
244    pub fn stop(&mut self) -> RecvAllIterator<Up> {
245        let channel = std::mem::replace(&mut self.channel, sync_channel(self.buffer_length));
246        let buffer_prev = std::mem::replace(&mut self.buffer_prev, None);
247        let workers_len = self.workers.len();
248        let workers = std::mem::replace(&mut self.workers, Vec::with_capacity(workers_len));
249        self.worker_index = 0;
250
251        let workers = workers.into_iter().map(|worker| {
252            // Note: the only instance where this can fail is if the receiver was dropped,
253            // in which case we can only hope that the thread will eventually join
254            let _ = worker.1.send(DownMsg::Stop);
255            worker.0
256        }).collect::<Vec<_>>();
257
258        RecvAllIterator::new(
259            channel.1,
260            buffer_prev,
261            workers
262        )
263    }
264
265    /// Stops the execution of all threads and joins them. Returns a Vec containing all of the remaining yielded values.
266    /// Note that the returned Vec will ignore the `buffer_length` limitation.
267    #[inline]
268    pub fn stop_and_join(&mut self) -> Vec<Up> {
269        let (sender, receiver) = std::mem::replace(&mut self.channel, sync_channel(self.buffer_length));
270        std::mem::drop(sender); // Prevent deadlock
271        let buffer_prev = std::mem::replace(&mut self.buffer_prev, None);
272        let workers_len = self.workers.len();
273        let workers = std::mem::replace(&mut self.workers, Vec::with_capacity(workers_len));
274        self.worker_index = 0;
275
276        for worker in workers.iter() {
277            // Note: the only instance where this can fail is if the receiver was dropped,
278            // in which case we can only hope that the thread will eventually join
279            let _ = worker.1.send(DownMsg::Stop);
280        }
281
282        let mut res = Vec::new();
283
284        if let Some(buffer_prev) = buffer_prev {
285            res.push(buffer_prev.get());
286        }
287
288        while let Ok(msg) = receiver.recv() {
289            res.push(msg.get());
290        }
291
292        for worker in workers {
293            match worker.0.join() {
294                Ok(_) => {},
295                Err(e) => std::panic::resume_unwind(e),
296            }
297        }
298
299        res
300    }
301
302    /// Sends `msg` to every worker.
303    /// If a worker has dropped their [`Receiver`][std::sync::mpsc::Receiver], then it will be skipped.
304    pub fn broadcast(&self, msg: DownMsg<Down>) where Down: Clone {
305        for (_join, tx) in self.workers.iter() {
306            // This will fail iff the thread has dropped its receiver, in which case we
307            // don't want for it to affect the other threads
308            let _ = tx.send(msg.clone());
309        }
310    }
311
312    /// Sends `msg` to a single worker, in a round-robin fashion.
313    /// Returns `Err` if there is no worker or if the worker has dropped its receiver.
314    pub fn broadcast_one(&mut self, msg: DownMsg<Down>) -> Result<(), SendError<DownMsg<Down>>> {
315        if self.workers.len() == 0 {
316            return Err(std::sync::mpsc::SendError(msg))
317        }
318
319        self.worker_index = (self.worker_index + 1) % self.workers.len();
320        self.workers[self.worker_index].1.send(msg)
321    }
322
323    pub fn get(&self, index: usize) -> Option<(&JoinHandle<()>, Sender<DownMsg<Down>>)> {
324        match self.workers.get(index) {
325            Some(x) => Some((&x.0, x.1.clone())),
326            None => None
327        }
328    }
329}