stack_queue/
task.rs

1#[cfg(not(loom))]
2use std::{
3  cell::UnsafeCell,
4  fmt,
5  fmt::Debug,
6  future::Future,
7  hint::unreachable_unchecked,
8  mem,
9  ops::Deref,
10  ptr::addr_of,
11  sync::atomic::{AtomicUsize, Ordering},
12  task::{Context, Poll},
13};
14use std::{
15  marker::{PhantomData, PhantomPinned},
16  mem::{needs_drop, MaybeUninit},
17  pin::Pin,
18  task::Waker,
19};
20
21#[cfg(feature = "diesel-associations")]
22use diesel::associations::BelongsTo;
23#[cfg(loom)]
24use loom::{
25  cell::UnsafeCell,
26  sync::atomic::{AtomicUsize, Ordering},
27};
28#[cfg(not(loom))]
29use parking_lot_core::SpinWait;
30use pin_project::{pin_project, pinned_drop};
31#[cfg(feature = "redis-args")]
32use redis::{RedisWrite, ToRedisArgs};
33#[cfg(not(loom))]
34use tokio::task::spawn;
35
36use crate::{
37  assignment::{BufferIter, UnboundedRange},
38  queue::{LocalQueue, TaskQueue},
39  BackgroundQueue, BatchReducer,
40};
41#[cfg(not(loom))]
42use crate::{queue::QueueFull, BufferCell};
43
44const SETTING_VALUE: usize = 1 << 0;
45const VALUE_SET: usize = 1 << 1;
46const RX_DROPPED: usize = 1 << 2;
47
48/// A pointer to the pinned receiver of an enqueued [`BatchedTask`]
49pub struct TaskRef<T: TaskQueue> {
50  state: UnsafeCell<AtomicUsize>,
51  rx: UnsafeCell<MaybeUninit<*const Receiver<T>>>,
52  task: UnsafeCell<MaybeUninit<T::Task>>,
53}
54
55#[cfg(not(loom))]
56impl<T> Deref for TaskRef<T>
57where
58  T: TaskQueue,
59{
60  type Target = T::Task;
61  fn deref(&self) -> &Self::Target {
62    self.task()
63  }
64}
65
66#[cfg(not(loom))]
67impl<T> PartialEq for TaskRef<T>
68where
69  T: TaskQueue,
70  T::Task: PartialEq,
71{
72  fn eq(&self, other: &Self) -> bool {
73    self.task().eq(other.task())
74  }
75}
76
77#[cfg(not(loom))]
78impl<T> PartialOrd for TaskRef<T>
79where
80  T: TaskQueue,
81  T::Task: PartialOrd,
82{
83  fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
84    self.task().partial_cmp(other.task())
85  }
86}
87
88impl<T> TaskRef<T>
89where
90  T: TaskQueue,
91{
92  pub(crate) fn new_uninit() -> Self {
93    TaskRef {
94      state: UnsafeCell::new(AtomicUsize::new(0)),
95      rx: UnsafeCell::new(MaybeUninit::uninit()),
96      task: UnsafeCell::new(MaybeUninit::uninit()),
97    }
98  }
99  #[cfg(not(loom))]
100  #[inline(always)]
101  pub(crate) fn with_state<F, R>(&self, f: F) -> R
102  where
103    F: FnOnce(*const AtomicUsize) -> R,
104  {
105    f(self.state.get())
106  }
107
108  #[cfg(loom)]
109  #[inline(always)]
110  pub(crate) fn with_state<F, R>(&self, f: F) -> R
111  where
112    F: FnOnce(*const AtomicUsize) -> R,
113  {
114    self.state.get().with(f)
115  }
116
117  #[inline(always)]
118  pub(crate) fn state_ptr(&self) -> *const AtomicUsize {
119    self.with_state(std::convert::identity)
120  }
121
122  #[cfg(not(loom))]
123  #[inline(always)]
124  unsafe fn set_state_unsync(&self, state: usize) {
125    *(*self.state.get()).get_mut() = state;
126  }
127
128  #[cfg(loom)]
129  #[inline(always)]
130  unsafe fn set_state_unsync(&self, state: usize) {
131    self.state.get_mut().deref().with_mut(|val| *val = state);
132  }
133
134  pub(crate) unsafe fn init(&self, task: T::Task, rx: *const Receiver<T>) {
135    self.set_state_unsync(0);
136    self.with_rx_mut(|val| val.write(MaybeUninit::new(rx)));
137    self.with_task_mut(|val| val.write(MaybeUninit::new(task)));
138  }
139
140  #[cfg(not(loom))]
141  #[inline(always)]
142  pub(crate) fn rx(&self) -> &Receiver<T> {
143    unsafe { &**(*self.rx.get()).assume_init_ref() }
144  }
145
146  #[cfg(loom)]
147  #[inline(always)]
148  pub(crate) fn rx(&self) -> &Receiver<T> {
149    unsafe { &**(*self.rx.get().deref()).assume_init_ref() }
150  }
151
152  #[cfg(not(loom))]
153  #[inline(always)]
154  pub(crate) unsafe fn with_rx_mut<F, R>(&self, f: F) -> R
155  where
156    F: FnOnce(*mut MaybeUninit<*const Receiver<T>>) -> R,
157  {
158    f(self.rx.get())
159  }
160
161  #[cfg(loom)]
162  #[inline(always)]
163  pub(crate) unsafe fn with_rx_mut<F, R>(&self, f: F) -> R
164  where
165    F: FnOnce(*mut MaybeUninit<*const Receiver<T>>) -> R,
166  {
167    self.rx.get_mut().with(f)
168  }
169
170  #[cfg(not(loom))]
171  #[inline(always)]
172  pub fn task(&self) -> &T::Task {
173    unsafe { (*self.task.get()).assume_init_ref() }
174  }
175
176  #[cfg(not(loom))]
177  #[inline(always)]
178  pub(crate) unsafe fn with_task_mut<F, R>(&self, f: F) -> R
179  where
180    F: FnOnce(*mut MaybeUninit<T::Task>) -> R,
181  {
182    f(self.task.get())
183  }
184
185  #[cfg(loom)]
186  #[inline(always)]
187  pub(crate) unsafe fn with_task_mut<F, R>(&self, f: F) -> R
188  where
189    F: FnOnce(*mut MaybeUninit<T::Task>) -> R,
190  {
191    self.task.get_mut().with(f)
192  }
193
194  #[inline(always)]
195  pub(crate) unsafe fn take_task_unchecked(&self) -> T::Task {
196    self.with_task_mut(|val| std::mem::replace(&mut *val, MaybeUninit::uninit()).assume_init())
197  }
198
199  /// Set value in receiver and wake if the receiver isn't already dropped. This takes &self because
200  /// [`TaskRef`] by design is never dropped
201  pub(crate) unsafe fn resolve_unchecked(&self, value: T::Value) {
202    let state = self.with_state(|val| (*val).fetch_or(SETTING_VALUE, Ordering::Release));
203
204    if (state & RX_DROPPED).eq(&0) {
205      let rx = self.rx();
206      rx.with_value_mut(|val| {
207        val.write(MaybeUninit::new(value));
208      });
209      rx.waker.wake_by_ref();
210      self.with_state(|val| {
211        (*val).fetch_xor(SETTING_VALUE | VALUE_SET, Ordering::Release);
212      });
213    }
214  }
215}
216
217#[cfg(not(loom))]
218impl<T> Debug for TaskRef<T>
219where
220  T: TaskQueue,
221  <T as TaskQueue>::Task: Debug,
222{
223  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224    write!(f, "{:?}", self.task())
225  }
226}
227
228unsafe impl<T> Send for TaskRef<T> where T: TaskQueue {}
229unsafe impl<T> Sync for TaskRef<T> where T: TaskQueue {}
230
231#[cfg_attr(docsrs, doc(cfg(feature = "diesel-associations")))]
232#[cfg(feature = "diesel-associations")]
233impl<T, Parent> BelongsTo<Parent> for TaskRef<T>
234where
235  T: TaskQueue,
236  T::Task: BelongsTo<Parent>,
237{
238  type ForeignKey = <T::Task as BelongsTo<Parent>>::ForeignKey;
239
240  type ForeignKeyColumn = <T::Task as BelongsTo<Parent>>::ForeignKeyColumn;
241
242  fn foreign_key(&self) -> Option<&Self::ForeignKey> {
243    self.task().foreign_key()
244  }
245
246  fn foreign_key_column() -> Self::ForeignKeyColumn {
247    <T::Task as BelongsTo<Parent>>::foreign_key_column()
248  }
249}
250
251#[cfg_attr(docsrs, doc(cfg(feature = "redis-args")))]
252#[cfg(feature = "redis-args")]
253impl<T> ToRedisArgs for TaskRef<T>
254where
255  T: TaskQueue,
256  T::Task: ToRedisArgs,
257{
258  fn write_redis_args<W>(&self, out: &mut W)
259  where
260    W: ?Sized + RedisWrite,
261  {
262    self.task().write_redis_args(out)
263  }
264}
265
266#[pin_project]
267pub(crate) struct Receiver<T: TaskQueue> {
268  state: *const AtomicUsize,
269  value: UnsafeCell<MaybeUninit<T::Value>>,
270  waker: Waker,
271  pin: PhantomPinned,
272}
273
274impl<T> Receiver<T>
275where
276  T: TaskQueue,
277{
278  pub(crate) fn new(state: *const AtomicUsize, waker: Waker) -> Self {
279    Receiver {
280      state,
281      value: UnsafeCell::new(MaybeUninit::uninit()),
282      waker,
283      pin: PhantomPinned,
284    }
285  }
286
287  #[inline(always)]
288  fn state(&self) -> &AtomicUsize {
289    unsafe { &*self.state }
290  }
291
292  #[cfg(not(loom))]
293  #[inline(always)]
294  unsafe fn with_value_mut<F, R>(&self, f: F) -> R
295  where
296    F: FnOnce(*mut MaybeUninit<T::Value>) -> R,
297  {
298    f(self.value.get())
299  }
300
301  #[cfg(loom)]
302  #[inline(always)]
303  unsafe fn with_value_mut<F, R>(&self, f: F) -> R
304  where
305    F: FnOnce(*mut MaybeUninit<T::Value>) -> R,
306  {
307    self.value.get_mut().with(f)
308  }
309}
310
311unsafe impl<T> Send for Receiver<T> where T: TaskQueue {}
312
313#[pin_project(project = StateProj)]
314pub(crate) enum State<T: TaskQueue> {
315  Unbatched { task: T::Task },
316  Batched(#[pin] Receiver<T>),
317  Received,
318}
319
320/// An automatically batched task
321#[pin_project(project = TaskProj, PinnedDrop)]
322pub struct BatchedTask<T: TaskQueue, const N: usize = 1024> {
323  pub(crate) state: State<T>,
324}
325
326impl<T, const N: usize> BatchedTask<T, N>
327where
328  T: TaskQueue,
329  T: LocalQueue<N, BufferCell = TaskRef<T>>,
330{
331  /// Create a new auto batched task
332  pub fn new(task: T::Task) -> Self {
333    BatchedTask {
334      state: State::Unbatched { task },
335    }
336  }
337}
338
339#[cfg(not(loom))]
340impl<T, const N: usize> Future for BatchedTask<T, N>
341where
342  T: TaskQueue,
343  T: LocalQueue<N, BufferCell = TaskRef<T>>,
344{
345  type Output = T::Value;
346
347  fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
348    let this = self.as_mut().project();
349
350    match this.state {
351      State::Unbatched { task: _ } => {
352        T::queue().with(|queue| unsafe {
353          let assignment = queue.enqueue(|task_ref| {
354            let task = match mem::replace(
355              this.state,
356              State::Batched(Receiver::new(task_ref.state_ptr(), cx.waker().to_owned())),
357            ) {
358              State::Unbatched { task } => task,
359              _ => unreachable_unchecked(),
360            };
361
362            let rx = match this.state {
363              State::Batched(batched) => addr_of!(*batched),
364              _ => unreachable_unchecked(),
365            };
366
367            task_ref.init(task, rx)
368          });
369
370          match assignment {
371            Ok(Some(assignment)) => {
372              spawn(async move {
373                T::batch_process::<N>(assignment).await;
374              });
375            }
376            Ok(None) => {}
377            Err(QueueFull) => {
378              queue.pending.push(cx.waker().to_owned());
379            }
380          }
381        });
382
383        Poll::Pending
384      }
385      State::Batched(_) => {
386        let value = match mem::replace(this.state, State::Received) {
387          State::Batched(rx) => unsafe { rx.with_value_mut(|val| (*val).assume_init_read()) },
388          _ => unsafe { unreachable_unchecked() },
389        };
390
391        Poll::Ready(value)
392      }
393      // If already received, block forever. See https://doc.rust-lang.org/std/future/trait.Future.html#panics
394      State::Received => Poll::Pending,
395    }
396  }
397}
398
399#[cfg(not(loom))]
400#[pinned_drop]
401impl<T, const N: usize> PinnedDrop for BatchedTask<T, N>
402where
403  T: TaskQueue,
404{
405  fn drop(self: Pin<&mut Self>) {
406    if let State::Batched(rx) = &self.state {
407      let mut state = rx.state().fetch_or(RX_DROPPED, Ordering::AcqRel);
408      let mut spin = SpinWait::new();
409
410      // This cannot be safely deallocated until after the value is set
411      while state & SETTING_VALUE == SETTING_VALUE {
412        spin.spin();
413        state = rx.state().load(Ordering::Acquire);
414      }
415
416      if needs_drop::<T::Task>() && (state & VALUE_SET).eq(&VALUE_SET) {
417        unsafe {
418          rx.with_value_mut(|val| {
419            (*val).assume_init_drop();
420          });
421        }
422      }
423    }
424  }
425}
426
427#[cfg(loom)]
428#[pinned_drop]
429impl<T, const N: usize> PinnedDrop for BatchedTask<T, N>
430where
431  T: TaskQueue,
432{
433  fn drop(self: Pin<&mut Self>) {
434    if let State::Batched(rx) = &self.state {
435      let mut state = rx.state().fetch_or(RX_DROPPED, Ordering::AcqRel);
436
437      // This cannot be safely deallocated until after the value is set
438      while state & SETTING_VALUE == SETTING_VALUE {
439        loom::thread::yield_now();
440        state = rx.state().load(Ordering::Acquire);
441      }
442
443      if needs_drop::<T::Task>() && (state & VALUE_SET).eq(&VALUE_SET) {
444        unsafe {
445          rx.with_value_mut(|val| {
446            (*val).assume_init_drop();
447          });
448        }
449      }
450    }
451  }
452}
453
454#[pin_project(project_replace = EnqueueOwn)]
455pub(crate) enum BackgroundEnqueue<'a, T: BackgroundQueue, const N: usize> {
456  Pending(T::Task, PhantomData<&'a ()>),
457  Enqueued,
458}
459
460impl<T, const N: usize> BackgroundEnqueue<'_, T, N>
461where
462  T: BackgroundQueue,
463{
464  pub(crate) fn new(task: T::Task) -> Self {
465    BackgroundEnqueue::Pending(task, PhantomData)
466  }
467}
468
469#[cfg(not(loom))]
470impl<'a, T, const N: usize> Future for BackgroundEnqueue<'a, T, N>
471where
472  T: BackgroundQueue,
473  T: LocalQueue<N, BufferCell = BufferCell<T::Task>>,
474{
475  type Output = Option<UnboundedRange<'a, T::Task, N>>;
476  fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
477    match self.as_mut().project_replace(BackgroundEnqueue::Enqueued) {
478      EnqueueOwn::Pending(task, _) => T::queue().with(|queue| match unsafe { queue.push(task) } {
479        Ok(assignment) => Poll::Ready(assignment),
480        Err(task) => {
481          queue.pending.push(cx.waker().to_owned());
482          self.project_replace(BackgroundEnqueue::Pending(task, PhantomData));
483          Poll::Pending
484        }
485      }),
486      EnqueueOwn::Enqueued => Poll::Ready(None),
487    }
488  }
489}
490#[pin_project(project = ReduceProj)]
491pub struct BatchReduce<'a, T, F, R, const N: usize>
492where
493  T: BatchReducer,
494  F: for<'b> FnOnce(BufferIter<'b, T::Task, N>) -> R + Send,
495{
496  state: ReduceState<'a, T, F, R, N>,
497  pin: PhantomPinned,
498}
499
500impl<T, F, R, const N: usize> BatchReduce<'_, T, F, R, N>
501where
502  T: BatchReducer,
503  F: for<'a> FnOnce(BufferIter<'a, T::Task, N>) -> R + Send,
504{
505  pub(crate) fn new(task: T::Task, reducer: F) -> Self {
506    BatchReduce {
507      state: ReduceState::Unbatched { task, reducer },
508      pin: PhantomPinned,
509    }
510  }
511}
512
513enum ReduceState<'a, T, F, R, const N: usize>
514where
515  T: BatchReducer,
516  F: for<'b> FnOnce(BufferIter<'b, T::Task, N>) -> R + Send,
517{
518  Unbatched {
519    task: T::Task,
520    reducer: F,
521  },
522  Collecting {
523    batch: UnboundedRange<'a, T::Task, N>,
524    reducer: F,
525  },
526  Batched,
527}
528
529#[cfg(not(loom))]
530impl<T, F, R, const N: usize> Future for BatchReduce<'_, T, F, R, N>
531where
532  T: BatchReducer,
533  T: LocalQueue<N, BufferCell = BufferCell<T::Task>>,
534  F: for<'b> FnOnce(BufferIter<'b, T::Task, N>) -> R + Send,
535{
536  type Output = Option<R>;
537  fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
538    let this = self.as_mut().project();
539
540    match this.state {
541      ReduceState::Unbatched {
542        task: _,
543        reducer: _,
544      } => match mem::replace(this.state, ReduceState::Batched) {
545        ReduceState::Unbatched { task, reducer } => {
546          T::queue().with(|queue| match unsafe { queue.push(task) } {
547            Ok(Some(batch)) => {
548              let _ = mem::replace(this.state, ReduceState::Collecting { batch, reducer });
549              cx.waker().wake_by_ref();
550              Poll::Pending
551            }
552            Ok(None) => Poll::Ready(None),
553            Err(task) => {
554              let _ = mem::replace(this.state, ReduceState::Unbatched { task, reducer });
555              queue.pending.push(cx.waker().to_owned());
556              Poll::Pending
557            }
558          })
559        }
560        _ => unsafe {
561          unreachable_unchecked();
562        },
563      },
564      ReduceState::Collecting {
565        batch: _,
566        reducer: _,
567      } => match mem::replace(this.state, ReduceState::Batched) {
568        ReduceState::Collecting { batch, reducer } => {
569          Poll::Ready(Some(reducer(batch.into_bounded().into_iter())))
570        }
571        _ => unsafe {
572          unreachable_unchecked();
573        },
574      },
575      ReduceState::Batched => Poll::Ready(None),
576    }
577  }
578}
579
580#[pin_project(project = CollectProj)]
581pub struct BatchCollect<'a, T, const N: usize>
582where
583  T: BatchReducer,
584{
585  state: CollectState<'a, T, N>,
586  pin: PhantomPinned,
587}
588
589impl<T, const N: usize> BatchCollect<'_, T, N>
590where
591  T: BatchReducer,
592{
593  pub(crate) fn new(task: T::Task) -> Self {
594    BatchCollect {
595      state: CollectState::Unbatched { task },
596      pin: PhantomPinned,
597    }
598  }
599}
600enum CollectState<'a, T, const N: usize>
601where
602  T: BatchReducer,
603{
604  Unbatched {
605    task: T::Task,
606  },
607  Collecting {
608    batch: UnboundedRange<'a, T::Task, N>,
609  },
610  Batched,
611}
612
613#[cfg(not(loom))]
614impl<T, const N: usize> Future for BatchCollect<'_, T, N>
615where
616  T: BatchReducer,
617  T: LocalQueue<N, BufferCell = BufferCell<T::Task>>,
618{
619  type Output = Option<Vec<T::Task>>;
620  fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
621    let this = self.as_mut().project();
622
623    match this.state {
624      CollectState::Unbatched { task: _ } => {
625        match mem::replace(this.state, CollectState::Batched) {
626          CollectState::Unbatched { task } => {
627            T::queue().with(|queue| match unsafe { queue.push(task) } {
628              Ok(Some(batch)) => {
629                let _ = mem::replace(this.state, CollectState::Collecting { batch });
630                cx.waker().wake_by_ref();
631                Poll::Pending
632              }
633              Ok(None) => Poll::Ready(None),
634              Err(task) => {
635                let _ = mem::replace(this.state, CollectState::Unbatched { task });
636                queue.pending.push(cx.waker().to_owned());
637                Poll::Pending
638              }
639            })
640          }
641          _ => unsafe {
642            unreachable_unchecked();
643          },
644        }
645      }
646      CollectState::Collecting { batch: _ } => {
647        match mem::replace(this.state, CollectState::Batched) {
648          CollectState::Collecting { batch } => Poll::Ready(Some(batch.into_bounded().to_vec())),
649          _ => unsafe {
650            unreachable_unchecked();
651          },
652        }
653      }
654      CollectState::Batched => Poll::Ready(None),
655    }
656  }
657}