zrx_executor/executor/strategy/worker/stealing.rs
1// Copyright (c) 2025-2026 Zensical and contributors
2
3// SPDX-License-Identifier: MIT
4// All contributions are certified under the DCO
5
6// Permission is hereby granted, free of charge, to any person obtaining a copy
7// of this software and associated documentation files (the "Software"), to
8// deal in the Software without restriction, including without limitation the
9// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10// sell copies of the Software, and to permit persons to whom the Software is
11// furnished to do so, subject to the following conditions:
12
13// The above copyright notice and this permission notice shall be included in
14// all copies or substantial portions of the Software.
15
16// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18// FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
19// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22// IN THE SOFTWARE.
23
24// ----------------------------------------------------------------------------
25
26//! Work-stealing execution strategy.
27
28use crossbeam::deque::{Injector, Steal, Stealer, Worker};
29use std::fmt::{self, Debug};
30use std::iter::repeat_with;
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::thread::{self, Builder, JoinHandle};
34use std::{cmp, panic};
35
36use crate::executor::strategy::{Signal, Strategy};
37use crate::executor::task::Task;
38use crate::executor::{Error, Result};
39
40// ----------------------------------------------------------------------------
41// Structs
42// ----------------------------------------------------------------------------
43
44/// Work-stealing execution strategy.
45///
46/// This strategy implements work-stealing, where each worker thread has its own
47/// local queue. Workers can steal tasks from a central injector or from other
48/// workers, if their local queues are empty. This allows for more efficient
49/// execution if there's a large number of workers and tasks.
50///
51/// Work stealing enhances load balancing by allowing idle workers to take on
52/// tasks from busier peers, which helps to reduce idle time and can improve
53/// overall throughput. Unlike the simpler [`WorkSharing`][] strategy that uses
54/// a central queue where workers may become stuck waiting for new tasks, this
55/// method can yield better resource utilization. Additionally, work-stealing
56/// helps mitigate contention over shared resources (i.e. channels), which can
57/// become a bottleneck in central queueing systems, allowing each worker to
58/// primarily operate on its local queue. However, if task runtimes are short,
59/// the utilization can be lower than with a central queue, since stealing uses
60/// an optimistic strategy and is a best-effort operation. When task runtimes
61/// are long, work stealing can be more efficient than central queueing.
62///
63/// This reduced contention is particularly beneficial in dynamic environments
64/// with significantly fluctuating workloads, enabling faster task completion
65/// as workers can quickly adapt to take on shorter or less complex tasks as
66/// they become available.
67///
68/// [`WorkSharing`]: crate::executor::strategy::WorkSharing
69///
70/// # Examples
71///
72/// ```
73/// # use std::error::Error;
74/// # fn main() -> Result<(), Box<dyn Error>> {
75/// use zrx_executor::strategy::{Strategy, WorkStealing};
76///
77/// // Create strategy and submit task
78/// let strategy = WorkStealing::default();
79/// strategy.submit(Box::new(|| println!("Task")))?;
80/// # Ok(())
81/// # }
82/// ```
83pub struct WorkStealing {
84 /// Injector for task submission.
85 injector: Arc<Injector<Box<dyn Task>>>,
86 /// Signal for synchronization.
87 signal: Arc<Signal>,
88 /// Join handles of worker threads.
89 threads: Vec<JoinHandle<Result>>,
90 /// Counter for running tasks.
91 running: Arc<AtomicUsize>,
92 /// Counter for pending tasks.
93 pending: Arc<AtomicUsize>,
94 /// Maximum number of pending tasks.
95 capacity: usize,
96}
97
98// ----------------------------------------------------------------------------
99// Implementations
100// ----------------------------------------------------------------------------
101
102impl WorkStealing {
103 /// Creates a work-stealing execution strategy.
104 ///
105 /// This method creates a strategy with the given number of worker threads,
106 /// which are spawned immediately before the method returns. Internally, a
107 /// default limit of 8 tasks per worker is used, so for 4 workers, the
108 /// executor will have a capacity of 32 tasks.
109 ///
110 /// Use [`WorkStealing::with_capacity`] to set a custom capacity.
111 ///
112 /// # Panics
113 ///
114 /// Panics if thread creation fails.
115 ///
116 /// # Examples
117 ///
118 /// ```
119 /// use zrx_executor::strategy::WorkStealing;
120 ///
121 /// // Create strategy
122 /// let strategy = WorkStealing::new(4);
123 /// ```
124 #[must_use]
125 pub fn new(num_workers: usize) -> Self {
126 Self::with_capacity(num_workers, 8 * num_workers)
127 }
128
129 /// Creates a work-stealing execution strategy with the given capacity.
130 ///
131 /// This method creates a strategy with the given number of worker threads,
132 /// which are spawned immediately before the method returns.
133 ///
134 /// While this strategy uses unbounded channels due to how the [`Injector`]
135 /// and [`Worker`] concepts are implemented, it still provides a capacity
136 /// limit to ensure the executor doesn't end being overwhelmed. The given
137 /// capacity sets the number of tasks the executor accepts before starting
138 /// to reject them, which can be used to apply backpressure. Note that the
139 /// capacity is not a per-worker, but a global per-executor limit.
140 ///
141 /// # Panics
142 ///
143 /// Panics if thread creation fails.
144 ///
145 /// # Examples
146 ///
147 /// ```
148 /// use zrx_executor::strategy::WorkStealing;
149 ///
150 /// // Create strategy
151 /// let strategy = WorkStealing::with_capacity(4, 64);
152 /// ```
153 #[must_use]
154 pub fn with_capacity(num_workers: usize, capacity: usize) -> Self {
155 let injector = Arc::new(Injector::new());
156 let signal = Arc::new(Signal::new());
157
158 // Create worker queues
159 let mut workers = Vec::with_capacity(num_workers);
160 for _ in 0..num_workers {
161 workers.push(Worker::new_fifo());
162 }
163
164 // Obtain stealers from worker queues - note that we collect stealers
165 // into a slice and not a vector, as we won't change the data after
166 // initializing it, so we can share the stealers among workers without
167 // the need for synchronization.
168 let stealers: Arc<[Stealer<Box<dyn Task>>]> =
169 Arc::from(workers.iter().map(Worker::stealer).collect::<Vec<_>>());
170
171 // Keep track of running and pending tasks
172 let running = Arc::new(AtomicUsize::new(0));
173 let pending = Arc::new(AtomicUsize::new(0));
174
175 // Initialize worker threads
176 let iter = workers.into_iter().enumerate().map(|(index, worker)| {
177 let injector = Arc::clone(&injector);
178 let stealers = Arc::clone(&stealers);
179 let signal = Arc::clone(&signal);
180
181 // Create worker thread and obtain references to injector and
182 // stealers, which we need to retrieve the next task
183 let running = Arc::clone(&running);
184 let pending = Arc::clone(&pending);
185 let h = move || {
186 let injector = injector.as_ref();
187 let stealers = stealers.as_ref();
188
189 // Try to fetch the next task, either from the local queue, or
190 // from the injector or another worker. Additionally, we keep
191 // track of the number of running tasks to provide a simple way
192 // to monitor the load of the thread pool.
193 loop {
194 let Some(task) = get(&worker, injector, stealers) else {
195 // No more tasks, so we wait for the executor to signal
196 // if the worker should continue or terminate. This can
197 // fail due to a poisoned lock, in which case we need
198 // to terminate gracefully as well.
199 if signal.should_terminate()? {
200 break;
201 }
202
203 // Return to waiting for next task
204 continue;
205 };
206
207 // Update number of pending and running tasks
208 pending.fetch_sub(1, Ordering::Acquire);
209 running.fetch_add(1, Ordering::Release);
210
211 // Execute task, but ignore panics, since the executor has
212 // no way of reporting them, and they're printed anyway
213 let subtasks = panic::catch_unwind(|| task.execute())
214 .unwrap_or_default();
215
216 // Update number of running tasks
217 running.fetch_sub(1, Ordering::Acquire);
218
219 // In case the task returned further subtasks, we add them
220 // to the local queue, so they are executed by the current
221 // worker, or can be stolen by another worker in case the
222 // current worker thread is busy
223 if !subtasks.is_empty() {
224 let added = subtasks
225 .into_iter()
226 .map(|subtask| worker.push(subtask))
227 .count();
228
229 // Update number of running and pending tasks, and wake
230 // other workers threads to allow for stealing
231 pending.fetch_add(added, Ordering::Release);
232 signal.notify();
233 }
234 }
235
236 // No errors occurred
237 Ok(())
238 };
239
240 // We deliberately use unwrap here, as the capability to spawn
241 // threads is a fundamental requirement of the executor
242 Builder::new()
243 .name(format!("zrx/executor/{}", index + 1))
244 .spawn(h)
245 .unwrap()
246 });
247
248 // Create worker threads and return strategy
249 let threads = iter.collect();
250 Self {
251 injector,
252 signal,
253 threads,
254 running,
255 pending,
256 capacity,
257 }
258 }
259}
260
261// ----------------------------------------------------------------------------
262// Trait implementations
263// ----------------------------------------------------------------------------
264
265impl Strategy for WorkStealing {
266 /// Submits a task.
267 ///
268 /// This method submits a [`Task`], which is executed by one of the worker
269 /// threads as soon as possible. If a task computes a result, a [`Sender`][]
270 /// can be shared with the task, to send the result back to the caller,
271 /// which can then poll a [`Receiver`][].
272 ///
273 /// Note that tasks are intended to only run once, which is why they are
274 /// consumed. If a task needs to be run multiple times, it must be wrapped
275 /// in a closure that creates a new task each time. This allows for safe
276 /// sharing of state between tasks.
277 ///
278 /// [`Receiver`]: crossbeam::channel::Receiver
279 /// [`Sender`]: crossbeam::channel::Sender
280 ///
281 /// # Errors
282 ///
283 /// If the task cannot be submitted, [`Error::Submit`][] is returned, which
284 /// can only happen if the channel is disconnected or at capacity.
285 ///
286 /// # Examples
287 ///
288 /// ```
289 /// # use std::error::Error;
290 /// # fn main() -> Result<(), Box<dyn Error>> {
291 /// use zrx_executor::strategy::{Strategy, WorkStealing};
292 ///
293 /// // Create strategy and submit task
294 /// let strategy = WorkStealing::default();
295 /// strategy.submit(Box::new(|| println!("Task")))?;
296 /// # Ok(())
297 /// # }
298 /// ```
299 fn submit(&self, task: Box<dyn Task>) -> Result {
300 // As workers can steal tasks from the injector, we must manually track
301 // the number of pending tasks. Note that we must increment the counter
302 // before pushing the task to the injector, to ensure that the count is
303 // accurate at all times, even if the task is stolen immediately.
304 let pending = self.pending.fetch_add(1, Ordering::Release);
305 if pending == self.capacity {
306 // We hit the capacity limit, so we need to back off and reject the
307 // task submission, decrement the counter again and return the task
308 // as part of the error. We deliberately use an optimistic strategy
309 // here, so we don't need to check the counter before incrementing.
310 self.pending.fetch_sub(1, Ordering::Release);
311 return Err(Error::Submit(task));
312 }
313
314 // Submit the task to the injector and wake up waiting worker threads,
315 // so they can steal the task and execute it as soon as possible
316 self.injector.push(task);
317 self.signal.notify();
318
319 // No errors occurred
320 Ok(())
321 }
322
323 /// Returns the number of workers.
324 ///
325 /// # Examples
326 ///
327 /// ```
328 /// use zrx_executor::strategy::{Strategy, WorkStealing};
329 ///
330 /// // Get number of workers
331 /// let strategy = WorkStealing::new(1);
332 /// assert_eq!(strategy.num_workers(), 1);
333 /// ```
334 #[inline]
335 fn num_workers(&self) -> usize {
336 self.threads.len()
337 }
338
339 /// Returns the number of running tasks.
340 ///
341 /// This method allows to monitor the worker load, as it returns how many
342 /// workers are currently actively executing tasks.
343 ///
344 /// # Examples
345 ///
346 /// ```
347 /// use zrx_executor::strategy::{Strategy, WorkStealing};
348 ///
349 /// // Get number of running tasks
350 /// let strategy = WorkStealing::default();
351 /// assert_eq!(strategy.num_tasks_running(), 0);
352 /// ```
353 #[inline]
354 fn num_tasks_running(&self) -> usize {
355 self.running.load(Ordering::Relaxed)
356 }
357
358 /// Returns the number of pending tasks.
359 ///
360 /// This method allows to throttle the submission of tasks, as it returns
361 /// how many tasks are currently waiting to be executed.
362 ///
363 /// # Examples
364 ///
365 /// ```
366 /// use zrx_executor::strategy::{Strategy, WorkStealing};
367 ///
368 /// // Get number of pending tasks
369 /// let strategy = WorkStealing::default();
370 /// assert_eq!(strategy.num_tasks_pending(), 0);
371 /// ```
372 #[inline]
373 fn num_tasks_pending(&self) -> usize {
374 self.pending.load(Ordering::Relaxed)
375 }
376
377 /// Returns the capacity.
378 ///
379 /// The work-stealing execution strategy does not impose a hard limit on
380 /// the number of tasks. Thus, this strategy should only be used if tasks
381 /// are not produced faster than they can be executed, or the number of
382 /// tasks is limited by some other means.
383 ///
384 /// # Examples
385 ///
386 /// ```
387 /// use zrx_executor::strategy::{Strategy, WorkStealing};
388 ///
389 /// // Get capacity
390 /// let strategy = WorkStealing::default();
391 /// assert!(strategy.capacity() >= strategy.num_workers());
392 /// ```
393 #[inline]
394 fn capacity(&self) -> usize {
395 self.capacity
396 }
397}
398
399// ----------------------------------------------------------------------------
400
401impl Default for WorkStealing {
402 /// Creates a work-stealing execution strategy using all CPUs - 1.
403 ///
404 /// The number of workers is determined by the number of logical CPUs minus
405 /// one, which reserves one core for the main thread for orchestration. If
406 /// the number of logical CPUs is fewer than 1, the strategy defaults to a
407 /// single worker thread.
408 ///
409 /// __Warning__: this method makes use of [`thread::available_parallelism`]
410 /// to determine the number of available cores, which has some limitations.
411 /// Please refer to the documentation of that function for more details, or
412 /// consider using [`num_cpus`][] as an alternative.
413 ///
414 /// [`num_cpus`]: https://crates.io/crates/num_cpus
415 ///
416 /// # Examples
417 ///
418 /// ```
419 /// use zrx_executor::strategy::WorkStealing;
420 ///
421 /// // Create strategy
422 /// let strategy = WorkStealing::default();
423 /// ```
424 #[inline]
425 fn default() -> Self {
426 Self::new(cmp::max(
427 thread::available_parallelism()
428 .map(|num| num.get().saturating_sub(1))
429 .unwrap_or(1),
430 1,
431 ))
432 }
433}
434
435impl Drop for WorkStealing {
436 /// Terminates and joins all worker threads.
437 ///
438 /// This method waits for all worker threads to finish executing currently
439 /// running tasks, while ignoring any pending tasks. All worker threads are
440 /// joined before the method returns. This is necessary to prevent worker
441 /// threads from running after the strategy has been dropped.
442 fn drop(&mut self) {
443 let _ = self.signal.terminate();
444
445 // Join all worker threads without panicking on errors
446 for handle in self.threads.drain(..) {
447 let _ = handle.join();
448 }
449 }
450}
451
452// ----------------------------------------------------------------------------
453
454impl Debug for WorkStealing {
455 /// Formats the execution strategy for debugging.
456 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
457 f.debug_struct("WorkStealing")
458 .field("workers", &self.num_workers())
459 .field("running", &self.num_tasks_running())
460 .field("pending", &self.num_tasks_pending())
461 .finish()
462 }
463}
464
465// ----------------------------------------------------------------------------
466// Functions
467// ----------------------------------------------------------------------------
468
469/// Attempts to get the next available task, either from the worker's own queue
470/// or by stealing from the injector or other stealers if needed. Note that this
471/// code was taken almost verbatim from the [`crossbeam`] docs, specifically
472/// from [`crossbeam::deque`](crossbeam::deque#examples), but cut smaller.
473fn get<T>(
474 worker: &Worker<T>, injector: &Injector<T>, stealers: &[Stealer<T>],
475) -> Option<T> {
476 worker
477 .pop()
478 .or_else(|| steal_or_retry(worker, injector, stealers))
479}
480
481/// Repeatedly attempts to steal a task from the injector or stealers until a
482/// non-retryable steal result is found, returning the successful task if any.
483fn steal_or_retry<T>(
484 worker: &Worker<T>, injector: &Injector<T>, stealers: &[Stealer<T>],
485) -> Option<T> {
486 repeat_with(|| steal(worker, injector, stealers))
487 .find(|steal| !steal.is_retry())
488 .and_then(Steal::success)
489}
490
491/// Tries to steal a task from the injector or, if unavailable, from each
492/// stealer in sequence, collecting any valid tasks or signaling retry.
493fn steal<T>(
494 worker: &Worker<T>, injector: &Injector<T>, stealers: &[Stealer<T>],
495) -> Steal<T> {
496 injector
497 .steal_batch_and_pop(worker)
498 .or_else(|| stealers.iter().map(Stealer::steal).collect())
499}