stack_queue/
assignment.rs

1#[cfg(not(loom))]
2use std::sync::atomic::Ordering;
3use std::{
4  marker::PhantomData,
5  mem::{self, needs_drop},
6  ops::Range,
7};
8
9use async_local::LocalRef;
10#[cfg(loom)]
11use loom::sync::atomic::Ordering;
12#[cfg(not(loom))]
13use tokio::task::{spawn_blocking, JoinHandle};
14
15use crate::{
16  queue::{Inner, TaskQueue, INDEX_SHIFT, PHASE},
17  task::TaskRef,
18  BufferCell,
19};
20
21/// The responsibilty to process a yet to be assigned set of tasks.
22pub struct PendingAssignment<'a, T: TaskQueue, const N: usize> {
23  base_slot: usize,
24  queue: LocalRef<'a, Inner<TaskRef<T>, N>>,
25}
26
27impl<'a, T, const N: usize> PendingAssignment<'a, T, N>
28where
29  T: TaskQueue,
30{
31  pub(crate) fn new(base_slot: usize, queue: LocalRef<'a, Inner<TaskRef<T>, N>>) -> Self {
32    PendingAssignment { base_slot, queue }
33  }
34
35  #[inline(always)]
36  fn set_assignment_bounds(&self) -> Range<usize> {
37    let end_slot = self.queue.slot.fetch_xor(PHASE, Ordering::Relaxed);
38
39    (self.base_slot >> INDEX_SHIFT)..(end_slot >> INDEX_SHIFT)
40  }
41
42  /// By converting into a [`TaskAssignment`] the task range responsible for processing will be
43  /// bounded and further tasks enqueued will be of a new batch. Assignment of a task range can be
44  /// deferred until resources such as database connections are ready as a way to process tasks in
45  /// larger batches. This operation is constant time and wait-free
46  pub fn into_assignment(self) -> TaskAssignment<'a, T, N> {
47    let task_range = self.set_assignment_bounds();
48    let queue = self.queue;
49
50    mem::forget(self);
51
52    TaskAssignment::new(task_range, queue)
53  }
54
55  /// Move [`PendingAssignment`] into a thread where blocking is acceptable.
56  #[cfg(not(loom))]
57  pub async fn with_blocking<F>(self, f: F) -> CompletionReceipt<T>
58  where
59    F: for<'b> FnOnce(PendingAssignment<'b, T, N>) -> CompletionReceipt<T> + Send + 'static,
60  {
61    let batch: PendingAssignment<'_, T, N> = unsafe { std::mem::transmute(self) };
62    tokio::task::spawn_blocking(move || f(batch)).await.unwrap()
63  }
64}
65
66unsafe impl<T, const N: usize> Send for PendingAssignment<'_, T, N> where T: TaskQueue {}
67unsafe impl<T, const N: usize> Sync for PendingAssignment<'_, T, N> where T: TaskQueue {}
68
69impl<T, const N: usize> Drop for PendingAssignment<'_, T, N>
70where
71  T: TaskQueue,
72{
73  fn drop(&mut self) {
74    let task_range = self.set_assignment_bounds();
75    let queue = self.queue;
76
77    TaskAssignment::new(task_range, queue);
78  }
79}
80
81/// Assignment of a task range yet to be processed
82pub struct TaskAssignment<'a, T: TaskQueue, const N: usize> {
83  task_range: Range<usize>,
84  queue: LocalRef<'a, Inner<TaskRef<T>, N>>,
85}
86
87impl<'a, T, const N: usize> TaskAssignment<'a, T, N>
88where
89  T: TaskQueue,
90{
91  fn new(task_range: Range<usize>, queue: LocalRef<'a, Inner<TaskRef<T>, N>>) -> Self {
92    TaskAssignment { task_range, queue }
93  }
94
95  /// Returns a pair of slices which contain, in order, the contents of the assigned task range.
96  pub fn as_slices(&self) -> (&[TaskRef<T>], &[TaskRef<T>]) {
97    let start = self.task_range.start & (N - 1);
98    let end = self.task_range.end & (N - 1);
99
100    if end > start {
101      unsafe { (self.queue.buffer.get_unchecked(start..end), &[]) }
102    } else {
103      unsafe {
104        (
105          self.queue.buffer.get_unchecked(start..N),
106          self.queue.buffer.get_unchecked(0..end),
107        )
108      }
109    }
110  }
111
112  /// An iterator over the assigned task range
113  pub fn tasks(&self) -> impl Iterator<Item = &TaskRef<T>> {
114    let tasks = self.as_slices();
115    tasks.0.iter().chain(tasks.1.iter())
116  }
117
118  /// Resolve task assignment with an iterator where indexes align with tasks
119  pub fn resolve_with_iter<I>(self, iter: I) -> CompletionReceipt<T>
120  where
121    I: IntoIterator<Item = T::Value>,
122  {
123    self.tasks().zip(iter).for_each(|(task_ref, value)| unsafe {
124      if needs_drop::<T::Task>() {
125        drop(task_ref.take_task_unchecked());
126      }
127
128      task_ref.resolve_unchecked(value);
129    });
130
131    self.into_completion_receipt()
132  }
133
134  /// Resolve task assignment by mapping each task into it's respective value
135  pub fn map<F>(self, op: F) -> CompletionReceipt<T>
136  where
137    F: Fn(T::Task) -> T::Value + Sync,
138  {
139    self.tasks().for_each(|task_ref| unsafe {
140      let task = task_ref.take_task_unchecked();
141      task_ref.resolve_unchecked(op(task));
142    });
143
144    self.into_completion_receipt()
145  }
146
147  #[inline(always)]
148  fn deoccupy_buffer(&self) {
149    self.queue.deoccupy_region(self.task_range.start & (N - 1));
150  }
151
152  fn into_completion_receipt(self) -> CompletionReceipt<T> {
153    self.deoccupy_buffer();
154
155    mem::forget(self);
156
157    CompletionReceipt::new()
158  }
159
160  /// Move [`TaskAssignment`] into a thread where blocking is acceptable
161  #[cfg(not(loom))]
162  pub async fn with_blocking<F>(self, f: F) -> CompletionReceipt<T>
163  where
164    F: for<'b> FnOnce(TaskAssignment<'b, T, N>) -> CompletionReceipt<T> + Send + 'static,
165  {
166    let batch: TaskAssignment<'_, T, N> = unsafe { std::mem::transmute(self) };
167    tokio::task::spawn_blocking(move || f(batch)).await.unwrap()
168  }
169}
170
171impl<T, const N: usize> Drop for TaskAssignment<'_, T, N>
172where
173  T: TaskQueue,
174{
175  fn drop(&mut self) {
176    if needs_drop::<T::Task>() {
177      self
178        .tasks()
179        .for_each(|task_ref| unsafe { drop(task_ref.take_task_unchecked()) });
180    }
181
182    self.deoccupy_buffer();
183  }
184}
185
186unsafe impl<T, const N: usize> Send for TaskAssignment<'_, T, N> where T: TaskQueue {}
187unsafe impl<T, const N: usize> Sync for TaskAssignment<'_, T, N> where T: TaskQueue {}
188
189/// A type-state proof of completion for a task assignment
190pub struct CompletionReceipt<T: TaskQueue>(PhantomData<T>);
191
192impl<T> CompletionReceipt<T>
193where
194  T: TaskQueue,
195{
196  fn new() -> Self {
197    CompletionReceipt(PhantomData)
198  }
199}
200/// A guard granting exclusive access over an unbounded range of a [`StackQueue`](crate::StackQueue)
201/// buffer
202pub struct UnboundedRange<'a, T: Send + Sync + Sized + 'static, const N: usize> {
203  base_slot: usize,
204  queue: LocalRef<'a, Inner<BufferCell<T>, N>>,
205}
206
207impl<'a, T, const N: usize> UnboundedRange<'a, T, N>
208where
209  T: Send + Sync + Sized + 'static,
210{
211  pub(crate) fn new(base_slot: usize, queue: LocalRef<'a, Inner<BufferCell<T>, N>>) -> Self {
212    UnboundedRange { base_slot, queue }
213  }
214
215  #[inline(always)]
216  fn set_bounds(&self) -> Range<usize> {
217    let end_slot = self.queue.slot.fetch_xor(PHASE, Ordering::Relaxed);
218    (self.base_slot >> INDEX_SHIFT)..(end_slot >> INDEX_SHIFT)
219  }
220
221  /// Establish exclusive access over a [`StackQueue`](crate::StackQueue) buffer range
222  pub fn into_bounded(self) -> BoundedRange<'a, T, N> {
223    let range = self.set_bounds();
224    let queue = self.queue;
225
226    mem::forget(self);
227
228    BoundedRange::new(range, queue)
229  }
230
231  /// Move [`UnboundedRange`] into a thread where blocking is acceptable.
232  #[cfg(not(loom))]
233  pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
234  where
235    F: for<'b> FnOnce(UnboundedRange<'b, T, N>) -> R + Send + 'static,
236    R: Send + 'static,
237  {
238    let batch: UnboundedRange<'_, T, N> = unsafe { std::mem::transmute(self) };
239    spawn_blocking(move || f(batch))
240  }
241}
242
243impl<T, const N: usize> Drop for UnboundedRange<'_, T, N>
244where
245  T: Send + Sync + Sized + 'static,
246{
247  fn drop(&mut self) {
248    let task_range = self.set_bounds();
249    let start_index = task_range.start & (N - 1);
250
251    let queue = self.queue;
252
253    if needs_drop::<T>() {
254      for index in task_range {
255        unsafe {
256          queue.with_buffer_cell(|cell| (*cell).assume_init_drop(), index & (N - 1));
257        }
258      }
259    }
260
261    self.queue.deoccupy_region(start_index);
262  }
263}
264
265unsafe impl<T, const N: usize> Send for UnboundedRange<'_, T, N> where
266  T: Send + Sync + Sized + 'static
267{
268}
269unsafe impl<T, const N: usize> Sync for UnboundedRange<'_, T, N> where
270  T: Send + Sync + Sized + 'static
271{
272}
273
274/// A guard granting exclusive access over a bounded range of a [`StackQueue`](crate::StackQueue)
275/// buffer
276pub struct BoundedRange<'a, T: Send + Sync + Sized + 'static, const N: usize> {
277  range: Range<usize>,
278  queue: LocalRef<'a, Inner<BufferCell<T>, N>>,
279}
280
281impl<'a, T, const N: usize> BoundedRange<'a, T, N>
282where
283  T: Send + Sync + Sized + 'static,
284{
285  fn new(range: Range<usize>, queue: LocalRef<'a, Inner<BufferCell<T>, N>>) -> Self {
286    BoundedRange { range, queue }
287  }
288
289  /// Returns a pair of slices which contain, in order, the contents of the guarded task range.
290  #[cfg(not(loom))]
291  pub fn as_slices(&self) -> (&[T], &[T]) {
292    let start = self.range.start & (N - 1);
293    let end = self.range.end & (N - 1);
294
295    if end > start {
296      unsafe {
297        mem::transmute::<(&[BufferCell<T>], &[BufferCell<T>]), _>((
298          self.queue.buffer.get_unchecked(start..end),
299          &[],
300        ))
301      }
302    } else {
303      unsafe {
304        mem::transmute((
305          self.queue.buffer.get_unchecked(start..N),
306          self.queue.buffer.get_unchecked(0..end),
307        ))
308      }
309    }
310  }
311
312  /// An iterator over the guarded task range
313  #[cfg(not(loom))]
314  pub fn iter(&self) -> impl Iterator<Item = &T> {
315    let tasks = self.as_slices();
316    tasks.0.iter().chain(tasks.1.iter())
317  }
318
319  #[cfg(not(loom))]
320  pub fn to_vec(self) -> Vec<T> {
321    let items = self.as_slices();
322    let front_len = items.0.len();
323    let back_len = items.1.len();
324    let total_len = front_len + back_len;
325    let mut buffer = Vec::new();
326    buffer.reserve_exact(total_len);
327
328    unsafe {
329      std::ptr::copy_nonoverlapping(items.0.as_ptr(), buffer.as_mut_ptr(), front_len);
330      if back_len > 0 {
331        std::ptr::copy_nonoverlapping(
332          items.1.as_ptr(),
333          buffer.as_mut_ptr().add(front_len),
334          back_len,
335        );
336      }
337      buffer.set_len(total_len);
338    }
339
340    self.deoccupy_buffer();
341
342    mem::forget(self);
343
344    buffer
345  }
346
347  /// Move [`BoundedRange`] into a thread where blocking is acceptable.
348  #[cfg(not(loom))]
349  pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
350  where
351    F: for<'b> FnOnce(BoundedRange<'b, T, N>) -> R + Send + 'static,
352    R: Send + 'static,
353  {
354    let batch: BoundedRange<'_, T, N> = unsafe { std::mem::transmute(self) };
355    batch.queue.with_blocking(move |_| f(batch))
356  }
357
358  #[inline(always)]
359  fn deoccupy_buffer(&self) {
360    self.queue.deoccupy_region(self.range.start & (N - 1));
361  }
362}
363
364impl<T, const N: usize> Drop for BoundedRange<'_, T, N>
365where
366  T: Send + Sync + Sized + 'static,
367{
368  fn drop(&mut self) {
369    if needs_drop::<T>() {
370      for index in self.range.clone() {
371        unsafe {
372          self
373            .queue
374            .with_buffer_cell(|cell| (*cell).assume_init_drop(), index & (N - 1));
375        }
376      }
377    }
378
379    self.deoccupy_buffer();
380  }
381}
382
383unsafe impl<T, const N: usize> Send for BoundedRange<'_, T, N> where T: Send + Sync + Sized + 'static
384{}
385unsafe impl<T, const N: usize> Sync for BoundedRange<'_, T, N> where T: Send + Sync + Sized + 'static
386{}
387
388/// An iterator over a guarded range of tasks from a [`StackQueue`](crate::StackQueue) buffer
389pub struct BufferIter<'a, T: Send + Sync + Sized + 'static, const N: usize> {
390  current: usize,
391  range: Range<usize>,
392  queue: LocalRef<'a, Inner<BufferCell<T>, N>>,
393}
394
395impl<T, const N: usize> BufferIter<'_, T, N>
396where
397  T: Send + Sync + Sized + 'static,
398{
399  fn deoccupy_buffer(&self) {
400    self.queue.deoccupy_region(self.range.start & (N - 1));
401  }
402}
403
404impl<T, const N: usize> Iterator for BufferIter<'_, T, N>
405where
406  T: Send + Sync + Sized + 'static,
407{
408  type Item = T;
409
410  fn next(&mut self) -> Option<Self::Item> {
411    if self.current < self.range.end {
412      let task = unsafe {
413        self
414          .queue
415          .with_buffer_cell(|cell| (*cell).assume_init_read(), self.current & (N - 1))
416      };
417
418      self.current += 1;
419
420      Some(task)
421    } else {
422      None
423    }
424  }
425}
426
427unsafe impl<T, const N: usize> Send for BufferIter<'_, T, N> where T: Send + Sync + Sized + 'static {}
428unsafe impl<T, const N: usize> Sync for BufferIter<'_, T, N> where T: Send + Sync + Sized + 'static {}
429
430impl<T, const N: usize> Drop for BufferIter<'_, T, N>
431where
432  T: Send + Sync + Sized + 'static,
433{
434  fn drop(&mut self) {
435    if needs_drop::<T>() {
436      while self.next().is_some() {}
437    }
438    self.deoccupy_buffer();
439  }
440}
441
442impl<'a, T, const N: usize> IntoIterator for BoundedRange<'a, T, N>
443where
444  T: Send + Sync + Sized + 'static,
445{
446  type Item = T;
447  type IntoIter = BufferIter<'a, T, N>;
448
449  fn into_iter(self) -> Self::IntoIter {
450    let iter = BufferIter {
451      current: self.range.start,
452      range: self.range.clone(),
453      queue: self.queue,
454    };
455
456    mem::forget(self);
457
458    iter
459  }
460}