work_queue/
lib.rs

1//! A concurrent work-stealing queue for building schedulers.
2//!
3//! # Examples
4//!
5//! Distribute some tasks in a thread pool:
6//!
7//! ```
8//! use work_queue::{Queue, LocalQueue};
9//!
10//! struct Task(Box<dyn Fn(&mut LocalQueue<Task>) + Send>);
11//!
12//! let threads = 4;
13//!
14//! let queue: Queue<Task> = Queue::new(threads, 128);
15//!
16//! // Push some tasks to the queue.
17//! for _ in 0..500 {
18//!     queue.push(Task(Box::new(|local| {
19//!         do_work();
20//!
21//!         local.push(Task(Box::new(|_| do_work())));
22//!         local.push(Task(Box::new(|_| do_work())));
23//!     })));
24//! }
25//!
26//! // Spawn threads to complete the tasks.
27//! let handles: Vec<_> = queue
28//!     .local_queues()
29//!     .map(|mut local_queue| {
30//!         std::thread::spawn(move || {
31//!             while let Some(task) = local_queue.pop() {
32//!                 task.0(&mut local_queue);
33//!             }
34//!         })
35//!     })
36//!     .collect();
37//!
38//! for handle in handles {
39//!     handle.join().unwrap();
40//! }
41//! # fn do_work() {}
42//! ```
43//!
44//! # Comparison with crossbeam-deque
45//!
46//! This crate is similar in purpose to [`crossbeam-deque`](https://docs.rs/crossbeam-deque), which
47//! also provides concurrent work-stealing queues. However there are a few notable differences:
48//!
49//! - This crate is more high level - work stealing is done automatically when calling `pop`
50//! instead of you having to manually call it.
51//! - As such, we do not support as much customization as `crossbeam-deque` - but the algorithm
52//! itself can be optimized better.
53//! - Queues have a fixed number of local queues that they support, and this number cannot grow.
54//! - Each local queue has a fixed capacity, unlike `crossbeam-deque` which supports local queue
55//! growth. This makes our local queues faster.
56//!
57//! # Implementation
58//!
59//! This crate's queue implementation is based off [Tokio's current scheduler]. The idea is that
60//! each thread holds a fixed-capacity local queue, and there is also an unbounded global queue
61//! accessible by all threads. In the general case each worker thread will only interact with its
62//! local queue, avoiding lots of synchronization - but if one worker thread happens to have a
63//! lot less work than another, it will be spread out evenly due to work stealing.
64//!
65//! Additionally, each local queue stores a [non-stealable LIFO slot] to optimize for message
66//! passing patterns, so that if one task creates another, that created task will be polled
67//! immediately, instead of only much later when it reaches the front of the local queue.
68//!
69//! [Tokio's current scheduler]: https://tokio.rs/blog/2019-10-scheduler
70//! [non-stealable LIFO slot]: https://tokio.rs/blog/2019-10-scheduler#optimizing-for-message-passing-patterns
71#![warn(missing_debug_implementations, rust_2018_idioms, missing_docs)]
72
73use std::cmp;
74use std::collections::hash_map::DefaultHasher;
75use std::fmt::{self, Debug, Formatter};
76use std::hash::Hasher;
77use std::iter::FusedIterator;
78use std::mem::{self, MaybeUninit};
79use std::ops::Deref;
80use std::ptr::{self, NonNull};
81#[cfg(not(loom))]
82use std::{collections::hash_map::RandomState, hash::BuildHasher};
83
84#[cfg_attr(loom, path = "loom.rs")]
85#[cfg_attr(not(loom), path = "std.rs")]
86mod facade;
87use facade::atomic::{self, AtomicBool, AtomicU16, AtomicU32, AtomicUsize};
88use facade::Arc;
89use facade::GlobalQueue;
90use facade::UnsafeCell;
91
92/// A work queue.
93///
94/// This implements [`Clone`] and so multiple handles to the queue can be easily created and
95/// shared.
96#[derive(Debug)]
97pub struct Queue<T>(Arc<Shared<T>>);
98
99impl<T> Queue<T> {
100    /// Create a new work queue.
101    ///
102    /// `local_queues` is the number of [`LocalQueue`]s yielded by [`Self::local_queues`]. Typically
103    /// you will have a local queue for each thread on a thread pool.
104    ///
105    /// `local_queue_size` is the number of items that can be stored in each local queue before it
106    /// overflows into the global one. You should fine-tune this to your needs.
107    ///
108    /// # Panics
109    ///
110    /// This will panic if the local queue size is not a power of two.
111    ///
112    /// # Examples
113    ///
114    /// ```
115    /// use work_queue::Queue;
116    ///
117    /// let threads = 4;
118    /// let queue: Queue<i32> = Queue::new(threads, 512);
119    /// ```
120    pub fn new(local_queues: usize, local_queue_size: u16) -> Self {
121        assert_eq!(
122            local_queue_size.count_ones(),
123            1,
124            "Queue size is not a power of two"
125        );
126        let mask = local_queue_size - 1;
127
128        Self(Arc::new(Shared {
129            local_queues: (0..local_queues)
130                .map(|_| LocalQueueInner {
131                    heads: AtomicU32::new(0),
132                    tail: AtomicU16::new(0),
133                    mask,
134                    items: (0..local_queue_size)
135                        .map(|_| UnsafeCell::new(MaybeUninit::uninit()))
136                        .collect(),
137                })
138                .collect(),
139            global_queue: GlobalQueue::new(),
140            stealing_global: AtomicBool::new(false),
141            taken_local_queues: AtomicBool::new(false),
142            searchers: AtomicUsize::new(0),
143        }))
144    }
145
146    /// Push an item to the global queue. When one of the local queues empties, they can pick this
147    /// item up.
148    pub fn push(&self, item: T) {
149        let _ = self.0.global_queue.push(item);
150    }
151
152    /// Iterate over the local queues of this queue.
153    ///
154    /// # Panics
155    ///
156    /// This will panic if called more than one time.
157    pub fn local_queues(&self) -> LocalQueues<'_, T> {
158        assert!(!self
159            .0
160            .taken_local_queues
161            .swap(true, atomic::Ordering::Relaxed));
162
163        LocalQueues {
164            shared: self,
165            index: 0,
166            #[cfg(not(loom))]
167            hasher: RandomState::new().build_hasher(),
168            #[cfg(loom)]
169            hasher: DefaultHasher::new(),
170        }
171    }
172
173    /// Get the number of threads that are currently searching for work inside [`pop`](Self::pop).
174    ///
175    /// If this number is too high, you may wish to avoid calling [`pop`](Self::pop) to reduce
176    /// contention.
177    #[must_use]
178    pub fn searchers(&self) -> usize {
179        self.0.searchers.load(atomic::Ordering::Relaxed)
180    }
181}
182
183impl<T> Clone for Queue<T> {
184    fn clone(&self) -> Self {
185        Self(Arc::clone(&self.0))
186    }
187}
188
189#[derive(Debug)]
190struct Shared<T> {
191    local_queues: Box<[LocalQueueInner<T>]>,
192    global_queue: GlobalQueue<T>,
193    /// Whether a thread is currently stealing from the global queue. When `true`, threads
194    /// should avoid trying to pop from it to reduce contention.
195    stealing_global: AtomicBool,
196    /// Whether the local queues have already been yielded to the user and so shouldn't be yielded
197    /// again.
198    taken_local_queues: AtomicBool,
199    /// The number of queues searching for work.
200    searchers: AtomicUsize,
201}
202
203/// The fixed-capacity SP2C queue owned by each local queue.
204struct LocalQueueInner<T> {
205    /// The two heads (fronts) of the queue, packed into one atomic by `pack_heads` and
206    /// `unpack_heads`.
207    ///
208    /// The first head, the "stealer" head, always lags behind the second head, the "real" head.
209    /// Items are popped starting from the real head, but the space between the two heads still
210    /// cannot be overwritten by the tail, as it's being read by a stealer.
211    heads: AtomicU32,
212
213    /// The back of the queue. Only incremented by the associated queue.
214    tail: AtomicU16,
215
216    /// Bitmask applied to the head and tail to obtain the actual indices, so that the atomics can
217    /// be incremented and freely overflow outside of the range of the queue itself.
218    mask: u16,
219
220    /// The actual items in the queue.
221    items: Box<[UnsafeCell<MaybeUninit<T>>]>,
222}
223
224unsafe impl<T: Send> Sync for LocalQueueInner<T> {}
225
226impl<T> Debug for LocalQueueInner<T> {
227    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228        let (protected_head, head) = unpack_heads(self.heads.load(atomic::Ordering::Acquire));
229
230        f.debug_struct("LocalQueueInner")
231            .field("protected_head", &protected_head)
232            .field("head", &head)
233            .field("tail", &self.tail)
234            .field("mask", &format_args!("{:#b}", self.mask))
235            .finish()
236    }
237}
238
239/// Unpack the `heads` value in a `LocalQueueInner`. Returns a tuple of the stealer head and the
240/// real head.
241fn unpack_heads(heads: u32) -> (u16, u16) {
242    ((heads >> 16) as u16, heads as u16)
243}
244/// Pack the `heads` value in a `LocalQueueInner` from its stealer head and real head.
245fn pack_heads(stealer: u16, real: u16) -> u32 {
246    (stealer as u32) << 16 | real as u32
247}
248
249/// One of the local queues in a [`Queue`].
250///
251/// You can create this using [`Queue::local_queues`].
252#[derive(Debug)]
253pub struct LocalQueue<T> {
254    /// Special slot that is always popped from first, to optimize for message passing where one
255    /// task is blocked on another.
256    lifo_slot: Option<T>,
257    local: ValidPtr<LocalQueueInner<T>>,
258    shared: Queue<T>,
259    /// Random number generator used to find which queue to start work stealing from.
260    rng: Rng,
261}
262
263impl<T> LocalQueue<T> {
264    /// Load the tail of the local queue.
265    fn local_tail(&mut self) -> u16 {
266        // SAFETY: The tail can be loaded without synchronization because only `self` can write to
267        // it, and we have an `&mut self`.
268        unsafe { facade::atomic_u16_unsync_load(&self.local.tail) }
269    }
270
271    /// Push an item to the local queue. If the local queue is full, it will move half of its items
272    /// to the global queue.
273    pub fn push(&mut self, item: T) {
274        if let Some(previous) = self.lifo_slot.replace(item) {
275            self.push_yield(previous);
276        }
277    }
278
279    /// Push an item to the local queue, skipping the LIFO slot. This can be used to give other
280    /// tasks a chance to run. Otherwise, there's a risk that one task will completely take over a
281    /// thread in a push-pop cycle due to the LIFO slot.
282    pub fn push_yield(&mut self, item: T) {
283        let tail = self.local_tail();
284
285        // We have to use Acquire to make sure that we don't write into memory that is
286        // currently being read by work stealers.
287        let mut heads = self.local.heads.load(atomic::Ordering::Acquire);
288
289        loop {
290            let (steal_head, head) = unpack_heads(heads);
291
292            // If the local queue is not full, we can simply push to that.
293            if tail.wrapping_sub(steal_head) < self.local.items.len() as u16 {
294                let i = tail & self.local.mask;
295
296                self.local.items[usize::from(i)]
297                    .with_mut(|slot| unsafe { slot.write(MaybeUninit::new(item)) });
298
299                // Release is necessary to make sure the above write is ordered before accesssing
300                // values.
301                self.local
302                    .tail
303                    .store(tail.wrapping_add(1), atomic::Ordering::Release);
304
305                return;
306            }
307
308            // If no threads are currently stealing, our overflowing local queue will not be
309            // drained, so we should push half of it to the global queue.
310            //
311            // Otherwise (when threads are stealing) we don't want to wait for them to finish, so
312            // we just push this single item to the global queue (but we don't need to push any
313            // more since we're about to become less full).
314            if steal_head == head {
315                let half = self.local.items.len() as u16 / 2;
316                // TODO: We could use compare_exchange_weak here, which may potentially improve
317                // performance.
318                let res = self.local.heads.compare_exchange(
319                    heads,
320                    pack_heads(head.wrapping_add(half), head.wrapping_add(half)),
321                    // Release is necessary to ensure that any previous writes to `tail` become
322                    // visible after this point. If not made visible, we could get into a situation
323                    // where a thread reads an older value for `tail` than `heads`, and if `heads`
324                    // has advanced beyond the old `tail` by that point it causes all sorts of
325                    // issues.
326                    atomic::Ordering::AcqRel,
327                    // Acquire is necessary because on failure we use the new value to update the
328                    // head (see the Acquire ordering above).
329                    atomic::Ordering::Acquire,
330                );
331
332                // Moving the head failed because another thread has just stolen some items. This
333                // means the queue is less full, so we can retry pushing to the local queue.
334                if let Err(new_heads) = res {
335                    heads = new_heads;
336                    continue;
337                }
338
339                // Push half the items in the current queue to the global queue.
340                for i in 0..half {
341                    let index = head.wrapping_add(i) & self.local.mask;
342                    let item = unsafe {
343                        self.local.items[usize::from(index)]
344                            .with(|slot| slot.read())
345                            .assume_init()
346                    };
347                    let _ = self.shared.0.global_queue.push(item);
348                }
349            }
350
351            let _ = self.shared.0.global_queue.push(item);
352
353            return;
354        }
355    }
356
357    /// Pop an item from the local queue, or steal from the global and sibling queues if it is
358    /// empty.
359    pub fn pop(&mut self) -> Option<T> {
360        // First try to pop from the LIFO slot.
361        if let Some(item) = self.lifo_slot.take() {
362            return Some(item);
363        }
364
365        let tail = self.local_tail();
366
367        // First try to pop from the local queue.
368        let res = atomic_u32_fetch_update(
369            &self.local.heads,
370            // No memory orderings are necessary here as this is the only thread that mutates
371            // the data, and it's not currently mutating the data.
372            atomic::Ordering::Relaxed,
373            atomic::Ordering::Relaxed,
374            |heads| {
375                let (steal_head, head) = unpack_heads(heads);
376                if head == tail {
377                    None
378                } else if steal_head == head {
379                    // There are no current stealers; update both heads.
380                    Some(pack_heads(head.wrapping_add(1), head.wrapping_add(1)))
381                } else {
382                    // There is currently a stealer; only update the real head, as it's the
383                    // stealer's job to update the stealer head later.
384                    Some(pack_heads(steal_head, head.wrapping_add(1)))
385                }
386            },
387        );
388
389        let heads = match res {
390            // We have successfully popped something from the local queue.
391            Ok(heads) => {
392                let (_, head) = unpack_heads(heads);
393                let i = head & self.local.mask;
394                return Some(unsafe {
395                    self.local.items[usize::from(i)]
396                        .with(|ptr| ptr.read())
397                        .assume_init()
398                });
399            }
400            // The local queue is empty.
401            Err(heads) => heads,
402        };
403        let (steal_head, head) = unpack_heads(heads);
404        assert_eq!(head, tail);
405
406        // The number of free slots in the queue we can steal into.
407        let space = self.local.items.len() as u16 - head.wrapping_sub(steal_head);
408
409        // Now we will try to steal into this queue from various places.
410
411        // TODO: Potentially throttle stealing?
412        self.shared
413            .0
414            .searchers
415            // No ordering is necessary because we use this as a hint, not for safety.
416            .fetch_add(1, atomic::Ordering::Relaxed);
417
418        struct DecrementSearchers<'a>(&'a AtomicUsize);
419        impl Drop for DecrementSearchers<'_> {
420            fn drop(&mut self) {
421                self.0.fetch_sub(1, atomic::Ordering::Relaxed);
422            }
423        }
424        let _decrement_searchers = DecrementSearchers(&self.shared.0.searchers);
425
426        // If there are no threads currently stealing from the global queue, we will steal from it.
427        if self
428            .shared
429            .0
430            .stealing_global
431            .compare_exchange(
432                false,
433                true,
434                // No ordering is necessary because we use this as a hint, not for safety.
435                atomic::Ordering::Relaxed,
436                atomic::Ordering::Relaxed,
437            )
438            .is_ok()
439        {
440            let popped_item = self.shared.0.global_queue.pop();
441
442            if popped_item.is_some() {
443                // To avoid having to search for items again after we have completed this one, we
444                // fill half of our queue with items from the global queue.
445
446                let steal = cmp::min(self.local.items.len() as u16 / 2, space);
447                let mut tail = head;
448                let end_tail = head.wrapping_add(steal);
449
450                // Ensure that the following mutations of the local queue of items will not occur
451                // before stealers have finished reading.
452                u32_acquire_fence(&self.local.heads);
453
454                while tail != end_tail {
455                    match self.shared.0.global_queue.pop() {
456                        Some(item) => {
457                            let i = tail & self.local.mask;
458                            self.local.items[usize::from(i)]
459                                .with_mut(|slot| unsafe { slot.write(MaybeUninit::new(item)) });
460                        }
461                        None => break,
462                    }
463                    tail = tail.wrapping_add(1);
464                }
465                // Release is necessary to make sure the above write is ordered before accesssing
466                // values.
467                self.local.tail.store(tail, atomic::Ordering::Release);
468            }
469
470            self.shared
471                .0
472                .stealing_global
473                .store(false, atomic::Ordering::Relaxed);
474
475            if let Some(popped_item) = popped_item {
476                return Some(popped_item);
477            }
478        }
479
480        // Steal work from sibling queues starting from a random location.
481        let queues = self.shared.0.local_queues.len();
482        let start = self.rng.gen_usize_to(queues);
483
484        'sibling_queues: for i in 0..queues {
485            let mut i = start + i;
486            if i >= queues {
487                i -= queues;
488            }
489
490            let queue = &self.shared.0.local_queues[i];
491            if ptr::eq(queue, &*self.local) {
492                continue;
493            }
494
495            // Acquire is necessary to make sure that the below load of the tail does not occur
496            // before the load of the head.
497            let mut queue_heads = queue.heads.load(atomic::Ordering::Acquire);
498
499            let (old_queue_head, mut queue_head, steal) = loop {
500                let (queue_steal_head, queue_head) = unpack_heads(queue_heads);
501
502                // If another thread is already stealing from this queue, don't steal from it.
503                if queue_steal_head != queue_head {
504                    continue 'sibling_queues;
505                }
506
507                // Acquire is necessary so we don't read into items that are currently being
508                // written by the thread itself.
509                let queue_tail = queue.tail.load(atomic::Ordering::Acquire);
510
511                // The number of items that can be stolen.
512                let stealable = queue_tail.wrapping_sub(queue_head);
513
514                // The number of items we actually want to steal - this is half of their queue,
515                // rounded up.
516                let steal = cmp::min(stealable - stealable / 2, space);
517
518                if steal == 0 {
519                    continue 'sibling_queues;
520                }
521
522                let new_queue_head = queue_head.wrapping_add(steal);
523
524                // TODO: We could use compare_exchange here, which may potentially improve
525                // performance.
526                let res = queue.heads.compare_exchange_weak(
527                    queue_heads,
528                    // Only move the real head, as we still need to keep the steal head to read
529                    // from the queue.
530                    pack_heads(queue_head, new_queue_head),
531                    // Release isn't necessary here since the above code doesn't access any memory.
532                    atomic::Ordering::Acquire,
533                    // Acquire is necessary when the compare_exchange fails because the result is
534                    // used to update the values in `queue_heads`; see the Acquire above.
535                    atomic::Ordering::Acquire,
536                );
537
538                match res {
539                    Ok(_) => break (queue_head, new_queue_head, steal),
540                    Err(updated_queue_heads) => queue_heads = updated_queue_heads,
541                }
542            };
543
544            assert_ne!(steal, 0);
545
546            // Ensure that the following mutations of the local queue of items will not occur
547            // before stealers of our own queue have finished reading.
548            u32_acquire_fence(&self.local.heads);
549
550            // Read the first item separately, as we will be returning it.
551            let first_item = unsafe {
552                queue.items[usize::from(old_queue_head & queue.mask)]
553                    .with(|slot| slot.read())
554                    .assume_init()
555            };
556
557            // Copy over the stolen items to our queue.
558            for i in 1..steal {
559                let src = &queue.items[usize::from(old_queue_head.wrapping_add(i) & queue.mask)];
560                let dst =
561                    &self.local.items[usize::from(head.wrapping_add(i - 1) & self.local.mask)];
562
563                src.with(|src| dst.with_mut(|dst| unsafe { src.copy_to_nonoverlapping(dst, 1) }))
564            }
565
566            // Update the steal head to match the real head.
567            loop {
568                let res = queue.heads.compare_exchange_weak(
569                    pack_heads(old_queue_head, queue_head),
570                    pack_heads(queue_head, queue_head),
571                    // Release is necessary to make sure the above reads are ordered before any
572                    // other thread can write to the values.
573                    atomic::Ordering::Release,
574                    // No ordering is necessary because we're not accessing shared mutable state
575                    // after this point.
576                    atomic::Ordering::Relaxed,
577                );
578
579                match res {
580                    Ok(_) => break,
581                    Err(updated_queue_heads) => {
582                        let (updated_queue_steal_head, update_queue_head) =
583                            unpack_heads(updated_queue_heads);
584                        assert_eq!(updated_queue_steal_head, old_queue_head);
585                        queue_head = update_queue_head;
586                    }
587                }
588            }
589
590            if steal > 1 {
591                // Release is necessary to make sure the above writes are ordered before accessing
592                // values.
593                self.local
594                    .tail
595                    .store(tail.wrapping_add(steal - 1), atomic::Ordering::Release);
596            }
597
598            return Some(first_item);
599        }
600
601        // Lastly, pop from the global queue without guarding against contention, since there is
602        // nowhere else we can currently get items from.
603        self.shared.0.global_queue.pop()
604    }
605
606    /// Get the number of threads that are currently searching for work inside [`pop`](Self::pop).
607    ///
608    /// If this number is too high, you may wish to avoid calling [`pop`](Self::pop) to reduce
609    /// contention.
610    #[must_use]
611    pub fn searchers(&self) -> usize {
612        self.shared.searchers()
613    }
614
615    /// Get the global queue that is associated with this local queue.
616    #[must_use]
617    pub fn global(&self) -> &Queue<T> {
618        &self.shared
619    }
620}
621
622/// An iterator over the [`LocalQueue`]s in a [`Queue`]. Created by [`Queue::local_queues`].
623#[derive(Debug)]
624#[must_use = "iterators are lazy and do nothing unless consumed"]
625pub struct LocalQueues<'a, T> {
626    shared: &'a Queue<T>,
627    index: usize,
628    hasher: DefaultHasher,
629}
630
631impl<T> Iterator for LocalQueues<'_, T> {
632    type Item = LocalQueue<T>;
633
634    fn next(&mut self) -> Option<Self::Item> {
635        let inner = self.shared.0.local_queues.get(self.index)?;
636        self.index += 1;
637
638        Some(LocalQueue {
639            lifo_slot: None,
640            // SAFETY: The `LocalQueue` stores an `Arc` so this pointer is guaranteed to be valid
641            // until the type is dropped.
642            local: unsafe { ValidPtr::new(inner) },
643            shared: self.shared.clone(),
644            rng: Rng {
645                state: {
646                    self.hasher.write_usize(self.index);
647                    self.hasher.finish()
648                },
649            },
650        })
651    }
652    fn size_hint(&self) -> (usize, Option<usize>) {
653        let len = self.len();
654        (len, Some(len))
655    }
656}
657
658impl<T> ExactSizeIterator for LocalQueues<'_, T> {
659    fn len(&self) -> usize {
660        self.shared.0.local_queues.len() - self.index
661    }
662}
663
664impl<T> FusedIterator for LocalQueues<'_, T> {}
665
666/// A `*const T` that is guaranteed to always be valid and non-null.
667struct ValidPtr<T: ?Sized>(NonNull<T>);
668impl<T: ?Sized> ValidPtr<T> {
669    unsafe fn new(ptr: *const T) -> Self {
670        Self(NonNull::new_unchecked(ptr as *mut T))
671    }
672}
673impl<T: ?Sized> Clone for ValidPtr<T> {
674    fn clone(&self) -> Self {
675        *self
676    }
677}
678impl<T: ?Sized> Copy for ValidPtr<T> {}
679impl<T: ?Sized> Deref for ValidPtr<T> {
680    type Target = T;
681    fn deref(&self) -> &Self::Target {
682        unsafe { self.0.as_ref() }
683    }
684}
685impl<T: ?Sized + Debug> Debug for ValidPtr<T> {
686    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
687        T::fmt(self, f)
688    }
689}
690unsafe impl<T: ?Sized + Sync> Send for ValidPtr<T> {}
691unsafe impl<T: ?Sized + Sync> Sync for ValidPtr<T> {}
692
693#[cfg(target_pointer_width = "64")]
694type DoubleUsize = u128;
695#[cfg(target_pointer_width = "32")]
696type DoubleUsize = u64;
697
698/// Wyrand RNG.
699#[derive(Debug)]
700struct Rng {
701    state: u64,
702}
703impl Rng {
704    fn gen_u64(&mut self) -> u64 {
705        self.state = self.state.wrapping_add(0xA0761D6478BD642F);
706        let t = u128::from(self.state) * u128::from(self.state ^ 0xE7037ED1A0B428DB);
707        (t >> 64) as u64 ^ t as u64
708    }
709    fn gen_usize(&mut self) -> usize {
710        self.gen_u64() as usize
711    }
712    fn gen_usize_to(&mut self, to: usize) -> usize {
713        // Adapted from https://www.pcg-random.org/posts/bounded-rands.html
714        const USIZE_BITS: usize = mem::size_of::<usize>() * 8;
715
716        let mut x = self.gen_usize();
717        let mut m = ((x as DoubleUsize * to as DoubleUsize) >> USIZE_BITS) as usize;
718        let mut l = x.wrapping_mul(to);
719        if l < to {
720            let t = to.wrapping_neg() % to;
721            while l < t {
722                x = self.gen_usize();
723                m = ((x as DoubleUsize * to as DoubleUsize) >> USIZE_BITS) as usize;
724                l = x.wrapping_mul(to);
725            }
726        }
727        m
728    }
729}
730
731fn atomic_u32_fetch_update<F>(
732    atomic: &AtomicU32,
733    set_order: atomic::Ordering,
734    fetch_order: atomic::Ordering,
735    mut f: F,
736) -> Result<u32, u32>
737where
738    F: FnMut(u32) -> Option<u32>,
739{
740    let mut prev = atomic.load(fetch_order);
741    while let Some(next) = f(prev) {
742        match atomic.compare_exchange_weak(prev, next, set_order, fetch_order) {
743            Ok(x) => return Ok(x),
744            Err(next_prev) => prev = next_prev,
745        }
746    }
747    Err(prev)
748}
749
750fn u32_acquire_fence(atomic: &AtomicU32) {
751    if cfg!(tsan) {
752        // ThreadSanitizer doesn't support fences.
753        atomic.load(atomic::Ordering::Acquire);
754    } else {
755        atomic::fence(atomic::Ordering::Acquire);
756    }
757}
758
759#[cfg(all(test, not(loom)))]
760mod tests {
761    use super::*;
762
763    use std::collections::HashSet;
764
765    #[test]
766    fn rng() {
767        let mut rng = Rng { state: 3493858 };
768
769        let mut remaining: HashSet<_> = (0..15).collect();
770
771        while !remaining.is_empty() {
772            let value = rng.gen_usize_to(15);
773            assert!(value < 15, "{} is not less than 15!", value);
774            remaining.remove(&value);
775        }
776    }
777
778    #[test]
779    fn lifo_slot() {
780        let queue = Queue::new(1, 2);
781        let mut local = queue.local_queues().next().unwrap();
782
783        assert_eq!(local.pop(), None);
784        assert_eq!(local.pop(), None);
785
786        local.push(Box::new(5));
787        assert_eq!(local.pop(), Some(Box::new(5)));
788        assert_eq!(local.pop(), None);
789    }
790
791    #[test]
792    fn push_many() {
793        let queue = Queue::new(1, 2);
794        let mut local = queue.local_queues().next().unwrap();
795
796        for i in 0..4 {
797            local.push(Box::new(i));
798        }
799        assert_eq!(local.pop(), Some(Box::new(3)));
800        assert_eq!(local.pop(), Some(Box::new(1)));
801        assert_eq!(local.pop(), Some(Box::new(0)));
802        assert_eq!(local.pop(), Some(Box::new(2)));
803        assert_eq!(local.pop(), None);
804    }
805
806    #[test]
807    fn wrapping() {
808        let queue = Queue::new(1, 2);
809        let mut local = queue.local_queues().next().unwrap();
810
811        local.push_yield(Box::new(0));
812
813        for i in 0..10 {
814            local.push_yield(Box::new(i + 1));
815
816            assert_eq!(local.pop(), Some(Box::new(i)));
817        }
818
819        assert_eq!(local.pop(), Some(Box::new(10)));
820        assert_eq!(local.pop(), None);
821        assert_eq!(local.pop(), None);
822    }
823
824    #[test]
825    fn steal_global() {
826        for &size in &[2, 4, 8, 16, 32, 64] {
827            let queue = Queue::new(4, size);
828
829            for i in 0..16 {
830                queue.push(Box::new(i));
831            }
832
833            let mut local = queue.local_queues().next().unwrap();
834
835            for i in 0..16 {
836                assert_eq!(local.pop(), Some(Box::new(i)));
837            }
838
839            assert_eq!(local.pop(), None);
840        }
841    }
842
843    #[test]
844    fn steal_siblings() {
845        let queue = Queue::new(2, 64);
846
847        let mut locals: Vec<_> = queue.local_queues().collect();
848
849        locals[0].push_yield(Box::new(4));
850        locals[0].push_yield(Box::new(5));
851
852        locals[1].push(Box::new(1));
853        locals[1].push(Box::new(0));
854
855        queue.push(Box::new(2));
856        queue.push(Box::new(3));
857
858        for i in 0..6 {
859            assert_eq!(locals[1].pop(), Some(Box::new(i)));
860        }
861    }
862
863    #[test]
864    fn many_locals() {
865        let queue = <Queue<()>>::new(10, 128);
866        queue.local_queues().for_each(drop);
867    }
868
869    #[test]
870    fn searchers() {
871        let queue = Queue::new(2, 64);
872        let mut locals = queue.local_queues();
873        let mut local_a = locals.next().unwrap();
874        let mut local_b = locals.next().unwrap();
875
876        assert_eq!(local_a.searchers(), 0);
877        assert_eq!(local_b.searchers(), 0);
878
879        local_a.push(());
880        local_a.push(());
881        local_a.pop().unwrap();
882        local_a.pop().unwrap();
883        queue.push(());
884        local_b.pop().unwrap();
885        assert!(local_b.pop().is_none());
886
887        assert_eq!(local_a.searchers(), 0);
888        assert_eq!(local_b.searchers(), 0);
889
890        // This test hangs on Miri.
891        if cfg!(not(miri)) {
892            let stop = Arc::new(AtomicBool::new(false));
893
894            let handle = std::thread::spawn({
895                let stop = Arc::clone(&stop);
896                move || {
897                    while !stop.load(atomic::Ordering::Relaxed) {
898                        local_b.pop();
899                    }
900                }
901            });
902
903            loop {
904                let searchers = local_a.searchers();
905                assert!(searchers < 2);
906                if searchers == 1 {
907                    break;
908                }
909            }
910
911            stop.store(true, atomic::Ordering::Relaxed);
912            handle.join().unwrap();
913        }
914    }
915
916    #[test]
917    fn stress() {
918        let queue = Queue::new(4, 4);
919
920        if cfg!(miri) {
921            for _ in 0..3 {
922                queue.push(4);
923            }
924        } else {
925            for _ in 0..32 {
926                queue.push(6);
927            }
928        }
929
930        let threads: Vec<_> = queue
931            .local_queues()
932            .map(|mut queue| {
933                std::thread::spawn(move || {
934                    while let Some(num) = queue.pop() {
935                        for _ in 0..num {
936                            queue.push(num - 1);
937                        }
938                    }
939                })
940            })
941            .collect();
942
943        for thread in threads {
944            thread.join().unwrap();
945        }
946    }
947
948    #[test]
949    fn cobb() {
950        use std::cell::UnsafeCell;
951
952        struct State(Option<Box<[UnsafeCell<LocalQueue<Box<i32>>>]>>);
953        unsafe impl Sync for State {}
954
955        cobb::run_test(cobb::TestCfg {
956            threads: 4,
957            iterations: if cfg!(miri) { 100 } else { 1000 },
958            sub_iterations: if cfg!(miri) { 1 } else { 10 },
959            setup: || {
960                let queue = Queue::new(4, 4);
961                State(Some(
962                    queue
963                        .local_queues()
964                        .map(UnsafeCell::new)
965                        .collect::<Box<[_]>>(),
966                ))
967            },
968            test: |State(state), tctx| {
969                let local_queues = state.as_ref().unwrap();
970                let queue = unsafe { &mut *local_queues[tctx.thread_index()].get() };
971                if tctx.thread_index() < 2 {
972                    queue.push(Box::new(5));
973                } else {
974                    queue.pop();
975                }
976            },
977            teardown: |state| *state = State(None),
978            ..Default::default()
979        });
980    }
981}
982
983#[cfg(all(test, loom))]
984mod loom_tests {
985    use super::*;
986
987    fn locals<T, const N: usize>(queue: &Queue<T>) -> [LocalQueue<T>; N] {
988        array_init::from_iter(queue.local_queues()).expect("incorrect number of local queues")
989    }
990
991    #[test]
992    fn pop_none() {
993        loom::model(|| {
994            let queue: Queue<()> = Queue::new(2, 1);
995            let [mut local_1, mut local_2] = locals(&queue);
996            loom::thread::spawn(move || assert!(local_1.pop().is_none()));
997            assert!(local_2.pop().is_none());
998        });
999    }
1000
1001    #[test]
1002    fn concurrent_steal_global() {
1003        loom::model(|| {
1004            let queue: Queue<Box<i32>> = Queue::new(2, 1);
1005            let [mut local_1, mut local_2] = locals(&queue);
1006            for i in 0..2 {
1007                queue.push(Box::new(i));
1008            }
1009            loom::thread::spawn(move || {
1010                local_1.pop();
1011                local_1.pop();
1012            });
1013            local_2.pop();
1014        });
1015    }
1016
1017    #[test]
1018    fn concurrent_steal_sibling() {
1019        loom::model(|| {
1020            let queue: Queue<Box<i32>> = Queue::new(3, 1);
1021            let [mut local_1, mut local_2, mut local_3] = locals(&queue);
1022            for i in 0..4 {
1023                local_1.push(Box::new(i));
1024            }
1025            loom::thread::spawn(move || {
1026                local_2.pop();
1027                local_2.pop();
1028            });
1029            local_3.pop();
1030        });
1031    }
1032
1033    #[test]
1034    fn global_spsc() {
1035        loom::model(|| {
1036            let queue: Queue<Box<i32>> = Queue::new(1, 4);
1037            let [mut local] = locals(&queue);
1038            loom::thread::spawn(move || {
1039                for i in 0..6 {
1040                    queue.push(Box::new(i));
1041                }
1042            });
1043            for _ in 0..6 {
1044                local.pop();
1045            }
1046        });
1047    }
1048
1049    #[test]
1050    fn sibling_spsc_few() {
1051        loom::model(|| {
1052            let queue: Queue<Box<i32>> = Queue::new(2, 4);
1053            let [mut local_1, mut local_2] = locals(&queue);
1054            loom::thread::spawn(move || {
1055                for i in 0..4 {
1056                    local_1.push(Box::new(i));
1057                }
1058            });
1059            for _ in 0..4 {
1060                local_2.pop();
1061            }
1062        });
1063    }
1064
1065    #[test]
1066    fn sibling_spsc_many() {
1067        loom::model(|| {
1068            let queue: Queue<Box<i32>> = Queue::new(2, 4);
1069            let [mut local_1, mut local_2] = locals(&queue);
1070            loom::thread::spawn(move || {
1071                for i in 0..8 {
1072                    local_1.push(Box::new(i));
1073                }
1074            });
1075            local_2.pop();
1076        });
1077    }
1078}