swap_queue/
lib.rs

1//!
2//! A lock-free thread-owned queue whereby tasks are taken by stealers in entirety via buffer swapping. This is meant to be used [`thread_local`] paired with [`tokio::task::spawn`] as a constant-time take-all batching mechanism that outperforms [`crossbeam_deque::Worker`](https://docs.rs/crossbeam-deque/0.8.1/crossbeam_deque/struct.Worker.html), and [`tokio::sync::mpsc`] for batching.
3//!
4//! ## Example
5//!
6//! ```
7//! use swap_queue::Worker;
8//! use tokio::{
9//!   runtime::Handle,
10//!   sync::oneshot::{channel, Sender},
11//! };
12//!
13//! // Jemalloc makes this library substantially faster
14//! #[global_allocator]
15//! static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc;
16//!
17//! // Worker needs to be thread local because it is !Sync
18//! thread_local! {
19//!   static QUEUE: Worker<(u64, Sender<u64>)> = Worker::new();
20//! }
21//!
22//! // This mechanism will batch optimally without overhead within an async-context because spawn will happen after things already scheduled
23//! async fn push_echo(i: u64) -> u64 {
24//!   {
25//!     let (tx, rx) = channel();
26//!
27//!     QUEUE.with(|queue| {
28//!       // A new stealer is returned whenever the buffer is new or was empty
29//!       if let Some(stealer) = queue.push((i, tx)) {
30//!         Handle::current().spawn(async move {
31//!           // Take the underlying buffer in entirety; the next push will return a new Stealer
32//!           let batch = stealer.take().await;
33//!
34//!           // Some sort of batched operation, such as a database query
35//!
36//!           batch.into_iter().for_each(|(i, tx)| {
37//!             tx.send(i).ok();
38//!           });
39//!         });
40//!       }
41//!     });
42//!
43//!     rx
44//!   }
45//!   .await
46//!   .unwrap()
47//! }
48//! ```
49
50use crossbeam_epoch::{self as epoch, Atomic, Owned};
51use crossbeam_utils::CachePadded;
52
53use futures::executor::block_on;
54use std::{cell::Cell, fmt, marker::PhantomData, mem, ptr, sync::Arc};
55use tokio::sync::oneshot::{channel, Receiver, Sender};
56
57#[cfg(loom)]
58use loom::sync::atomic::{AtomicUsize, Ordering};
59
60#[cfg(not(loom))]
61use std::sync::atomic::{AtomicUsize, Ordering};
62
63// Current buffer index
64const BUFFER_IDX: usize = 1 << 0;
65
66// Designates that write is in progress
67const WRITE_IN_PROGRESS: usize = 1 << 1;
68
69// Designates how many bits are set aside for flags
70const FLAGS_SHIFT: usize = 1;
71
72// Slot increments both for reads and writes, therefore we shift slot an extra bit to extract length
73const LENGTH_SHIFT: usize = FLAGS_SHIFT + 1;
74
75// Minimum buffer capacity.
76const MIN_CAP: usize = 64;
77
78/// A buffer that holds tasks in a worker queue.
79///
80/// This is just a pointer to the buffer and its length - dropping an instance of this struct will
81/// *not* deallocate the buffer.
82struct Buffer<T> {
83  /// Slot that represents the index offset and buffer idx
84  slot: usize,
85
86  /// Pointer to the allocated memory.
87  ptr: *mut T,
88
89  /// Capacity of the buffer. Always a power of two.
90  cap: usize,
91}
92
93unsafe impl<T: Send> Send for Buffer<T> {}
94unsafe impl<T: Send> Sync for Buffer<T> {}
95
96impl<T> Buffer<T> {
97  /// Allocates a new buffer with the specified capacity.
98  fn alloc(slot: usize, cap: usize) -> Buffer<T> {
99    debug_assert_eq!(cap, cap.next_power_of_two());
100
101    let mut v = Vec::with_capacity(cap);
102    let ptr = v.as_mut_ptr();
103    mem::forget(v);
104
105    Buffer { slot, ptr, cap }
106  }
107
108  /// Deallocates the buffer.
109  unsafe fn dealloc(self) {
110    drop(Vec::from_raw_parts(self.ptr, 0, self.cap));
111  }
112
113  /// Returns a pointer to the task at the specified `index`.
114  unsafe fn at(&self, index: usize) -> *mut T {
115    // `self.cap` is always a power of two.
116    self.ptr.offset((index & (self.cap - 1)) as isize)
117  }
118
119  /// Writes `task` into the specified `index`.
120  unsafe fn write(&self, index: usize, task: T) {
121    ptr::write_volatile(self.at(index), task)
122  }
123
124  unsafe fn to_vec(self, length: usize) -> Vec<T> {
125    let Buffer { ptr, cap, .. } = self;
126    Vec::from_raw_parts(ptr, length, cap)
127  }
128}
129
130impl<T> Clone for Buffer<T> {
131  fn clone(&self) -> Buffer<T> {
132    Buffer {
133      slot: self.slot,
134      ptr: self.ptr,
135      cap: self.cap,
136    }
137  }
138}
139
140impl<T> Copy for Buffer<T> {}
141
142fn slot_delta(a: usize, b: usize) -> usize {
143  if a < b {
144    ((usize::MAX - b) >> LENGTH_SHIFT) + (a >> LENGTH_SHIFT)
145  } else {
146    (a >> LENGTH_SHIFT) - (b >> LENGTH_SHIFT)
147  }
148}
149
150struct Inner<T> {
151  slot: AtomicUsize,
152  buffers: (
153    CachePadded<Atomic<Buffer<T>>>,
154    CachePadded<Atomic<Buffer<T>>>,
155  ),
156}
157
158impl<T> Inner<T> {
159  fn get_buffer(&self, slot: usize) -> &CachePadded<Atomic<Buffer<T>>> {
160    if slot & BUFFER_IDX == 0 {
161      &self.buffers.0
162    } else {
163      &self.buffers.1
164    }
165  }
166}
167
168/// A thread-owned worker queue that writes to a swappable buffer using atomic slotting
169///
170/// # Examples
171///
172/// ```
173/// use swap_queue::Worker;
174///
175/// let w = Worker::new();
176/// let s = w.push(1).unwrap();
177/// w.push(2);
178/// w.push(3);
179/// // this is non-blocking because it's called on the same thread as Worker; a write in progress is not possible
180/// assert_eq!(s.take_blocking(), vec![1, 2, 3]);
181///
182/// let s = w.push(4).unwrap();
183/// w.push(5);
184/// w.push(6);
185/// // this is identical to [`Stealer::take_blocking`]
186/// let batch: Vec<_> = s.into();
187/// assert_eq!(batch, vec![4, 5, 6]);
188/// ```
189
190enum Flavor {
191  Unbounded,
192  AutoBatched { batch_size: usize },
193}
194
195pub struct Worker<T> {
196  flavor: Flavor,
197  /// A reference to the inner representation of the queue.
198  inner: Arc<CachePadded<Inner<T>>>,
199  /// A copy of `inner.buffer` for quick access.
200  buffer: Cell<Buffer<T>>,
201  /// Send handle corresponding to the current Stealer
202  tx: Cell<Option<Sender<Vec<T>>>>,
203  /// Indicates that the worker cannot be shared among threads.
204  _marker: PhantomData<*mut ()>,
205}
206
207unsafe impl<T: Send> Send for Worker<T> {}
208
209impl<T> Worker<T> {
210  /// Creates a new Worker queue.
211  ///
212  /// # Examples
213  ///
214  /// ```
215  /// use swap_queue::Worker;
216  ///
217  /// let w = Worker::<i32>::new();
218  /// ```
219  pub fn new() -> Worker<T> {
220    let buffer = Buffer::alloc(0, MIN_CAP);
221
222    let inner = Arc::new(CachePadded::new(Inner {
223      slot: AtomicUsize::new(0),
224      buffers: (
225        CachePadded::new(Atomic::new(buffer)),
226        CachePadded::new(Atomic::null()),
227      ),
228    }));
229
230    Worker {
231      flavor: Flavor::Unbounded,
232      inner,
233      buffer: Cell::new(buffer),
234      tx: Cell::new(None),
235      _marker: PhantomData,
236    }
237  }
238
239  /// Creates an auto-batched Worker queue with fixed-length buffers. At capacity, the buffer is swapped out and ownership taken by the returned Stealer. Batch size must be a power of 2
240  ///
241  /// # Examples
242  ///
243  /// ```
244  /// use swap_queue::Worker;
245  ///
246  /// let w = Worker::<i32>::auto_batched(64);
247  /// ```
248  pub fn auto_batched(batch_size: usize) -> Worker<T> {
249    debug_assert!(batch_size.ge(&64), "batch_size must be at least 64");
250    debug_assert_eq!(
251      batch_size,
252      batch_size.next_power_of_two(),
253      "batch_size must be a power of 2"
254    );
255
256    let buffer = Buffer::alloc(0, MIN_CAP);
257
258    let inner = Arc::new(CachePadded::new(Inner {
259      slot: AtomicUsize::new(0),
260      buffers: (
261        CachePadded::new(Atomic::new(buffer)),
262        CachePadded::new(Atomic::null()),
263      ),
264    }));
265
266    Worker {
267      flavor: Flavor::AutoBatched { batch_size },
268      inner,
269      buffer: Cell::new(buffer),
270      tx: Cell::new(None),
271      _marker: PhantomData,
272    }
273  }
274
275  /// Resizes the internal buffer to the new capacity of `new_cap`.
276  unsafe fn resize(&self, buffer: &mut Buffer<T>, slot: usize) {
277    let length = slot_delta(slot, buffer.slot);
278
279    // Allocate a new buffer and copy data from the old buffer to the new one.
280    let new = Buffer::alloc(buffer.slot, buffer.cap * 2);
281
282    ptr::copy_nonoverlapping(buffer.at(0), new.at(0), length);
283
284    self.buffer.set(new);
285
286    let old = std::mem::replace(buffer, new);
287
288    self
289      .inner
290      .get_buffer(slot)
291      .store(Owned::new(new), Ordering::Release);
292
293    old.dealloc();
294  }
295
296  fn replace_buffer(&self, buffer: &mut Buffer<T>, slot: usize, cap: usize) -> Buffer<T> {
297    let new = Buffer::alloc(slot.to_owned(), cap);
298
299    self
300      .inner
301      .get_buffer(slot)
302      .store(Owned::new(new), Ordering::Release);
303
304    self.buffer.set(new);
305
306    std::mem::replace(buffer, new)
307  }
308
309  /// Write to the next slot, swapping buffers as necessary and returning a Stealer at the start of a new batch
310  pub fn push(&self, task: T) -> Option<Stealer<T>> {
311    let slot = self
312      .inner
313      .slot
314      .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
315
316    let mut buffer = self.buffer.get();
317
318    // BUFFER_IDX bit changed, therefore buffer was stolen
319    if ((slot ^ buffer.slot) & BUFFER_IDX).eq(&BUFFER_IDX) {
320      buffer = Buffer::alloc(slot, buffer.cap);
321
322      self
323        .inner
324        .get_buffer(slot)
325        .store(Owned::new(buffer), Ordering::Release);
326
327      self.buffer.set(buffer);
328
329      unsafe {
330        buffer.write(0, task);
331      }
332
333      // There can be no stealer at this point, so no need to check IDX XOR
334      self
335        .inner
336        .slot
337        .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
338
339      let (tx, rx) = channel();
340      self.tx.set(Some(tx));
341
342      Some(Stealer::Taker(StealHandle {
343        rx,
344        inner: self.inner.clone(),
345      }))
346    } else {
347      let index = slot_delta(slot, buffer.slot);
348
349      match &self.flavor {
350        Flavor::Unbounded if index.eq(&buffer.cap) => {
351          unsafe {
352            self.resize(&mut buffer, slot);
353            buffer.write(index, task);
354          }
355
356          let slot = self
357            .inner
358            .slot
359            .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
360
361          // Stealer expressed intention to take buffer by changing the buffer index, and is waiting on Worker to send buffer upon completion of the current write in progress
362          if ((slot ^ buffer.slot) & BUFFER_IDX).eq(&BUFFER_IDX) {
363            let (tx, rx) = channel();
364            let tx = self.tx.replace(Some(tx)).unwrap();
365
366            // Send buffer as vec to receiver
367            tx.send(unsafe { buffer.to_vec(index) }).ok();
368
369            Some(Stealer::Taker(StealHandle {
370              rx,
371              inner: self.inner.clone(),
372            }))
373          } else {
374            None
375          }
376        }
377        Flavor::AutoBatched { batch_size } if index.eq(batch_size) => {
378          let old = self.replace_buffer(&mut buffer, slot, *batch_size);
379          let batch = unsafe { old.to_vec(*batch_size) };
380
381          unsafe {
382            buffer.write(0, task);
383          }
384
385          let slot = self
386            .inner
387            .slot
388            .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
389
390          if ((slot ^ buffer.slot) & BUFFER_IDX).eq(&BUFFER_IDX) {
391            let (tx, rx) = channel();
392            let tx = self.tx.replace(Some(tx)).unwrap();
393
394            tx.send(batch).ok();
395
396            Some(Stealer::Taker(StealHandle {
397              rx,
398              inner: self.inner.clone(),
399            }))
400          } else {
401            Some(Stealer::Owner(batch))
402          }
403        }
404        _ if index.eq(&0) => {
405          unsafe {
406            buffer.write(0, task);
407          }
408
409          self
410            .inner
411            .slot
412            .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
413
414          let (tx, rx) = channel();
415          self.tx.set(Some(tx));
416
417          Some(Stealer::Taker(StealHandle {
418            rx,
419            inner: self.inner.clone(),
420          }))
421        }
422        _ => {
423          unsafe {
424            buffer.write(index, task);
425          }
426
427          let slot = self
428            .inner
429            .slot
430            .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
431
432          if ((slot ^ buffer.slot) & BUFFER_IDX).eq(&BUFFER_IDX) {
433            let (tx, rx) = channel();
434            let tx = self.tx.replace(Some(tx)).unwrap();
435
436            // Send buffer as vec to receiver
437            tx.send(unsafe { buffer.to_vec(index) }).ok();
438
439            Some(Stealer::Taker(StealHandle {
440              rx,
441              inner: self.inner.clone(),
442            }))
443          } else {
444            None
445          }
446        }
447      }
448    }
449  }
450}
451
452impl<T> Default for Worker<T> {
453  fn default() -> Self {
454    Self::new()
455  }
456}
457
458impl<T> fmt::Debug for Worker<T> {
459  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460    f.pad("Worker { .. }")
461  }
462}
463
464impl<T> Drop for Worker<T> {
465  fn drop(&mut self) {
466    // By leaving this as indefinitely write in progress the Stealer will always receive from the oneshot::Sender
467    let slot = self
468      .inner
469      .slot
470      .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
471
472    let buffer = self.buffer.get();
473
474    // Is buffer still current? (If not Stealer has already taken buffer)
475    if slot & BUFFER_IDX == buffer.slot & BUFFER_IDX {
476      let length = slot_delta(slot, buffer.slot);
477
478      // Send to Stealer if able
479      if let Some(tx) = self.tx.replace(None) {
480        if let Err(queue) = tx.send(unsafe { buffer.to_vec(length) }) {
481          drop(queue);
482        }
483      } else {
484        // Otherwise deallocate everything
485        unsafe {
486          // Go through the buffer from front to back and drop all tasks in the queue.
487          for i in 0..length {
488            buffer.at(i).drop_in_place();
489          }
490
491          // Free the memory allocated by the buffer.
492          buffer.dealloc();
493        }
494      }
495    }
496  }
497}
498
499#[doc(hidden)]
500pub struct StealHandle<T> {
501  /// Buffer receiver to be used when waiting on writes
502  rx: Receiver<Vec<T>>,
503  /// A reference to the inner representation of the queue.
504  inner: Arc<CachePadded<Inner<T>>>,
505}
506
507/// Stealers swap out and take ownership of buffers in entirety from Workers
508pub enum Stealer<T> {
509  /// Stealer was created with an owned batch that can simply be unwrapped
510  Owner(Vec<T>),
511  /// A Steal Handle buffer swaps either by taking the buffer directly or by awaiting the Worker to send on write completion
512  Taker(StealHandle<T>),
513}
514
515unsafe impl<T: Send> Send for Stealer<T> {}
516unsafe impl<T: Send> Sync for Stealer<T> {}
517
518impl<T> Stealer<T> {
519  /// Take the entire queue by swapping the underlying buffer and converting back into a `Vec<T>` or by waiting to receive the buffer from the Worker if a write was in progress.
520  pub async fn take(self) -> Vec<T> {
521    match self {
522      Stealer::Owner(batch) => batch,
523      Stealer::Taker(StealHandle { rx, inner }) => {
524        let slot = inner.slot.fetch_xor(BUFFER_IDX, Ordering::Relaxed);
525
526        // Worker will see the buffer has swapped when confirming length increment
527        if slot & WRITE_IN_PROGRESS == WRITE_IN_PROGRESS {
528          // Writer can never be dropped mid-write, therefore RecvError cannot occur
529          rx.await.unwrap()
530        } else {
531          let guard = &epoch::pin();
532
533          let buffer = inner.get_buffer(slot).load_consume(guard);
534
535          unsafe {
536            let buffer = *buffer.into_owned();
537            buffer.to_vec(slot_delta(slot, buffer.slot))
538          }
539        }
540      }
541    }
542  }
543
544  /// Take the entire queue by swapping the underlying buffer and converting into a `Vec<T>` or by blocking to receive from the Worker if a write was in progress. This is always non-blocking when called on the same thread as the Worker
545  pub fn take_blocking(self) -> Vec<T> {
546    match self {
547      Stealer::Owner(batch) => batch,
548      Stealer::Taker(StealHandle { rx, inner }) => {
549        let slot = inner.slot.fetch_xor(BUFFER_IDX, Ordering::Relaxed);
550
551        // Worker will see the buffer has swapped when confirming length increment
552        // It's not possible for this to be write in progress when called from the same thread as the queue
553        if slot & WRITE_IN_PROGRESS == WRITE_IN_PROGRESS {
554          // Writer can never be dropped mid-write, therefore RecvError cannot occur
555          block_on(rx).unwrap()
556        } else {
557          let guard = &epoch::pin();
558
559          let buffer = inner.get_buffer(slot).load_consume(guard);
560
561          unsafe {
562            let buffer = *buffer.into_owned();
563            buffer.to_vec(slot_delta(slot, buffer.slot))
564          }
565        }
566      }
567    }
568  }
569}
570
571/// Uses [`Stealer::take_blocking`]; non-blocking when called on the same thread as Worker
572impl<T> From<Stealer<T>> for Vec<T> {
573  fn from(stealer: Stealer<T>) -> Self {
574    stealer.take_blocking()
575  }
576}
577
578impl<T> fmt::Debug for Stealer<T> {
579  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
580    f.pad("Stealer { .. }")
581  }
582}
583
584#[cfg(all(test))]
585mod tests {
586  use super::*;
587
588  #[cfg(loom)]
589  use loom::thread;
590
591  #[cfg(not(loom))]
592  use std::thread;
593
594  macro_rules! model {
595    ($test:block) => {
596      #[cfg(loom)]
597      loom::model(|| $test);
598
599      #[cfg(not(loom))]
600      $test
601    };
602  }
603
604  #[test]
605  fn slot_wraps_around() {
606    let delta = slot_delta(1 << LENGTH_SHIFT, usize::MAX);
607
608    assert_eq!(delta, 1);
609  }
610
611  #[test]
612  fn it_resizes() {
613    model!({
614      let queue = Worker::new();
615      let stealer = queue.push(0).unwrap();
616
617      for i in 1..128 {
618        queue.push(i);
619      }
620
621      let batch = stealer.take_blocking();
622      let expected = (0..128).collect::<Vec<i32>>();
623
624      assert_eq!(batch, expected);
625    });
626  }
627
628  #[test]
629  fn it_makes_new_stealer_per_batch() {
630    model!({
631      let queue = Worker::new();
632      let stealer = queue.push(0).unwrap();
633
634      queue.push(1);
635      queue.push(2);
636
637      assert_eq!(stealer.take_blocking(), vec![0, 1, 2]);
638
639      let stealer = queue.push(3).unwrap();
640      queue.push(4);
641      queue.push(5);
642
643      assert_eq!(stealer.take_blocking(), vec![3, 4, 5]);
644    });
645  }
646
647  #[test]
648  fn it_auto_batches() {
649    model!({
650      let queue = Worker::auto_batched(64);
651      let mut stealers: Vec<Stealer<i32>> = vec![];
652
653      for i in 0..128 {
654        if let Some(stealer) = queue.push(i) {
655          stealers.push(stealer);
656        }
657      }
658
659      let batch: Vec<i32> = stealers
660        .into_iter()
661        .rev()
662        .flat_map(|stealer| stealer.take_blocking())
663        .collect();
664
665      let expected = (0..128).collect::<Vec<i32>>();
666
667      assert_eq!(batch, expected);
668    });
669  }
670
671  #[cfg(not(loom))]
672  #[tokio::test]
673  async fn stealer_takes() {
674    let queue = Worker::new();
675    let stealer = queue.push(0).unwrap();
676
677    for i in 1..1024 {
678      queue.push(i);
679    }
680
681    let batch = stealer.take().await;
682    let expected = (0..1024).collect::<Vec<i32>>();
683
684    assert_eq!(batch, expected);
685  }
686
687  #[test]
688  fn stealer_takes_blocking() {
689    model!({
690      let queue = Worker::new();
691      let stealer = queue.push(0).unwrap();
692
693      for i in 1..128 {
694        queue.push(i);
695      }
696
697      thread::spawn(move || {
698        stealer.take_blocking();
699      })
700      .join()
701      .unwrap();
702    });
703  }
704
705  #[cfg(not(loom))]
706  #[tokio::test]
707  async fn worker_drops() {
708    let queue = Worker::new();
709    let stealer = queue.push(0).unwrap();
710
711    for i in 1..128 {
712      queue.push(i);
713    }
714
715    drop(queue);
716
717    let batch = stealer.take().await;
718    let expected = (0..128).collect::<Vec<i32>>();
719
720    assert_eq!(batch, expected);
721  }
722
723  #[cfg(loom)]
724  #[tokio::test]
725  async fn worker_drops() {
726    loom::model(|| {
727      let queue = Worker::new();
728      let stealer = queue.push(0).unwrap();
729
730      for i in 1..128 {
731        queue.push(i);
732      }
733
734      drop(queue);
735
736      let batch = stealer.take_blocking();
737      let expected = (0..128).collect::<Vec<i32>>();
738
739      assert_eq!(batch, expected);
740    });
741  }
742}