zrx_executor/executor/strategy/worker/
stealing.rs

1// Copyright (c) Zensical LLC <https://zensical.org>
2
3// SPDX-License-Identifier: MIT
4// Third-party contributions licensed under CLA
5
6// Permission is hereby granted, free of charge, to any person obtaining a copy
7// of this software and associated documentation files (the "Software"), to
8// deal in the Software without restriction, including without limitation the
9// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10// sell copies of the Software, and to permit persons to whom the Software is
11// furnished to do so, subject to the following conditions:
12
13// The above copyright notice and this permission notice shall be included in
14// all copies or substantial portions of the Software.
15
16// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18// FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
19// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22// IN THE SOFTWARE.
23
24// ----------------------------------------------------------------------------
25
26//! Work-stealing execution strategy.
27
28use crossbeam::deque::{Injector, Steal, Stealer, Worker};
29use std::iter::repeat_with;
30use std::sync::atomic::{AtomicUsize, Ordering};
31use std::sync::Arc;
32use std::thread::{self, Builder, JoinHandle};
33use std::{cmp, fmt, panic};
34
35use crate::executor::strategy::{Signal, Strategy};
36use crate::executor::task::Task;
37use crate::executor::Result;
38
39// ----------------------------------------------------------------------------
40// Structs
41// ----------------------------------------------------------------------------
42
43/// Work-stealing execution strategy.
44///
45/// This strategy implements work-stealing, where each worker thread has its own
46/// local queue. Workers can steal tasks from a central injector or from other
47/// workers, if their local queues are empty. This allows for more efficient
48/// execution if there's a large number of workers and tasks.
49///
50/// Work stealing enhances load balancing by allowing idle workers to take on
51/// tasks from busier peers, which helps to reduce idle time and can improve
52/// overall throughput. Unlike the simpler [`WorkSharing`][] strategy that uses
53/// a central queue where workers may become stuck waiting for new tasks, this
54/// method can yield better resource utilization. Additionally, work-stealing
55/// helps mitigate contention over shared resources (i.e. channels), which can
56/// become a bottleneck in central queueing systems, allowing each worker to
57/// primarily operate on its local queue. However, if task runtimes are short,
58/// the utilization can be lower than with a central queue, since stealing uses
59/// an optimistic strategy and is a best-effort operation. When task runtimes
60/// are long, work stealing can be more efficient than central queueing.
61///
62/// This reduced contention is particularly beneficial in dynamic environments
63/// with significantly fluctuating workloads, enabling faster task completion
64/// as workers can quickly adapt to take on shorter or less complex tasks as
65/// they become available.
66///
67/// [`WorkSharing`]: crate::executor::strategy::WorkSharing
68///
69/// # Examples
70///
71/// ```
72/// # use std::error::Error;
73/// # fn main() -> Result<(), Box<dyn Error>> {
74/// use zrx_executor::strategy::{Strategy, WorkStealing};
75///
76/// // Create strategy and submit task
77/// let strategy = WorkStealing::default();
78/// strategy.submit(Box::new(|| println!("Task")))?;
79/// # Ok(())
80/// # }
81/// ```
82pub struct WorkStealing {
83    /// Injector for task submission.
84    injector: Arc<Injector<Box<dyn Task>>>,
85    /// Signal for synchronization.
86    signal: Arc<Signal>,
87    /// Join handles of worker threads.
88    threads: Vec<JoinHandle<Result>>,
89    /// Counter for running tasks.
90    running: Arc<AtomicUsize>,
91    /// Counter for pending tasks.
92    pending: Arc<AtomicUsize>,
93}
94
95// ----------------------------------------------------------------------------
96// Implementations
97// ----------------------------------------------------------------------------
98
99impl WorkStealing {
100    /// Creates a work-stealing execution strategy.
101    ///
102    /// This method creates a strategy with the given number of worker threads,
103    /// which are spawned immediately before the method returns. Note that this
104    /// strategy uses an unbounded channel, so there're no capacity limits as
105    /// for the [`WorkSharing`][] execution strategy.
106    ///
107    /// [`WorkSharing`]: crate::executor::strategy::WorkSharing
108    ///
109    /// # Panics
110    ///
111    /// Panics if thread creation fails.
112    ///
113    /// # Examples
114    ///
115    /// ```
116    /// use zrx_executor::strategy::WorkStealing;
117    ///
118    /// // Create strategy
119    /// let strategy = WorkStealing::new(4);
120    /// ```
121    #[must_use]
122    pub fn new(num_workers: usize) -> Self {
123        let injector = Arc::new(Injector::new());
124        let signal = Arc::new(Signal::new());
125
126        // Create worker queues
127        let mut workers = Vec::with_capacity(num_workers);
128        for _ in 0..num_workers {
129            workers.push(Worker::new_fifo());
130        }
131
132        // Obtain stealers from worker queues - note that we collect stealers
133        // into a slice and not a vector, as we won't change the data after
134        // initializing it, so we can share the stealers among workers without
135        // the need for synchronization.
136        let stealers: Arc<[Stealer<Box<dyn Task>>]> =
137            Arc::from(workers.iter().map(Worker::stealer).collect::<Vec<_>>());
138
139        // Keep track of running and pending tasks
140        let running = Arc::new(AtomicUsize::new(0));
141        let pending = Arc::new(AtomicUsize::new(0));
142
143        // Initialize worker threads
144        let iter = workers.into_iter().enumerate().map(|(index, worker)| {
145            let injector = Arc::clone(&injector);
146            let stealers = Arc::clone(&stealers);
147            let signal = Arc::clone(&signal);
148
149            // Create worker thread and obtain references to injector and
150            // stealers, which we need to retrieve the next task
151            let running = Arc::clone(&running);
152            let pending = Arc::clone(&pending);
153            let h = move || {
154                let injector = injector.as_ref();
155                let stealers = stealers.as_ref();
156
157                // Try to fetch the next task, either from the local queue, or
158                // from the injector or another worker. Additionally, we keep
159                // track of the number of running tasks to provide a simple way
160                // to monitor the load of the thread pool.
161                loop {
162                    let Some(task) = get(&worker, injector, stealers) else {
163                        // No more tasks, so we wait for the executor to signal
164                        // if the worker should continue or terminate. This can
165                        // fail due to a poisoned lock, in which case we need
166                        // to terminate gracefully as well.
167                        if signal.should_terminate()? {
168                            break;
169                        }
170
171                        // Return to waiting for next task
172                        continue;
173                    };
174
175                    // Update number of pending and running tasks
176                    pending.fetch_sub(1, Ordering::Acquire);
177                    running.fetch_add(1, Ordering::Release);
178
179                    // Execute task, but ignore panics, since the executor has
180                    // no way of reporting them, and they're printed anyway
181                    let subtasks = panic::catch_unwind(|| task.execute())
182                        .unwrap_or_default();
183
184                    // Update number of running tasks
185                    running.fetch_sub(1, Ordering::Acquire);
186
187                    // In case the task returned further subtasks, we add them
188                    // to the local queue, so they are executed by the current
189                    // worker, or can be stolen by another worker in case the
190                    // current worker thread is busy
191                    if !subtasks.is_empty() {
192                        let added = subtasks
193                            .into_iter()
194                            .map(|subtask| worker.push(subtask))
195                            .count();
196
197                        // Update number of running and pending tasks, and wake
198                        // other workers threads to allow for stealing
199                        pending.fetch_add(added, Ordering::Release);
200                        signal.notify();
201                    }
202                }
203
204                // No errors occurred
205                Ok(())
206            };
207
208            // We deliberately use unwrap here, as the capability to spawn
209            // threads is a fundamental requirement of the executor
210            Builder::new()
211                .name(format!("zrx/executor/{}", index + 1))
212                .spawn(h)
213                .unwrap()
214        });
215
216        // Create worker threads and return strategy
217        let threads = iter.collect();
218        Self {
219            injector,
220            signal,
221            threads,
222            running,
223            pending,
224        }
225    }
226}
227
228// ----------------------------------------------------------------------------
229// Trait implementations
230// ----------------------------------------------------------------------------
231
232impl Strategy for WorkStealing {
233    /// Submits a task.
234    ///
235    /// This method submits a [`Task`], which is executed by one of the worker
236    /// threads as soon as possible. If a task computes a result, a [`Sender`][]
237    /// can be shared with the task, to send the result back to the caller,
238    /// which can then poll a [`Receiver`][].
239    ///
240    /// Note that tasks are intended to only run once, which is why they are
241    /// consumed. If a task needs to be run multiple times, it must be wrapped
242    /// in a closure that creates a new task each time. This allows for safe
243    /// sharing of state between tasks.
244    ///
245    /// [`Receiver`]: crossbeam::channel::Receiver
246    /// [`Sender`]: crossbeam::channel::Sender
247    ///
248    /// # Errors
249    ///
250    /// This method is infallible, and will always return [`Ok`].
251    ///
252    /// # Examples
253    ///
254    /// ```
255    /// # use std::error::Error;
256    /// # fn main() -> Result<(), Box<dyn Error>> {
257    /// use zrx_executor::strategy::{Strategy, WorkStealing};
258    ///
259    /// // Create strategy and submit task
260    /// let strategy = WorkStealing::default();
261    /// strategy.submit(Box::new(|| println!("Task")))?;
262    /// # Ok(())
263    /// # }
264    /// ```
265    fn submit(&self, task: Box<dyn Task>) -> Result {
266        // As workers can steal tasks from the injector, we must manually track
267        // the number of pending tasks. For this reason, we increment the count
268        // by one to signal a new task was added, hand the task to the injector,
269        // and then wake any waiting worker threads.
270        self.injector.push(task);
271        self.pending.fetch_add(1, Ordering::Release);
272        self.signal.notify();
273
274        // No errors occurred
275        Ok(())
276    }
277
278    /// Returns the number of workers.
279    ///
280    /// # Examples
281    ///
282    /// ```
283    /// use zrx_executor::strategy::{Strategy, WorkStealing};
284    ///
285    /// // Get number of workers
286    /// let strategy = WorkStealing::new(1);
287    /// assert_eq!(strategy.num_workers(), 1);
288    /// ```
289    #[inline]
290    fn num_workers(&self) -> usize {
291        self.threads.len()
292    }
293
294    /// Returns the number of running tasks.
295    ///
296    /// This method allows to monitor the worker load, as it returns how many
297    /// workers are currently actively executing tasks.
298    ///
299    /// # Examples
300    ///
301    /// ```
302    /// use zrx_executor::strategy::{Strategy, WorkStealing};
303    ///
304    /// // Get number of running tasks
305    /// let strategy = WorkStealing::default();
306    /// assert_eq!(strategy.num_tasks_running(), 0);
307    /// ```
308    #[inline]
309    fn num_tasks_running(&self) -> usize {
310        self.running.load(Ordering::Relaxed)
311    }
312
313    /// Returns the number of pending tasks.
314    ///
315    /// This method allows to throttle the submission of tasks, as it returns
316    /// how many tasks are currently waiting to be executed.
317    ///
318    /// # Examples
319    ///
320    /// ```
321    /// use zrx_executor::strategy::{Strategy, WorkStealing};
322    ///
323    /// // Get number of pending tasks
324    /// let strategy = WorkStealing::default();
325    /// assert_eq!(strategy.num_tasks_pending(), 0);
326    /// ```
327    #[inline]
328    fn num_tasks_pending(&self) -> usize {
329        self.pending.load(Ordering::Relaxed)
330    }
331
332    /// Returns the capacity, if bounded.
333    ///
334    /// The work-stealing execution strategy does not impose a hard limit on
335    /// the number of tasks. Thus, this strategy should only be used if tasks
336    /// are not produced faster than they can be executed, or the number of
337    /// tasks is limited by some other means.
338    ///
339    /// # Examples
340    ///
341    /// ```
342    /// use zrx_executor::strategy::{Strategy, WorkStealing};
343    ///
344    /// // Get capacity
345    /// let strategy = WorkStealing::default();
346    /// assert_eq!(strategy.capacity(), None);
347    /// ```
348    #[inline]
349    fn capacity(&self) -> Option<usize> {
350        None
351    }
352}
353
354// ----------------------------------------------------------------------------
355
356impl Default for WorkStealing {
357    /// Creates a work-stealing execution strategy using all CPUs - 1.
358    ///
359    /// The number of workers is determined by the number of logical CPUs minus
360    /// one, which reserves one core for the main thread for orchestration. If
361    /// the number of logical CPUs is fewer than 1, the strategy defaults to a
362    /// single worker thread.
363    ///
364    /// __Warning__: this method makes use of [`thread::available_parallelism`]
365    /// to determine the number of available cores, which has some limitations.
366    /// Please refer to the documentation of that function for more details, or
367    /// consider using [`num_cpus`][] as an alternative.
368    ///
369    /// [`num_cpus`]: https://crates.io/crates/num_cpus
370    ///
371    /// # Examples
372    ///
373    /// ```
374    /// use zrx_executor::strategy::WorkStealing;
375    ///
376    /// // Create strategy
377    /// let strategy = WorkStealing::default();
378    /// ```
379    #[inline]
380    fn default() -> Self {
381        Self::new(cmp::max(
382            thread::available_parallelism()
383                .map(|num| num.get().saturating_sub(1))
384                .unwrap_or(1),
385            1,
386        ))
387    }
388}
389
390impl Drop for WorkStealing {
391    /// Terminates and joins all worker threads.
392    ///
393    /// This method waits for all worker threads to finish executing currently
394    /// running tasks, while ignoring any pending tasks. All worker threads are
395    /// joined before the method returns. This is necessary to prevent worker
396    /// threads from running after the strategy has been dropped.
397    fn drop(&mut self) {
398        let _ = self.signal.terminate();
399
400        // Join all worker threads without panicking on errors
401        for handle in self.threads.drain(..) {
402            let _ = handle.join();
403        }
404    }
405}
406
407// ----------------------------------------------------------------------------
408
409impl fmt::Debug for WorkStealing {
410    /// Formats the execution strategy for debugging.
411    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
412        f.debug_struct("WorkStealing")
413            .field("workers", &self.num_workers())
414            .field("running", &self.num_tasks_running())
415            .field("pending", &self.num_tasks_pending())
416            .finish()
417    }
418}
419
420// ----------------------------------------------------------------------------
421// Functions
422// ----------------------------------------------------------------------------
423
424/// Attempts to get the next available task, either from the worker's own queue
425/// or by stealing from the injector or other stealers if needed. Note that this
426/// code was taken almost verbatim from the [`crossbeam`] docs, specifically
427/// from [`crossbeam::deque`](crossbeam::deque#examples), but cut smaller.
428fn get<T>(
429    worker: &Worker<T>, injector: &Injector<T>, stealers: &[Stealer<T>],
430) -> Option<T> {
431    worker
432        .pop()
433        .or_else(|| steal_or_retry(worker, injector, stealers))
434}
435
436/// Repeatedly attempts to steal a task from the injector or stealers until a
437/// non-retryable steal result is found, returning the successful task if any.
438fn steal_or_retry<T>(
439    worker: &Worker<T>, injector: &Injector<T>, stealers: &[Stealer<T>],
440) -> Option<T> {
441    repeat_with(|| steal(worker, injector, stealers))
442        .find(|steal| !steal.is_retry())
443        .and_then(Steal::success)
444}
445
446/// Tries to steal a task from the injector or, if unavailable, from each
447/// stealer in sequence, collecting any valid tasks or signaling retry.
448fn steal<T>(
449    worker: &Worker<T>, injector: &Injector<T>, stealers: &[Stealer<T>],
450) -> Steal<T> {
451    injector
452        .steal_batch_and_pop(worker)
453        .or_else(|| stealers.iter().map(Stealer::steal).collect())
454}