Skip to main content

zrx_executor/executor/strategy/worker/
stealing.rs

1// Copyright (c) 2025-2026 Zensical and contributors
2
3// SPDX-License-Identifier: MIT
4// All contributions are certified under the DCO
5
6// Permission is hereby granted, free of charge, to any person obtaining a copy
7// of this software and associated documentation files (the "Software"), to
8// deal in the Software without restriction, including without limitation the
9// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10// sell copies of the Software, and to permit persons to whom the Software is
11// furnished to do so, subject to the following conditions:
12
13// The above copyright notice and this permission notice shall be included in
14// all copies or substantial portions of the Software.
15
16// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18// FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
19// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22// IN THE SOFTWARE.
23
24// ----------------------------------------------------------------------------
25
26//! Work-stealing execution strategy.
27
28use crossbeam::deque::{Injector, Steal, Stealer, Worker};
29use std::fmt::{self, Debug};
30use std::iter::repeat_with;
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::thread::{self, Builder, JoinHandle};
34use std::{cmp, panic};
35
36use crate::executor::strategy::{Signal, Strategy};
37use crate::executor::task::Task;
38use crate::executor::Result;
39
40// ----------------------------------------------------------------------------
41// Structs
42// ----------------------------------------------------------------------------
43
44/// Work-stealing execution strategy.
45///
46/// This strategy implements work-stealing, where each worker thread has its own
47/// local queue. Workers can steal tasks from a central injector or from other
48/// workers, if their local queues are empty. This allows for more efficient
49/// execution if there's a large number of workers and tasks.
50///
51/// Work stealing enhances load balancing by allowing idle workers to take on
52/// tasks from busier peers, which helps to reduce idle time and can improve
53/// overall throughput. Unlike the simpler [`WorkSharing`][] strategy that uses
54/// a central queue where workers may become stuck waiting for new tasks, this
55/// method can yield better resource utilization. Additionally, work-stealing
56/// helps mitigate contention over shared resources (i.e. channels), which can
57/// become a bottleneck in central queueing systems, allowing each worker to
58/// primarily operate on its local queue. However, if task runtimes are short,
59/// the utilization can be lower than with a central queue, since stealing uses
60/// an optimistic strategy and is a best-effort operation. When task runtimes
61/// are long, work stealing can be more efficient than central queueing.
62///
63/// This reduced contention is particularly beneficial in dynamic environments
64/// with significantly fluctuating workloads, enabling faster task completion
65/// as workers can quickly adapt to take on shorter or less complex tasks as
66/// they become available.
67///
68/// [`WorkSharing`]: crate::executor::strategy::WorkSharing
69///
70/// # Examples
71///
72/// ```
73/// # use std::error::Error;
74/// # fn main() -> Result<(), Box<dyn Error>> {
75/// use zrx_executor::strategy::{Strategy, WorkStealing};
76///
77/// // Create strategy and submit task
78/// let strategy = WorkStealing::default();
79/// strategy.submit(Box::new(|| println!("Task")))?;
80/// # Ok(())
81/// # }
82/// ```
83pub struct WorkStealing {
84    /// Injector for task submission.
85    injector: Arc<Injector<Box<dyn Task>>>,
86    /// Signal for synchronization.
87    signal: Arc<Signal>,
88    /// Join handles of worker threads.
89    threads: Vec<JoinHandle<Result>>,
90    /// Counter for running tasks.
91    running: Arc<AtomicUsize>,
92    /// Counter for pending tasks.
93    pending: Arc<AtomicUsize>,
94}
95
96// ----------------------------------------------------------------------------
97// Implementations
98// ----------------------------------------------------------------------------
99
100impl WorkStealing {
101    /// Creates a work-stealing execution strategy.
102    ///
103    /// This method creates a strategy with the given number of worker threads,
104    /// which are spawned immediately before the method returns. Note that this
105    /// strategy uses an unbounded channel, so there're no capacity limits as
106    /// for the [`WorkSharing`][] execution strategy.
107    ///
108    /// [`WorkSharing`]: crate::executor::strategy::WorkSharing
109    ///
110    /// # Panics
111    ///
112    /// Panics if thread creation fails.
113    ///
114    /// # Examples
115    ///
116    /// ```
117    /// use zrx_executor::strategy::WorkStealing;
118    ///
119    /// // Create strategy
120    /// let strategy = WorkStealing::new(4);
121    /// ```
122    #[must_use]
123    pub fn new(num_workers: usize) -> Self {
124        let injector = Arc::new(Injector::new());
125        let signal = Arc::new(Signal::new());
126
127        // Create worker queues
128        let mut workers = Vec::with_capacity(num_workers);
129        for _ in 0..num_workers {
130            workers.push(Worker::new_fifo());
131        }
132
133        // Obtain stealers from worker queues - note that we collect stealers
134        // into a slice and not a vector, as we won't change the data after
135        // initializing it, so we can share the stealers among workers without
136        // the need for synchronization.
137        let stealers: Arc<[Stealer<Box<dyn Task>>]> =
138            Arc::from(workers.iter().map(Worker::stealer).collect::<Vec<_>>());
139
140        // Keep track of running and pending tasks
141        let running = Arc::new(AtomicUsize::new(0));
142        let pending = Arc::new(AtomicUsize::new(0));
143
144        // Initialize worker threads
145        let iter = workers.into_iter().enumerate().map(|(index, worker)| {
146            let injector = Arc::clone(&injector);
147            let stealers = Arc::clone(&stealers);
148            let signal = Arc::clone(&signal);
149
150            // Create worker thread and obtain references to injector and
151            // stealers, which we need to retrieve the next task
152            let running = Arc::clone(&running);
153            let pending = Arc::clone(&pending);
154            let h = move || {
155                let injector = injector.as_ref();
156                let stealers = stealers.as_ref();
157
158                // Try to fetch the next task, either from the local queue, or
159                // from the injector or another worker. Additionally, we keep
160                // track of the number of running tasks to provide a simple way
161                // to monitor the load of the thread pool.
162                loop {
163                    let Some(task) = get(&worker, injector, stealers) else {
164                        // No more tasks, so we wait for the executor to signal
165                        // if the worker should continue or terminate. This can
166                        // fail due to a poisoned lock, in which case we need
167                        // to terminate gracefully as well.
168                        if signal.should_terminate()? {
169                            break;
170                        }
171
172                        // Return to waiting for next task
173                        continue;
174                    };
175
176                    // Update number of pending and running tasks
177                    pending.fetch_sub(1, Ordering::Acquire);
178                    running.fetch_add(1, Ordering::Release);
179
180                    // Execute task, but ignore panics, since the executor has
181                    // no way of reporting them, and they're printed anyway
182                    let subtasks = panic::catch_unwind(|| task.execute())
183                        .unwrap_or_default();
184
185                    // Update number of running tasks
186                    running.fetch_sub(1, Ordering::Acquire);
187
188                    // In case the task returned further subtasks, we add them
189                    // to the local queue, so they are executed by the current
190                    // worker, or can be stolen by another worker in case the
191                    // current worker thread is busy
192                    if !subtasks.is_empty() {
193                        let added = subtasks
194                            .into_iter()
195                            .map(|subtask| worker.push(subtask))
196                            .count();
197
198                        // Update number of running and pending tasks, and wake
199                        // other workers threads to allow for stealing
200                        pending.fetch_add(added, Ordering::Release);
201                        signal.notify();
202                    }
203                }
204
205                // No errors occurred
206                Ok(())
207            };
208
209            // We deliberately use unwrap here, as the capability to spawn
210            // threads is a fundamental requirement of the executor
211            Builder::new()
212                .name(format!("zrx/executor/{}", index + 1))
213                .spawn(h)
214                .unwrap()
215        });
216
217        // Create worker threads and return strategy
218        let threads = iter.collect();
219        Self {
220            injector,
221            signal,
222            threads,
223            running,
224            pending,
225        }
226    }
227}
228
229// ----------------------------------------------------------------------------
230// Trait implementations
231// ----------------------------------------------------------------------------
232
233impl Strategy for WorkStealing {
234    /// Submits a task.
235    ///
236    /// This method submits a [`Task`], which is executed by one of the worker
237    /// threads as soon as possible. If a task computes a result, a [`Sender`][]
238    /// can be shared with the task, to send the result back to the caller,
239    /// which can then poll a [`Receiver`][].
240    ///
241    /// Note that tasks are intended to only run once, which is why they are
242    /// consumed. If a task needs to be run multiple times, it must be wrapped
243    /// in a closure that creates a new task each time. This allows for safe
244    /// sharing of state between tasks.
245    ///
246    /// [`Receiver`]: crossbeam::channel::Receiver
247    /// [`Sender`]: crossbeam::channel::Sender
248    ///
249    /// # Errors
250    ///
251    /// This method is infallible, and will always return [`Ok`].
252    ///
253    /// # Examples
254    ///
255    /// ```
256    /// # use std::error::Error;
257    /// # fn main() -> Result<(), Box<dyn Error>> {
258    /// use zrx_executor::strategy::{Strategy, WorkStealing};
259    ///
260    /// // Create strategy and submit task
261    /// let strategy = WorkStealing::default();
262    /// strategy.submit(Box::new(|| println!("Task")))?;
263    /// # Ok(())
264    /// # }
265    /// ```
266    fn submit(&self, task: Box<dyn Task>) -> Result {
267        // As workers can steal tasks from the injector, we must manually track
268        // the number of pending tasks. Note that we must increment the counter
269        // before pushing the task to the injector, to ensure that the count is
270        // accurate at all times, even if the task is stolen immediately.
271        self.pending.fetch_add(1, Ordering::Release);
272
273        // Submit the task to the injector and wake up waiting worker threads,
274        // so they can steal the task and execute it as soon as possible
275        self.injector.push(task);
276        self.signal.notify();
277
278        // No errors occurred
279        Ok(())
280    }
281
282    /// Returns the number of workers.
283    ///
284    /// # Examples
285    ///
286    /// ```
287    /// use zrx_executor::strategy::{Strategy, WorkStealing};
288    ///
289    /// // Get number of workers
290    /// let strategy = WorkStealing::new(1);
291    /// assert_eq!(strategy.num_workers(), 1);
292    /// ```
293    #[inline]
294    fn num_workers(&self) -> usize {
295        self.threads.len()
296    }
297
298    /// Returns the number of running tasks.
299    ///
300    /// This method allows to monitor the worker load, as it returns how many
301    /// workers are currently actively executing tasks.
302    ///
303    /// # Examples
304    ///
305    /// ```
306    /// use zrx_executor::strategy::{Strategy, WorkStealing};
307    ///
308    /// // Get number of running tasks
309    /// let strategy = WorkStealing::default();
310    /// assert_eq!(strategy.num_tasks_running(), 0);
311    /// ```
312    #[inline]
313    fn num_tasks_running(&self) -> usize {
314        self.running.load(Ordering::Relaxed)
315    }
316
317    /// Returns the number of pending tasks.
318    ///
319    /// This method allows to throttle the submission of tasks, as it returns
320    /// how many tasks are currently waiting to be executed.
321    ///
322    /// # Examples
323    ///
324    /// ```
325    /// use zrx_executor::strategy::{Strategy, WorkStealing};
326    ///
327    /// // Get number of pending tasks
328    /// let strategy = WorkStealing::default();
329    /// assert_eq!(strategy.num_tasks_pending(), 0);
330    /// ```
331    #[inline]
332    fn num_tasks_pending(&self) -> usize {
333        self.pending.load(Ordering::Relaxed)
334    }
335
336    /// Returns the capacity, if bounded.
337    ///
338    /// The work-stealing execution strategy does not impose a hard limit on
339    /// the number of tasks. Thus, this strategy should only be used if tasks
340    /// are not produced faster than they can be executed, or the number of
341    /// tasks is limited by some other means.
342    ///
343    /// # Examples
344    ///
345    /// ```
346    /// use zrx_executor::strategy::{Strategy, WorkStealing};
347    ///
348    /// // Get capacity
349    /// let strategy = WorkStealing::default();
350    /// assert_eq!(strategy.capacity(), None);
351    /// ```
352    #[inline]
353    fn capacity(&self) -> Option<usize> {
354        None
355    }
356}
357
358// ----------------------------------------------------------------------------
359
360impl Default for WorkStealing {
361    /// Creates a work-stealing execution strategy using all CPUs - 1.
362    ///
363    /// The number of workers is determined by the number of logical CPUs minus
364    /// one, which reserves one core for the main thread for orchestration. If
365    /// the number of logical CPUs is fewer than 1, the strategy defaults to a
366    /// single worker thread.
367    ///
368    /// __Warning__: this method makes use of [`thread::available_parallelism`]
369    /// to determine the number of available cores, which has some limitations.
370    /// Please refer to the documentation of that function for more details, or
371    /// consider using [`num_cpus`][] as an alternative.
372    ///
373    /// [`num_cpus`]: https://crates.io/crates/num_cpus
374    ///
375    /// # Examples
376    ///
377    /// ```
378    /// use zrx_executor::strategy::WorkStealing;
379    ///
380    /// // Create strategy
381    /// let strategy = WorkStealing::default();
382    /// ```
383    #[inline]
384    fn default() -> Self {
385        Self::new(cmp::max(
386            thread::available_parallelism()
387                .map(|num| num.get().saturating_sub(1))
388                .unwrap_or(1),
389            1,
390        ))
391    }
392}
393
394impl Drop for WorkStealing {
395    /// Terminates and joins all worker threads.
396    ///
397    /// This method waits for all worker threads to finish executing currently
398    /// running tasks, while ignoring any pending tasks. All worker threads are
399    /// joined before the method returns. This is necessary to prevent worker
400    /// threads from running after the strategy has been dropped.
401    fn drop(&mut self) {
402        let _ = self.signal.terminate();
403
404        // Join all worker threads without panicking on errors
405        for handle in self.threads.drain(..) {
406            let _ = handle.join();
407        }
408    }
409}
410
411// ----------------------------------------------------------------------------
412
413impl Debug for WorkStealing {
414    /// Formats the execution strategy for debugging.
415    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
416        f.debug_struct("WorkStealing")
417            .field("workers", &self.num_workers())
418            .field("running", &self.num_tasks_running())
419            .field("pending", &self.num_tasks_pending())
420            .finish()
421    }
422}
423
424// ----------------------------------------------------------------------------
425// Functions
426// ----------------------------------------------------------------------------
427
428/// Attempts to get the next available task, either from the worker's own queue
429/// or by stealing from the injector or other stealers if needed. Note that this
430/// code was taken almost verbatim from the [`crossbeam`] docs, specifically
431/// from [`crossbeam::deque`](crossbeam::deque#examples), but cut smaller.
432fn get<T>(
433    worker: &Worker<T>, injector: &Injector<T>, stealers: &[Stealer<T>],
434) -> Option<T> {
435    worker
436        .pop()
437        .or_else(|| steal_or_retry(worker, injector, stealers))
438}
439
440/// Repeatedly attempts to steal a task from the injector or stealers until a
441/// non-retryable steal result is found, returning the successful task if any.
442fn steal_or_retry<T>(
443    worker: &Worker<T>, injector: &Injector<T>, stealers: &[Stealer<T>],
444) -> Option<T> {
445    repeat_with(|| steal(worker, injector, stealers))
446        .find(|steal| !steal.is_retry())
447        .and_then(Steal::success)
448}
449
450/// Tries to steal a task from the injector or, if unavailable, from each
451/// stealer in sequence, collecting any valid tasks or signaling retry.
452fn steal<T>(
453    worker: &Worker<T>, injector: &Injector<T>, stealers: &[Stealer<T>],
454) -> Steal<T> {
455    injector
456        .steal_batch_and_pop(worker)
457        .or_else(|| stealers.iter().map(Stealer::steal).collect())
458}