1#[cfg(not(loom))]
2use std::{
3 cell::UnsafeCell,
4 fmt,
5 fmt::Debug,
6 future::Future,
7 hint::unreachable_unchecked,
8 mem,
9 ops::Deref,
10 ptr::addr_of,
11 sync::atomic::{AtomicUsize, Ordering},
12 task::{Context, Poll},
13};
14use std::{
15 marker::{PhantomData, PhantomPinned},
16 mem::{needs_drop, MaybeUninit},
17 pin::Pin,
18 task::Waker,
19};
20
21#[cfg(feature = "diesel-associations")]
22use diesel::associations::BelongsTo;
23#[cfg(loom)]
24use loom::{
25 cell::UnsafeCell,
26 sync::atomic::{AtomicUsize, Ordering},
27};
28#[cfg(not(loom))]
29use parking_lot_core::SpinWait;
30use pin_project::{pin_project, pinned_drop};
31#[cfg(feature = "redis-args")]
32use redis::{RedisWrite, ToRedisArgs};
33#[cfg(not(loom))]
34use tokio::task::spawn;
35
36use crate::{
37 assignment::{BufferIter, UnboundedRange},
38 queue::{LocalQueue, TaskQueue},
39 BackgroundQueue, BatchReducer,
40};
41#[cfg(not(loom))]
42use crate::{queue::QueueFull, BufferCell};
43
44const SETTING_VALUE: usize = 1 << 0;
45const VALUE_SET: usize = 1 << 1;
46const RX_DROPPED: usize = 1 << 2;
47
48pub struct TaskRef<T: TaskQueue> {
50 state: UnsafeCell<AtomicUsize>,
51 rx: UnsafeCell<MaybeUninit<*const Receiver<T>>>,
52 task: UnsafeCell<MaybeUninit<T::Task>>,
53}
54
55#[cfg(not(loom))]
56impl<T> Deref for TaskRef<T>
57where
58 T: TaskQueue,
59{
60 type Target = T::Task;
61 fn deref(&self) -> &Self::Target {
62 self.task()
63 }
64}
65
66#[cfg(not(loom))]
67impl<T> PartialEq for TaskRef<T>
68where
69 T: TaskQueue,
70 T::Task: PartialEq,
71{
72 fn eq(&self, other: &Self) -> bool {
73 self.task().eq(other.task())
74 }
75}
76
77#[cfg(not(loom))]
78impl<T> PartialOrd for TaskRef<T>
79where
80 T: TaskQueue,
81 T::Task: PartialOrd,
82{
83 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
84 self.task().partial_cmp(other.task())
85 }
86}
87
88impl<T> TaskRef<T>
89where
90 T: TaskQueue,
91{
92 pub(crate) fn new_uninit() -> Self {
93 TaskRef {
94 state: UnsafeCell::new(AtomicUsize::new(0)),
95 rx: UnsafeCell::new(MaybeUninit::uninit()),
96 task: UnsafeCell::new(MaybeUninit::uninit()),
97 }
98 }
99 #[cfg(not(loom))]
100 #[inline(always)]
101 pub(crate) fn with_state<F, R>(&self, f: F) -> R
102 where
103 F: FnOnce(*const AtomicUsize) -> R,
104 {
105 f(self.state.get())
106 }
107
108 #[cfg(loom)]
109 #[inline(always)]
110 pub(crate) fn with_state<F, R>(&self, f: F) -> R
111 where
112 F: FnOnce(*const AtomicUsize) -> R,
113 {
114 self.state.get().with(f)
115 }
116
117 #[inline(always)]
118 pub(crate) fn state_ptr(&self) -> *const AtomicUsize {
119 self.with_state(std::convert::identity)
120 }
121
122 #[cfg(not(loom))]
123 #[inline(always)]
124 unsafe fn set_state_unsync(&self, state: usize) {
125 *(*self.state.get()).get_mut() = state;
126 }
127
128 #[cfg(loom)]
129 #[inline(always)]
130 unsafe fn set_state_unsync(&self, state: usize) {
131 self.state.get_mut().deref().with_mut(|val| *val = state);
132 }
133
134 pub(crate) unsafe fn init(&self, task: T::Task, rx: *const Receiver<T>) {
135 self.set_state_unsync(0);
136 self.with_rx_mut(|val| val.write(MaybeUninit::new(rx)));
137 self.with_task_mut(|val| val.write(MaybeUninit::new(task)));
138 }
139
140 #[cfg(not(loom))]
141 #[inline(always)]
142 pub(crate) fn rx(&self) -> &Receiver<T> {
143 unsafe { &**(*self.rx.get()).assume_init_ref() }
144 }
145
146 #[cfg(loom)]
147 #[inline(always)]
148 pub(crate) fn rx(&self) -> &Receiver<T> {
149 unsafe { &**(*self.rx.get().deref()).assume_init_ref() }
150 }
151
152 #[cfg(not(loom))]
153 #[inline(always)]
154 pub(crate) unsafe fn with_rx_mut<F, R>(&self, f: F) -> R
155 where
156 F: FnOnce(*mut MaybeUninit<*const Receiver<T>>) -> R,
157 {
158 f(self.rx.get())
159 }
160
161 #[cfg(loom)]
162 #[inline(always)]
163 pub(crate) unsafe fn with_rx_mut<F, R>(&self, f: F) -> R
164 where
165 F: FnOnce(*mut MaybeUninit<*const Receiver<T>>) -> R,
166 {
167 self.rx.get_mut().with(f)
168 }
169
170 #[cfg(not(loom))]
171 #[inline(always)]
172 pub fn task(&self) -> &T::Task {
173 unsafe { (*self.task.get()).assume_init_ref() }
174 }
175
176 #[cfg(not(loom))]
177 #[inline(always)]
178 pub(crate) unsafe fn with_task_mut<F, R>(&self, f: F) -> R
179 where
180 F: FnOnce(*mut MaybeUninit<T::Task>) -> R,
181 {
182 f(self.task.get())
183 }
184
185 #[cfg(loom)]
186 #[inline(always)]
187 pub(crate) unsafe fn with_task_mut<F, R>(&self, f: F) -> R
188 where
189 F: FnOnce(*mut MaybeUninit<T::Task>) -> R,
190 {
191 self.task.get_mut().with(f)
192 }
193
194 #[inline(always)]
195 pub(crate) unsafe fn take_task_unchecked(&self) -> T::Task {
196 self.with_task_mut(|val| std::mem::replace(&mut *val, MaybeUninit::uninit()).assume_init())
197 }
198
199 pub(crate) unsafe fn resolve_unchecked(&self, value: T::Value) {
202 let state = self.with_state(|val| (*val).fetch_or(SETTING_VALUE, Ordering::Release));
203
204 if (state & RX_DROPPED).eq(&0) {
205 let rx = self.rx();
206 rx.with_value_mut(|val| {
207 val.write(MaybeUninit::new(value));
208 });
209 rx.waker.wake_by_ref();
210 self.with_state(|val| {
211 (*val).fetch_xor(SETTING_VALUE | VALUE_SET, Ordering::Release);
212 });
213 }
214 }
215}
216
217#[cfg(not(loom))]
218impl<T> Debug for TaskRef<T>
219where
220 T: TaskQueue,
221 <T as TaskQueue>::Task: Debug,
222{
223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224 write!(f, "{:?}", self.task())
225 }
226}
227
228unsafe impl<T> Send for TaskRef<T> where T: TaskQueue {}
229unsafe impl<T> Sync for TaskRef<T> where T: TaskQueue {}
230
231#[cfg_attr(docsrs, doc(cfg(feature = "diesel-associations")))]
232#[cfg(feature = "diesel-associations")]
233impl<T, Parent> BelongsTo<Parent> for TaskRef<T>
234where
235 T: TaskQueue,
236 T::Task: BelongsTo<Parent>,
237{
238 type ForeignKey = <T::Task as BelongsTo<Parent>>::ForeignKey;
239
240 type ForeignKeyColumn = <T::Task as BelongsTo<Parent>>::ForeignKeyColumn;
241
242 fn foreign_key(&self) -> Option<&Self::ForeignKey> {
243 self.task().foreign_key()
244 }
245
246 fn foreign_key_column() -> Self::ForeignKeyColumn {
247 <T::Task as BelongsTo<Parent>>::foreign_key_column()
248 }
249}
250
251#[cfg_attr(docsrs, doc(cfg(feature = "redis-args")))]
252#[cfg(feature = "redis-args")]
253impl<T> ToRedisArgs for TaskRef<T>
254where
255 T: TaskQueue,
256 T::Task: ToRedisArgs,
257{
258 fn write_redis_args<W>(&self, out: &mut W)
259 where
260 W: ?Sized + RedisWrite,
261 {
262 self.task().write_redis_args(out)
263 }
264}
265
266#[pin_project]
267pub(crate) struct Receiver<T: TaskQueue> {
268 state: *const AtomicUsize,
269 value: UnsafeCell<MaybeUninit<T::Value>>,
270 waker: Waker,
271 pin: PhantomPinned,
272}
273
274impl<T> Receiver<T>
275where
276 T: TaskQueue,
277{
278 pub(crate) fn new(state: *const AtomicUsize, waker: Waker) -> Self {
279 Receiver {
280 state,
281 value: UnsafeCell::new(MaybeUninit::uninit()),
282 waker,
283 pin: PhantomPinned,
284 }
285 }
286
287 #[inline(always)]
288 fn state(&self) -> &AtomicUsize {
289 unsafe { &*self.state }
290 }
291
292 #[cfg(not(loom))]
293 #[inline(always)]
294 unsafe fn with_value_mut<F, R>(&self, f: F) -> R
295 where
296 F: FnOnce(*mut MaybeUninit<T::Value>) -> R,
297 {
298 f(self.value.get())
299 }
300
301 #[cfg(loom)]
302 #[inline(always)]
303 unsafe fn with_value_mut<F, R>(&self, f: F) -> R
304 where
305 F: FnOnce(*mut MaybeUninit<T::Value>) -> R,
306 {
307 self.value.get_mut().with(f)
308 }
309}
310
311unsafe impl<T> Send for Receiver<T> where T: TaskQueue {}
312
313#[pin_project(project = StateProj)]
314pub(crate) enum State<T: TaskQueue> {
315 Unbatched { task: T::Task },
316 Batched(#[pin] Receiver<T>),
317 Received,
318}
319
320#[pin_project(project = TaskProj, PinnedDrop)]
322pub struct BatchedTask<T: TaskQueue, const N: usize = 1024> {
323 pub(crate) state: State<T>,
324}
325
326impl<T, const N: usize> BatchedTask<T, N>
327where
328 T: TaskQueue,
329 T: LocalQueue<N, BufferCell = TaskRef<T>>,
330{
331 pub fn new(task: T::Task) -> Self {
333 BatchedTask {
334 state: State::Unbatched { task },
335 }
336 }
337}
338
339#[cfg(not(loom))]
340impl<T, const N: usize> Future for BatchedTask<T, N>
341where
342 T: TaskQueue,
343 T: LocalQueue<N, BufferCell = TaskRef<T>>,
344{
345 type Output = T::Value;
346
347 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
348 let this = self.as_mut().project();
349
350 match this.state {
351 State::Unbatched { task: _ } => {
352 T::queue().with(|queue| unsafe {
353 let assignment = queue.enqueue(|task_ref| {
354 let task = match mem::replace(
355 this.state,
356 State::Batched(Receiver::new(task_ref.state_ptr(), cx.waker().to_owned())),
357 ) {
358 State::Unbatched { task } => task,
359 _ => unreachable_unchecked(),
360 };
361
362 let rx = match this.state {
363 State::Batched(batched) => addr_of!(*batched),
364 _ => unreachable_unchecked(),
365 };
366
367 task_ref.init(task, rx)
368 });
369
370 match assignment {
371 Ok(Some(assignment)) => {
372 spawn(async move {
373 T::batch_process::<N>(assignment).await;
374 });
375 }
376 Ok(None) => {}
377 Err(QueueFull) => {
378 queue.pending.push(cx.waker().to_owned());
379 }
380 }
381 });
382
383 Poll::Pending
384 }
385 State::Batched(_) => {
386 let value = match mem::replace(this.state, State::Received) {
387 State::Batched(rx) => unsafe { rx.with_value_mut(|val| (*val).assume_init_read()) },
388 _ => unsafe { unreachable_unchecked() },
389 };
390
391 Poll::Ready(value)
392 }
393 State::Received => Poll::Pending,
395 }
396 }
397}
398
399#[cfg(not(loom))]
400#[pinned_drop]
401impl<T, const N: usize> PinnedDrop for BatchedTask<T, N>
402where
403 T: TaskQueue,
404{
405 fn drop(self: Pin<&mut Self>) {
406 if let State::Batched(rx) = &self.state {
407 let mut state = rx.state().fetch_or(RX_DROPPED, Ordering::AcqRel);
408 let mut spin = SpinWait::new();
409
410 while state & SETTING_VALUE == SETTING_VALUE {
412 spin.spin();
413 state = rx.state().load(Ordering::Acquire);
414 }
415
416 if needs_drop::<T::Task>() && (state & VALUE_SET).eq(&VALUE_SET) {
417 unsafe {
418 rx.with_value_mut(|val| {
419 (*val).assume_init_drop();
420 });
421 }
422 }
423 }
424 }
425}
426
427#[cfg(loom)]
428#[pinned_drop]
429impl<T, const N: usize> PinnedDrop for BatchedTask<T, N>
430where
431 T: TaskQueue,
432{
433 fn drop(self: Pin<&mut Self>) {
434 if let State::Batched(rx) = &self.state {
435 let mut state = rx.state().fetch_or(RX_DROPPED, Ordering::AcqRel);
436
437 while state & SETTING_VALUE == SETTING_VALUE {
439 loom::thread::yield_now();
440 state = rx.state().load(Ordering::Acquire);
441 }
442
443 if needs_drop::<T::Task>() && (state & VALUE_SET).eq(&VALUE_SET) {
444 unsafe {
445 rx.with_value_mut(|val| {
446 (*val).assume_init_drop();
447 });
448 }
449 }
450 }
451 }
452}
453
454#[pin_project(project_replace = EnqueueOwn)]
455pub(crate) enum BackgroundEnqueue<'a, T: BackgroundQueue, const N: usize> {
456 Pending(T::Task, PhantomData<&'a ()>),
457 Enqueued,
458}
459
460impl<T, const N: usize> BackgroundEnqueue<'_, T, N>
461where
462 T: BackgroundQueue,
463{
464 pub(crate) fn new(task: T::Task) -> Self {
465 BackgroundEnqueue::Pending(task, PhantomData)
466 }
467}
468
469#[cfg(not(loom))]
470impl<'a, T, const N: usize> Future for BackgroundEnqueue<'a, T, N>
471where
472 T: BackgroundQueue,
473 T: LocalQueue<N, BufferCell = BufferCell<T::Task>>,
474{
475 type Output = Option<UnboundedRange<'a, T::Task, N>>;
476 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
477 match self.as_mut().project_replace(BackgroundEnqueue::Enqueued) {
478 EnqueueOwn::Pending(task, _) => T::queue().with(|queue| match unsafe { queue.push(task) } {
479 Ok(assignment) => Poll::Ready(assignment),
480 Err(task) => {
481 queue.pending.push(cx.waker().to_owned());
482 self.project_replace(BackgroundEnqueue::Pending(task, PhantomData));
483 Poll::Pending
484 }
485 }),
486 EnqueueOwn::Enqueued => Poll::Ready(None),
487 }
488 }
489}
490#[pin_project(project = ReduceProj)]
491pub struct BatchReduce<'a, T, F, R, const N: usize>
492where
493 T: BatchReducer,
494 F: for<'b> FnOnce(BufferIter<'b, T::Task, N>) -> R + Send,
495{
496 state: ReduceState<'a, T, F, R, N>,
497 pin: PhantomPinned,
498}
499
500impl<T, F, R, const N: usize> BatchReduce<'_, T, F, R, N>
501where
502 T: BatchReducer,
503 F: for<'a> FnOnce(BufferIter<'a, T::Task, N>) -> R + Send,
504{
505 pub(crate) fn new(task: T::Task, reducer: F) -> Self {
506 BatchReduce {
507 state: ReduceState::Unbatched { task, reducer },
508 pin: PhantomPinned,
509 }
510 }
511}
512
513enum ReduceState<'a, T, F, R, const N: usize>
514where
515 T: BatchReducer,
516 F: for<'b> FnOnce(BufferIter<'b, T::Task, N>) -> R + Send,
517{
518 Unbatched {
519 task: T::Task,
520 reducer: F,
521 },
522 Collecting {
523 batch: UnboundedRange<'a, T::Task, N>,
524 reducer: F,
525 },
526 Batched,
527}
528
529#[cfg(not(loom))]
530impl<T, F, R, const N: usize> Future for BatchReduce<'_, T, F, R, N>
531where
532 T: BatchReducer,
533 T: LocalQueue<N, BufferCell = BufferCell<T::Task>>,
534 F: for<'b> FnOnce(BufferIter<'b, T::Task, N>) -> R + Send,
535{
536 type Output = Option<R>;
537 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
538 let this = self.as_mut().project();
539
540 match this.state {
541 ReduceState::Unbatched {
542 task: _,
543 reducer: _,
544 } => match mem::replace(this.state, ReduceState::Batched) {
545 ReduceState::Unbatched { task, reducer } => {
546 T::queue().with(|queue| match unsafe { queue.push(task) } {
547 Ok(Some(batch)) => {
548 let _ = mem::replace(this.state, ReduceState::Collecting { batch, reducer });
549 cx.waker().wake_by_ref();
550 Poll::Pending
551 }
552 Ok(None) => Poll::Ready(None),
553 Err(task) => {
554 let _ = mem::replace(this.state, ReduceState::Unbatched { task, reducer });
555 queue.pending.push(cx.waker().to_owned());
556 Poll::Pending
557 }
558 })
559 }
560 _ => unsafe {
561 unreachable_unchecked();
562 },
563 },
564 ReduceState::Collecting {
565 batch: _,
566 reducer: _,
567 } => match mem::replace(this.state, ReduceState::Batched) {
568 ReduceState::Collecting { batch, reducer } => {
569 Poll::Ready(Some(reducer(batch.into_bounded().into_iter())))
570 }
571 _ => unsafe {
572 unreachable_unchecked();
573 },
574 },
575 ReduceState::Batched => Poll::Ready(None),
576 }
577 }
578}
579
580#[pin_project(project = CollectProj)]
581pub struct BatchCollect<'a, T, const N: usize>
582where
583 T: BatchReducer,
584{
585 state: CollectState<'a, T, N>,
586 pin: PhantomPinned,
587}
588
589impl<T, const N: usize> BatchCollect<'_, T, N>
590where
591 T: BatchReducer,
592{
593 pub(crate) fn new(task: T::Task) -> Self {
594 BatchCollect {
595 state: CollectState::Unbatched { task },
596 pin: PhantomPinned,
597 }
598 }
599}
600enum CollectState<'a, T, const N: usize>
601where
602 T: BatchReducer,
603{
604 Unbatched {
605 task: T::Task,
606 },
607 Collecting {
608 batch: UnboundedRange<'a, T::Task, N>,
609 },
610 Batched,
611}
612
613#[cfg(not(loom))]
614impl<T, const N: usize> Future for BatchCollect<'_, T, N>
615where
616 T: BatchReducer,
617 T: LocalQueue<N, BufferCell = BufferCell<T::Task>>,
618{
619 type Output = Option<Vec<T::Task>>;
620 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
621 let this = self.as_mut().project();
622
623 match this.state {
624 CollectState::Unbatched { task: _ } => {
625 match mem::replace(this.state, CollectState::Batched) {
626 CollectState::Unbatched { task } => {
627 T::queue().with(|queue| match unsafe { queue.push(task) } {
628 Ok(Some(batch)) => {
629 let _ = mem::replace(this.state, CollectState::Collecting { batch });
630 cx.waker().wake_by_ref();
631 Poll::Pending
632 }
633 Ok(None) => Poll::Ready(None),
634 Err(task) => {
635 let _ = mem::replace(this.state, CollectState::Unbatched { task });
636 queue.pending.push(cx.waker().to_owned());
637 Poll::Pending
638 }
639 })
640 }
641 _ => unsafe {
642 unreachable_unchecked();
643 },
644 }
645 }
646 CollectState::Collecting { batch: _ } => {
647 match mem::replace(this.state, CollectState::Batched) {
648 CollectState::Collecting { batch } => Poll::Ready(Some(batch.into_bounded().to_vec())),
649 _ => unsafe {
650 unreachable_unchecked();
651 },
652 }
653 }
654 CollectState::Batched => Poll::Ready(None),
655 }
656 }
657}