stack_queue/
queue.rs

1use std::{
2  array,
3  fmt::Debug,
4  future::Future,
5  mem::MaybeUninit,
6  ops::{BitAnd, Deref},
7  task::Waker,
8};
9#[cfg(not(loom))]
10use std::{
11  cell::UnsafeCell,
12  sync::atomic::{AtomicUsize, Ordering},
13  thread::LocalKey,
14};
15
16use async_local::{AsContext, Context};
17use crossbeam_deque::{Steal, Stealer, Worker};
18use crossbeam_utils::CachePadded;
19#[cfg(loom)]
20use loom::{
21  cell::UnsafeCell,
22  sync::atomic::{AtomicUsize, Ordering},
23  thread::LocalKey,
24};
25#[cfg(not(loom))]
26use tokio::task::spawn;
27
28use crate::{
29  assignment::{BufferIter, CompletionReceipt, PendingAssignment, UnboundedRange},
30  helpers::*,
31  task::{BatchCollect, BatchReduce, BatchedTask, TaskRef},
32  MAX_BUFFER_LEN, MIN_BUFFER_LEN,
33};
34
35pub(crate) const PHASE: usize = 1;
36
37#[cfg(target_pointer_width = "64")]
38pub(crate) const INDEX_SHIFT: usize = 32;
39#[cfg(target_pointer_width = "32")]
40pub(crate) const INDEX_SHIFT: usize = 16;
41
42#[doc(hidden)]
43pub struct BufferCell<T: Send + Sync + Sized + 'static>(UnsafeCell<MaybeUninit<T>>);
44
45impl<T> Deref for BufferCell<T>
46where
47  T: Send + Sync + Sized + 'static,
48{
49  type Target = UnsafeCell<MaybeUninit<T>>;
50  fn deref(&self) -> &Self::Target {
51    &self.0
52  }
53}
54
55impl<T> BufferCell<T>
56where
57  T: Send + Sync + Sized + 'static,
58{
59  fn new_uninit() -> Self {
60    BufferCell(UnsafeCell::new(MaybeUninit::uninit()))
61  }
62}
63
64unsafe impl<T> Send for BufferCell<T> where T: Send + Sync + Sized + 'static {}
65unsafe impl<T> Sync for BufferCell<T> where T: Send + Sync + Sized + 'static {}
66
67/// Auto-batched queue whereby each task resolves to a value
68///
69/// # Example
70///
71/// ```rust
72/// struct EchoQueue;
73///
74/// #[local_queue(buffer_size = 64)]
75/// impl TaskQueue for EchoQueue {
76///   type Task = usize;
77///   type Value = usize;
78///
79///   async fn batch_process<const N: usize>(
80///     batch: PendingAssignment<'_, Self, N>,
81///   ) -> CompletionReceipt<Self> {
82///     batch.into_assignment().map(|val| val)
83///   }
84/// }
85/// ```
86pub trait TaskQueue: Send + Sync + Sized + 'static {
87  type Task: Send + Sync + Sized + 'static;
88  type Value: Send;
89
90  fn batch_process<const N: usize>(
91    assignment: PendingAssignment<'_, Self, N>,
92  ) -> impl Future<Output = CompletionReceipt<Self>> + Send;
93
94  fn auto_batch<const N: usize>(task: Self::Task) -> BatchedTask<Self, N>
95  where
96    Self: LocalQueue<N, BufferCell = TaskRef<Self>>,
97  {
98    BatchedTask::new(task)
99  }
100}
101
102/// Fire and forget auto-batched queue
103///
104/// # Example
105///
106/// ```rust
107/// struct EchoQueue;
108///
109/// #[local_queue]
110/// impl BackgroundQueue for EchoQueue {
111///   type Task = (usize, oneshot::Sender<usize>);
112///
113///   fn batch_process<const N: usize>(tasks: UnboundedRange<'_, Self::Task, N>) -> impl Future<Output = ()> + Send {
114///     for (val, tx) in tasks.into_bounded().into_iter() {
115///       tx.send(val).ok();
116///     }
117///   }
118/// }
119/// ```
120pub trait BackgroundQueue: Send + Sync + Sized + 'static {
121  type Task: Send + Sync + Sized + 'static;
122
123  fn batch_process<const N: usize>(
124    tasks: UnboundedRange<'_, Self::Task, N>,
125  ) -> impl Future<Output = ()> + Send;
126
127  /// Process task in background
128  ///
129  /// # Panics
130  ///
131  /// Panics if called from **outside** of the Tokio runtime
132  #[cfg(not(loom))]
133  fn auto_batch<const N: usize>(task: Self::Task)
134  where
135    Self: LocalQueue<N, BufferCell = BufferCell<Self::Task>>,
136  {
137    StackQueue::background_process::<Self>(task);
138  }
139}
140
141/// Auto-batched queue whereby tasks are reduced or collected
142///
143/// # Example
144///
145/// ```rust
146/// struct Accumulator;
147///
148/// #[local_queue]
149/// impl BatchReducer for Accumulator {
150///   type Task = usize;
151/// }
152///
153/// let sum: Option<usize> = Accumulator::batch_reduce(9000, |batch| batch.sum::<usize>()).await;
154/// ```
155pub trait BatchReducer: Send + Sync + Sized + 'static {
156  type Task: Send + Sync + Sized + 'static;
157
158  /// Reduce over tasks batched in an async context
159  fn batch_reduce<'a, const N: usize, F, R>(
160    task: Self::Task,
161    reducer: F,
162  ) -> BatchReduce<'a, Self, F, R, N>
163  where
164    Self: LocalQueue<N, BufferCell = BufferCell<Self::Task>>,
165    F: for<'b> FnOnce(BufferIter<'b, Self::Task, N>) -> R + Send,
166  {
167    BatchReduce::new(task, reducer)
168  }
169
170  /// Collect tasks batched in an async context
171  fn batch_collect<'a, const N: usize>(task: Self::Task) -> BatchCollect<'a, Self, N>
172  where
173    Self: LocalQueue<N, BufferCell = BufferCell<Self::Task>>,
174  {
175    BatchCollect::new(task)
176  }
177}
178
179/// Thread local context for enqueuing tasks on [`StackQueue`].
180///
181/// This can be implemented by using the [`local_queue`](crate::local_queue) macro on any
182/// [`TaskQueue`], [`BackgroundQueue`] or [`BatchReducer`] impl
183pub trait LocalQueue<const N: usize> {
184  type BufferCell: Send + Sync + Sized + 'static;
185
186  fn queue() -> &'static LocalKey<StackQueue<Self::BufferCell, N>>;
187}
188
189#[doc(hidden)]
190pub struct Inner<T: Sync + Sized + 'static, const N: usize = 1024> {
191  pub(crate) slot: CachePadded<AtomicUsize>,
192  pub(crate) occupancy: CachePadded<AtomicUsize>,
193  pub(crate) buffer: [T; N],
194  pub(crate) stealer: Stealer<Waker>,
195}
196
197impl<T, const N: usize> From<Stealer<Waker>> for Inner<BufferCell<T>, N>
198where
199  T: Send + Sync + Sized + 'static,
200{
201  fn from(stealer: Stealer<Waker>) -> Self {
202    let buffer = array::from_fn(|_| BufferCell::new_uninit());
203
204    Inner {
205      slot: CachePadded::new(AtomicUsize::new(0)),
206      occupancy: CachePadded::new(AtomicUsize::new(0)),
207      buffer,
208      stealer,
209    }
210  }
211}
212
213impl<T, const N: usize> From<Stealer<Waker>> for Inner<TaskRef<T>, N>
214where
215  T: TaskQueue,
216{
217  fn from(stealer: Stealer<Waker>) -> Self {
218    let buffer = array::from_fn(|_| TaskRef::new_uninit());
219
220    Inner {
221      slot: CachePadded::new(AtomicUsize::new(0)),
222      occupancy: CachePadded::new(AtomicUsize::new(0)),
223      buffer,
224      stealer,
225    }
226  }
227}
228
229impl<T, const N: usize> Inner<T, N>
230where
231  T: Sync + Sized,
232{
233  #[inline(always)]
234  pub(crate) fn deoccupy_region(&self, index: usize) {
235    let one_shifted = one_shifted::<N>(index);
236
237    if self
238      .occupancy
239      .fetch_sub(one_shifted, Ordering::AcqRel)
240      .eq(&one_shifted)
241    {
242      let mut batch_limit = region_size::<N>();
243
244      while batch_limit.gt(&0) {
245        match self.stealer.steal() {
246          Steal::Empty => break,
247          Steal::Success(waker) => waker.wake(),
248          Steal::Retry => continue,
249        }
250
251        batch_limit -= 1;
252      }
253    }
254  }
255}
256
257impl<T, const N: usize> Inner<BufferCell<T>, N>
258where
259  T: Send + Sync + Sized + 'static,
260{
261  #[cfg(not(loom))]
262  #[inline(always)]
263  pub(crate) unsafe fn with_buffer_cell<F, R>(&self, f: F, index: usize) -> R
264  where
265    F: FnOnce(*mut MaybeUninit<T>) -> R,
266  {
267    let cell = self.buffer.get_unchecked(index);
268    f(cell.get())
269  }
270
271  #[cfg(loom)]
272  #[inline(always)]
273  pub(crate) unsafe fn with_buffer_cell<F, R>(&self, f: F, index: usize) -> R
274  where
275    F: FnOnce(*mut MaybeUninit<T>) -> R,
276  {
277    let cell = self.buffer.get_unchecked(index);
278    cell.get_mut().with(f)
279  }
280}
281
282unsafe impl<T, const N: usize> Sync for Inner<T, N> where T: Sync + Sized + 'static {}
283
284#[derive(Debug)]
285pub(crate) struct QueueFull;
286
287/// Task queue designed for facilitating heapless auto-batching of tasks
288#[derive(AsContext)]
289pub struct StackQueue<T: Sync + Sized + 'static, const N: usize = 1024> {
290  slot: CachePadded<UnsafeCell<usize>>,
291  occupancy: CachePadded<UnsafeCell<usize>>,
292  inner: Context<Inner<T, N>>,
293  pub(crate) pending: Worker<Waker>,
294}
295
296impl<T, const N: usize> Default for StackQueue<T, N>
297where
298  T: Sync + Sized + 'static,
299  Inner<T, N>: From<Stealer<Waker>>,
300{
301  fn default() -> Self {
302    debug_assert_eq!(
303      N,
304      N.next_power_of_two(),
305      "StackQueue buffer size must be power of 2"
306    );
307    debug_assert!(N >= MIN_BUFFER_LEN);
308    debug_assert!(N <= MAX_BUFFER_LEN);
309
310    let pending = Worker::new_fifo();
311
312    StackQueue {
313      slot: CachePadded::new(UnsafeCell::new(PHASE)),
314      occupancy: CachePadded::new(UnsafeCell::new(0)),
315      inner: Context::new(Inner::from(pending.stealer())),
316      pending,
317    }
318  }
319}
320
321impl<T, const N: usize> StackQueue<T, N>
322where
323  T: Sync + Sized + 'static,
324{
325  #[cfg(not(loom))]
326  #[inline(always)]
327  unsafe fn with_slot<F, R>(&self, f: F) -> R
328  where
329    F: FnOnce(*const usize) -> R,
330  {
331    f(self.slot.get())
332  }
333
334  #[cfg(loom)]
335  #[inline(always)]
336  unsafe fn with_slot<F, R>(&self, f: F) -> R
337  where
338    F: FnOnce(*const usize) -> R,
339  {
340    self.slot.get().with(f)
341  }
342
343  #[cfg(not(loom))]
344  #[inline(always)]
345  unsafe fn with_slot_mut<F, R>(&self, f: F) -> R
346  where
347    F: FnOnce(*mut usize) -> R,
348  {
349    f(self.slot.get())
350  }
351
352  #[cfg(loom)]
353  #[inline(always)]
354  unsafe fn with_slot_mut<F, R>(&self, f: F) -> R
355  where
356    F: FnOnce(*mut usize) -> R,
357  {
358    self.slot.get_mut().with(f)
359  }
360
361  #[cfg(not(loom))]
362  #[inline(always)]
363  unsafe fn with_occupancy<F, R>(&self, f: F) -> R
364  where
365    F: FnOnce(*const usize) -> R,
366  {
367    f(self.occupancy.get())
368  }
369
370  #[cfg(loom)]
371  #[inline(always)]
372  unsafe fn with_occupancy<F, R>(&self, f: F) -> R
373  where
374    F: FnOnce(*const usize) -> R,
375  {
376    self.occupancy.get().with(f)
377  }
378
379  #[cfg(not(loom))]
380  #[inline(always)]
381  unsafe fn with_occupancy_mut<F, R>(&self, f: F) -> R
382  where
383    F: FnOnce(*mut usize) -> R,
384  {
385    f(self.occupancy.get())
386  }
387
388  #[cfg(loom)]
389  #[inline(always)]
390  unsafe fn with_occupancy_mut<F, R>(&self, f: F) -> R
391  where
392    F: FnOnce(*mut usize) -> R,
393  {
394    self.occupancy.get_mut().with(f)
395  }
396
397  #[inline(always)]
398  fn current_write_index(&self) -> usize {
399    // This algorithm can utilize an UnsafeCell for the index counter because where the current task
400    // is written is independent of when a phase change would result in a new task batch owning a
401    // new range of the buffer; only ownership is determined by atomic synchronization, not location
402    unsafe { self.with_slot(|val| slot_index::<N>(*val)) }
403  }
404
405  #[inline(always)]
406  fn check_regional_occupancy(&self, index: usize) -> Result<(), QueueFull> {
407    let region_mask = region_mask::<N>(index);
408
409    // If this is out of sync, then the region could be incorrectly marked as full, but never
410    // incorrectly marked as free, and so this optimization allows us to avoid the overhead of an
411    // atomic call so long as regions are cleared within a full cycle
412    let regional_occupancy =
413      unsafe { self.with_occupancy(|occupancy| (*occupancy).bitand(region_mask)) };
414
415    if regional_occupancy.eq(&0) {
416      return Ok(());
417    }
418
419    // Usually this slow path won't occur because occupancy syncs when new batches are created
420    let occupancy = self.inner.occupancy.load(Ordering::Acquire);
421    let regional_occupancy = occupancy.bitand(region_mask);
422
423    unsafe {
424      self.with_occupancy_mut(move |val| *val = occupancy);
425    }
426
427    if regional_occupancy.eq(&0) {
428      Ok(())
429    } else {
430      Err(QueueFull)
431    }
432  }
433
434  #[inline(always)]
435  fn occupy_region(&self, index: usize) {
436    // Add one relative to the the current region. In the unlikely event of an overflow, the next
437    // region checkpoint will result in QueueFull until fewer than 256 task batches exist.
438    let shifted_add = one_shifted::<N>(index);
439
440    let occupancy = self
441      .inner
442      .occupancy
443      .fetch_add(shifted_add, Ordering::AcqRel)
444      .wrapping_add(shifted_add);
445
446    unsafe {
447      self.with_occupancy_mut(move |val| *val = occupancy);
448    }
449  }
450
451  #[inline(always)]
452  unsafe fn replace_slot(&self, slot: usize) -> usize {
453    self.with_slot_mut(move |val| std::mem::replace(&mut *val, slot))
454  }
455}
456
457impl<T, const N: usize> StackQueue<TaskRef<T>, N>
458where
459  T: TaskQueue,
460{
461  pub(crate) unsafe fn enqueue<'a, F>(
462    &self,
463    write_with: F,
464  ) -> Result<Option<PendingAssignment<'a, T, N>>, QueueFull>
465  where
466    F: FnOnce(&TaskRef<T>),
467  {
468    let write_index = self.current_write_index();
469
470    // Regions sizes are always a power of 2, and so this acts as an optimized modulus operation
471    if write_index.bitand(region_size::<N>() - 1).eq(&0) {
472      self.check_regional_occupancy(write_index)?;
473    }
474
475    write_with(self.inner.buffer.get_unchecked(write_index));
476
477    let base_slot = self
478      .inner
479      .slot
480      .fetch_add(1 << INDEX_SHIFT, Ordering::Relaxed);
481
482    let prev_slot = self.replace_slot(base_slot.wrapping_add(1 << INDEX_SHIFT));
483
484    if ((base_slot ^ prev_slot) & PHASE).eq(&0) {
485      Ok(None)
486    } else {
487      self.occupy_region(write_index);
488
489      let queue = self.inner.local_ref();
490
491      Ok(Some(PendingAssignment::new(base_slot, queue)))
492    }
493  }
494}
495
496impl<T, const N: usize> StackQueue<BufferCell<T>, N>
497where
498  T: Send + Sync + Sized + 'static,
499{
500  pub(crate) unsafe fn push<'a>(&self, task: T) -> Result<Option<UnboundedRange<'a, T, N>>, T> {
501    let write_index = self.current_write_index();
502
503    // Regions sizes are always a power of 2, and so this acts as an optimized modulus operation
504    if write_index.bitand(region_size::<N>() - 1).eq(&0)
505      && self.check_regional_occupancy(write_index).is_err()
506    {
507      return Err(task);
508    }
509
510    self
511      .inner
512      .with_buffer_cell(|cell| cell.write(MaybeUninit::new(task)), write_index);
513
514    let base_slot = self
515      .inner
516      .slot
517      .fetch_add(1 << INDEX_SHIFT, Ordering::Relaxed);
518
519    let prev_slot = self.replace_slot(base_slot.wrapping_add(1 << INDEX_SHIFT));
520
521    if ((base_slot ^ prev_slot) & PHASE).eq(&0) {
522      Ok(None)
523    } else {
524      self.occupy_region(write_index);
525
526      let queue = self.inner.local_ref();
527
528      Ok(Some(UnboundedRange::new(base_slot, queue)))
529    }
530  }
531
532  #[cfg(not(loom))]
533  fn background_process<Q>(task: T)
534  where
535    Q: BackgroundQueue<Task = T> + LocalQueue<N, BufferCell = BufferCell<T>>,
536  {
537    use crate::task::BackgroundEnqueue;
538
539    Q::queue().with(|queue| match unsafe { queue.push(task) } {
540      Ok(Some(assignment)) => {
541        spawn(async move {
542          Q::batch_process::<N>(assignment).await;
543        });
544      }
545      Ok(None) => {}
546      Err(task) => {
547        spawn(async move {
548          if let Some(assignment) = BackgroundEnqueue::<'_, Q, N>::new(task).await {
549            Q::batch_process::<N>(assignment).await;
550          }
551        });
552      }
553    });
554  }
555}
556#[cfg(test)]
557mod test {
558  #[cfg(not(loom))]
559  use std::{thread, time::Duration};
560
561  #[cfg(not(loom))]
562  use futures::{stream::FuturesUnordered, StreamExt};
563  #[cfg(not(loom))]
564  use tokio::{
565    sync::{oneshot, Barrier},
566    task::{spawn, yield_now},
567  };
568
569  use crate::{
570    assignment::{CompletionReceipt, PendingAssignment},
571    local_queue, TaskQueue,
572  };
573  #[cfg(not(loom))]
574  use crate::{queue::UnboundedRange, BackgroundQueue};
575
576  struct EchoQueue;
577
578  #[local_queue(buffer_size = 64)]
579  impl TaskQueue for EchoQueue {
580    type Task = usize;
581    type Value = usize;
582
583    async fn batch_process<const N: usize>(
584      batch: PendingAssignment<'_, Self, N>,
585    ) -> CompletionReceipt<Self> {
586      batch.into_assignment().map(|val| val)
587    }
588  }
589
590  #[cfg(not(loom))]
591  #[cfg_attr(not(loom), tokio::test(crate = "async_local"))]
592  async fn it_process_tasks() {
593    use rand::{distributions::Standard, prelude::*};
594    let mut rng = rand::thread_rng();
595
596    let seed: Vec<usize> = (&mut rng).sample_iter(Standard).take(1 << 16).collect();
597
598    let expected_total: u128 = seed
599      .iter()
600      .fold(0, |total, val| total.wrapping_add(*val as u128));
601
602    let mut seed = seed.into_iter();
603    let mut total: u128 = 0;
604
605    while seed.len().gt(&0) {
606      let mut tasks: FuturesUnordered<_> = (&mut seed)
607        .take(rng.gen_range(0..1 << 13))
608        .map(EchoQueue::auto_batch)
609        .collect();
610
611      while let Some(val) = tasks.next().await {
612        total = total.wrapping_add(val as u128);
613      }
614    }
615
616    assert_eq!(total, expected_total);
617  }
618
619  #[cfg(not(loom))]
620  #[cfg_attr(not(loom), tokio::test(crate = "async_local", flavor = "multi_thread"))]
621  async fn it_cycles() {
622    for i in 0..1 << 16 {
623      EchoQueue::auto_batch(i).await;
624    }
625  }
626
627  #[cfg(not(loom))]
628  struct SlowQueue;
629
630  #[cfg(not(loom))]
631  #[local_queue(buffer_size = 64)]
632  impl TaskQueue for SlowQueue {
633    type Task = usize;
634    type Value = usize;
635
636    async fn batch_process<const N: usize>(
637      batch: PendingAssignment<'_, Self, N>,
638    ) -> CompletionReceipt<Self> {
639      batch
640        .with_blocking(|batch| {
641          let assignment = batch.into_assignment();
642          thread::sleep(Duration::from_millis(50));
643          assignment.map(|task| task)
644        })
645        .await
646    }
647  }
648
649  #[cfg(not(loom))]
650  #[cfg_attr(not(loom), tokio::test(crate = "async_local", flavor = "multi_thread"))]
651  async fn it_has_drop_safety() {
652    let handle = spawn(async {
653      SlowQueue::auto_batch(0).await;
654    });
655
656    yield_now().await;
657
658    handle.abort();
659  }
660
661  #[cfg(not(loom))]
662  struct YieldQueue;
663
664  #[cfg(not(loom))]
665  #[local_queue(buffer_size = 64)]
666  impl TaskQueue for YieldQueue {
667    type Task = usize;
668    type Value = usize;
669
670    async fn batch_process<const N: usize>(
671      batch: PendingAssignment<'_, Self, N>,
672    ) -> CompletionReceipt<Self> {
673      let assignment = batch.into_assignment();
674
675      yield_now().await;
676
677      assignment.map(|val| val)
678    }
679  }
680
681  #[cfg(not(loom))]
682  #[cfg_attr(not(loom), tokio::test(crate = "async_local", flavor = "multi_thread"))]
683  async fn it_negotiates_receiver_drop() {
684    use std::sync::Arc;
685
686    let tasks: FuturesUnordered<_> = (0..8192)
687      .map(|i| async move {
688        let barrier = Arc::new(Barrier::new(2));
689
690        let task_barrier = barrier.clone();
691
692        let handle = tokio::task::spawn(async move {
693          task_barrier.wait().await;
694          YieldQueue::auto_batch(i).await;
695        });
696
697        barrier.wait().await;
698        yield_now().await;
699
700        handle.abort()
701      })
702      .collect();
703
704    tasks.collect::<Vec<_>>().await;
705  }
706
707  #[cfg(loom)]
708  #[test]
709  fn stack_queue_drops() {
710    use crate::{BufferCell, StackQueue};
711
712    loom::model(|| {
713      let queue: StackQueue<BufferCell<usize>, 64> = StackQueue::default();
714      drop(queue);
715    });
716  }
717
718  #[cfg(loom)]
719  #[test]
720  fn the_occupancy_model_synchronizes() {
721    use loom::{
722      sync::{
723        atomic::{AtomicUsize, Ordering},
724        Arc,
725      },
726      thread,
727    };
728
729    loom::model(|| {
730      let occupancy = Arc::new(AtomicUsize::new(0));
731
732      assert_eq!(occupancy.fetch_add(1, Ordering::AcqRel), 0);
733
734      {
735        let occupancy = occupancy.clone();
736        thread::spawn(move || {
737          occupancy.fetch_sub(1, Ordering::Release);
738        })
739      }
740      .join()
741      .unwrap();
742
743      assert_eq!(occupancy.load(Ordering::Acquire), 0);
744    });
745  }
746
747  #[cfg(loom)]
748  #[test]
749  fn it_manages_occupancy() {
750    use crate::{queue::UnboundedRange, BufferCell, StackQueue};
751
752    let expected_total = (0..256).into_iter().sum::<usize>();
753
754    loom::model(move || {
755      let queue: StackQueue<BufferCell<usize>, 64> = StackQueue::default();
756      let mut batch: Option<UnboundedRange<usize, 64>> = None;
757      let mut i = 0;
758      let mut total = 0;
759
760      while i < 256 {
761        match unsafe { queue.push(i) } {
762          Ok(Some(unbounded_batch)) => {
763            batch = Some(unbounded_batch);
764            i += 1;
765          }
766          Ok(None) => {
767            i += 1;
768          }
769          Err(_) => {
770            if let Some(batch) = batch.take() {
771              total += batch.into_bounded().into_iter().sum::<usize>();
772            } else {
773              panic!("queue full despite buffer unoccupied");
774            }
775            continue;
776          }
777        }
778      }
779
780      if let Some(batch) = batch.take() {
781        total += batch.into_bounded().into_iter().sum::<usize>();
782      }
783
784      assert_eq!(total, expected_total);
785    });
786  }
787
788  #[cfg(loom)]
789  #[test]
790  fn it_negotiates_receiver_drop() {
791    use std::{hint::unreachable_unchecked, ptr::addr_of};
792
793    use futures::pin_mut;
794    use futures_test::task::noop_waker;
795    use loom::sync::{Arc, Condvar, Mutex};
796
797    use crate::task::{BatchedTask, Receiver, State, TaskRef};
798
799    loom::model(|| {
800      let task: Arc<TaskRef<EchoQueue>> = Arc::new(TaskRef::new_uninit());
801      let barrier = Arc::new((Mutex::new(false), Condvar::new()));
802
803      let resolver_handle = {
804        let task = task.clone();
805        let barrier = barrier.clone();
806
807        loom::thread::spawn(move || {
808          let (lock, cvar) = &*barrier;
809          let mut task_initialized = lock.lock().unwrap();
810          while !*task_initialized {
811            task_initialized = cvar.wait(task_initialized).unwrap();
812          }
813
814          unsafe {
815            task.resolve_unchecked(9001);
816          }
817        })
818      };
819
820      let receiver_handle = {
821        loom::thread::spawn(move || {
822          let waker = noop_waker();
823
824          let rx: Receiver<EchoQueue> = Receiver::new(task.state_ptr(), waker);
825
826          let auto_batched_task: BatchedTask<EchoQueue, 256> = BatchedTask {
827            state: State::Batched(rx),
828          };
829
830          pin_mut!(auto_batched_task);
831
832          let rx = match &auto_batched_task.state {
833            State::Batched(rx) => {
834              addr_of!(*rx)
835            }
836            _ => unsafe { unreachable_unchecked() },
837          };
838
839          unsafe {
840            task.init(9001, rx);
841          };
842
843          let (lock, cvar) = &*barrier;
844          let mut task_initialized = lock.lock().unwrap();
845          *task_initialized = true;
846          cvar.notify_one();
847
848          #[allow(clippy::drop_non_drop)]
849          drop(auto_batched_task);
850        })
851      };
852
853      resolver_handle.join().unwrap();
854      receiver_handle.join().unwrap();
855    });
856  }
857
858  #[cfg(not(loom))]
859  struct EchoBackgroundQueue;
860
861  #[cfg(not(loom))]
862  #[local_queue]
863  impl BackgroundQueue for EchoBackgroundQueue {
864    type Task = (usize, oneshot::Sender<usize>);
865
866    async fn batch_process<const N: usize>(tasks: UnboundedRange<'_, Self::Task, N>) {
867      let tasks = tasks.into_bounded().into_iter();
868
869      for (val, tx) in tasks {
870        tx.send(val).ok();
871      }
872    }
873  }
874
875  #[cfg(not(loom))]
876  #[cfg_attr(not(loom), tokio::test(crate = "async_local", flavor = "multi_thread"))]
877
878  async fn it_process_background_tasks() {
879    #[allow(clippy::needless_collect)]
880    let receivers: Vec<_> = (0..10_usize)
881      .map(|i| {
882        let (tx, rx) = oneshot::channel::<usize>();
883        EchoBackgroundQueue::auto_batch((i, tx));
884        rx
885      })
886      .collect();
887
888    for (i, rx) in receivers.into_iter().enumerate() {
889      assert_eq!(rx.await, Ok(i));
890    }
891  }
892
893  #[cfg(not(loom))]
894  #[cfg_attr(not(loom), tokio::test(crate = "async_local", flavor = "multi_thread"))]
895  async fn it_batch_reduces() {
896    use crate::BatchReducer;
897
898    struct Accumulator;
899
900    #[local_queue]
901    impl BatchReducer for Accumulator {
902      type Task = usize;
903    }
904
905    let tasks: FuturesUnordered<_> = (0..1 << 16)
906      .map(|i| Accumulator::batch_reduce(i, |iter| iter.sum::<usize>()))
907      .collect();
908
909    let total = tasks
910      .fold(0_usize, |total, value| async move {
911        total + value.unwrap_or_default()
912      })
913      .await;
914
915    assert_eq!(total, (0..1 << 16).sum());
916  }
917
918  #[cfg(not(loom))]
919  #[cfg_attr(not(loom), tokio::test(crate = "async_local", flavor = "multi_thread"))]
920  async fn it_batch_collects() {
921    use crate::BatchReducer;
922
923    struct Accumulator;
924
925    #[local_queue]
926    impl BatchReducer for Accumulator {
927      type Task = usize;
928    }
929
930    let mut tasks: FuturesUnordered<_> = (0..1 << 16).map(Accumulator::batch_collect).collect();
931
932    let mut total = 0;
933
934    while let Some(batch) = tasks.next().await {
935      total += batch.map_or(0, |batch| batch.into_iter().sum::<usize>());
936    }
937
938    assert_eq!(total, (0..1 << 16).sum());
939  }
940}