1use crossbeam_epoch::{self as epoch, Atomic, Owned};
51use crossbeam_utils::CachePadded;
52
53use futures::executor::block_on;
54use std::{cell::Cell, fmt, marker::PhantomData, mem, ptr, sync::Arc};
55use tokio::sync::oneshot::{channel, Receiver, Sender};
56
57#[cfg(loom)]
58use loom::sync::atomic::{AtomicUsize, Ordering};
59
60#[cfg(not(loom))]
61use std::sync::atomic::{AtomicUsize, Ordering};
62
63const BUFFER_IDX: usize = 1 << 0;
65
66const WRITE_IN_PROGRESS: usize = 1 << 1;
68
69const FLAGS_SHIFT: usize = 1;
71
72const LENGTH_SHIFT: usize = FLAGS_SHIFT + 1;
74
75const MIN_CAP: usize = 64;
77
78struct Buffer<T> {
83 slot: usize,
85
86 ptr: *mut T,
88
89 cap: usize,
91}
92
93unsafe impl<T: Send> Send for Buffer<T> {}
94unsafe impl<T: Send> Sync for Buffer<T> {}
95
96impl<T> Buffer<T> {
97 fn alloc(slot: usize, cap: usize) -> Buffer<T> {
99 debug_assert_eq!(cap, cap.next_power_of_two());
100
101 let mut v = Vec::with_capacity(cap);
102 let ptr = v.as_mut_ptr();
103 mem::forget(v);
104
105 Buffer { slot, ptr, cap }
106 }
107
108 unsafe fn dealloc(self) {
110 drop(Vec::from_raw_parts(self.ptr, 0, self.cap));
111 }
112
113 unsafe fn at(&self, index: usize) -> *mut T {
115 self.ptr.offset((index & (self.cap - 1)) as isize)
117 }
118
119 unsafe fn write(&self, index: usize, task: T) {
121 ptr::write_volatile(self.at(index), task)
122 }
123
124 unsafe fn to_vec(self, length: usize) -> Vec<T> {
125 let Buffer { ptr, cap, .. } = self;
126 Vec::from_raw_parts(ptr, length, cap)
127 }
128}
129
130impl<T> Clone for Buffer<T> {
131 fn clone(&self) -> Buffer<T> {
132 Buffer {
133 slot: self.slot,
134 ptr: self.ptr,
135 cap: self.cap,
136 }
137 }
138}
139
140impl<T> Copy for Buffer<T> {}
141
142fn slot_delta(a: usize, b: usize) -> usize {
143 if a < b {
144 ((usize::MAX - b) >> LENGTH_SHIFT) + (a >> LENGTH_SHIFT)
145 } else {
146 (a >> LENGTH_SHIFT) - (b >> LENGTH_SHIFT)
147 }
148}
149
150struct Inner<T> {
151 slot: AtomicUsize,
152 buffers: (
153 CachePadded<Atomic<Buffer<T>>>,
154 CachePadded<Atomic<Buffer<T>>>,
155 ),
156}
157
158impl<T> Inner<T> {
159 fn get_buffer(&self, slot: usize) -> &CachePadded<Atomic<Buffer<T>>> {
160 if slot & BUFFER_IDX == 0 {
161 &self.buffers.0
162 } else {
163 &self.buffers.1
164 }
165 }
166}
167
168enum Flavor {
191 Unbounded,
192 AutoBatched { batch_size: usize },
193}
194
195pub struct Worker<T> {
196 flavor: Flavor,
197 inner: Arc<CachePadded<Inner<T>>>,
199 buffer: Cell<Buffer<T>>,
201 tx: Cell<Option<Sender<Vec<T>>>>,
203 _marker: PhantomData<*mut ()>,
205}
206
207unsafe impl<T: Send> Send for Worker<T> {}
208
209impl<T> Worker<T> {
210 pub fn new() -> Worker<T> {
220 let buffer = Buffer::alloc(0, MIN_CAP);
221
222 let inner = Arc::new(CachePadded::new(Inner {
223 slot: AtomicUsize::new(0),
224 buffers: (
225 CachePadded::new(Atomic::new(buffer)),
226 CachePadded::new(Atomic::null()),
227 ),
228 }));
229
230 Worker {
231 flavor: Flavor::Unbounded,
232 inner,
233 buffer: Cell::new(buffer),
234 tx: Cell::new(None),
235 _marker: PhantomData,
236 }
237 }
238
239 pub fn auto_batched(batch_size: usize) -> Worker<T> {
249 debug_assert!(batch_size.ge(&64), "batch_size must be at least 64");
250 debug_assert_eq!(
251 batch_size,
252 batch_size.next_power_of_two(),
253 "batch_size must be a power of 2"
254 );
255
256 let buffer = Buffer::alloc(0, MIN_CAP);
257
258 let inner = Arc::new(CachePadded::new(Inner {
259 slot: AtomicUsize::new(0),
260 buffers: (
261 CachePadded::new(Atomic::new(buffer)),
262 CachePadded::new(Atomic::null()),
263 ),
264 }));
265
266 Worker {
267 flavor: Flavor::AutoBatched { batch_size },
268 inner,
269 buffer: Cell::new(buffer),
270 tx: Cell::new(None),
271 _marker: PhantomData,
272 }
273 }
274
275 unsafe fn resize(&self, buffer: &mut Buffer<T>, slot: usize) {
277 let length = slot_delta(slot, buffer.slot);
278
279 let new = Buffer::alloc(buffer.slot, buffer.cap * 2);
281
282 ptr::copy_nonoverlapping(buffer.at(0), new.at(0), length);
283
284 self.buffer.set(new);
285
286 let old = std::mem::replace(buffer, new);
287
288 self
289 .inner
290 .get_buffer(slot)
291 .store(Owned::new(new), Ordering::Release);
292
293 old.dealloc();
294 }
295
296 fn replace_buffer(&self, buffer: &mut Buffer<T>, slot: usize, cap: usize) -> Buffer<T> {
297 let new = Buffer::alloc(slot.to_owned(), cap);
298
299 self
300 .inner
301 .get_buffer(slot)
302 .store(Owned::new(new), Ordering::Release);
303
304 self.buffer.set(new);
305
306 std::mem::replace(buffer, new)
307 }
308
309 pub fn push(&self, task: T) -> Option<Stealer<T>> {
311 let slot = self
312 .inner
313 .slot
314 .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
315
316 let mut buffer = self.buffer.get();
317
318 if ((slot ^ buffer.slot) & BUFFER_IDX).eq(&BUFFER_IDX) {
320 buffer = Buffer::alloc(slot, buffer.cap);
321
322 self
323 .inner
324 .get_buffer(slot)
325 .store(Owned::new(buffer), Ordering::Release);
326
327 self.buffer.set(buffer);
328
329 unsafe {
330 buffer.write(0, task);
331 }
332
333 self
335 .inner
336 .slot
337 .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
338
339 let (tx, rx) = channel();
340 self.tx.set(Some(tx));
341
342 Some(Stealer::Taker(StealHandle {
343 rx,
344 inner: self.inner.clone(),
345 }))
346 } else {
347 let index = slot_delta(slot, buffer.slot);
348
349 match &self.flavor {
350 Flavor::Unbounded if index.eq(&buffer.cap) => {
351 unsafe {
352 self.resize(&mut buffer, slot);
353 buffer.write(index, task);
354 }
355
356 let slot = self
357 .inner
358 .slot
359 .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
360
361 if ((slot ^ buffer.slot) & BUFFER_IDX).eq(&BUFFER_IDX) {
363 let (tx, rx) = channel();
364 let tx = self.tx.replace(Some(tx)).unwrap();
365
366 tx.send(unsafe { buffer.to_vec(index) }).ok();
368
369 Some(Stealer::Taker(StealHandle {
370 rx,
371 inner: self.inner.clone(),
372 }))
373 } else {
374 None
375 }
376 }
377 Flavor::AutoBatched { batch_size } if index.eq(batch_size) => {
378 let old = self.replace_buffer(&mut buffer, slot, *batch_size);
379 let batch = unsafe { old.to_vec(*batch_size) };
380
381 unsafe {
382 buffer.write(0, task);
383 }
384
385 let slot = self
386 .inner
387 .slot
388 .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
389
390 if ((slot ^ buffer.slot) & BUFFER_IDX).eq(&BUFFER_IDX) {
391 let (tx, rx) = channel();
392 let tx = self.tx.replace(Some(tx)).unwrap();
393
394 tx.send(batch).ok();
395
396 Some(Stealer::Taker(StealHandle {
397 rx,
398 inner: self.inner.clone(),
399 }))
400 } else {
401 Some(Stealer::Owner(batch))
402 }
403 }
404 _ if index.eq(&0) => {
405 unsafe {
406 buffer.write(0, task);
407 }
408
409 self
410 .inner
411 .slot
412 .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
413
414 let (tx, rx) = channel();
415 self.tx.set(Some(tx));
416
417 Some(Stealer::Taker(StealHandle {
418 rx,
419 inner: self.inner.clone(),
420 }))
421 }
422 _ => {
423 unsafe {
424 buffer.write(index, task);
425 }
426
427 let slot = self
428 .inner
429 .slot
430 .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
431
432 if ((slot ^ buffer.slot) & BUFFER_IDX).eq(&BUFFER_IDX) {
433 let (tx, rx) = channel();
434 let tx = self.tx.replace(Some(tx)).unwrap();
435
436 tx.send(unsafe { buffer.to_vec(index) }).ok();
438
439 Some(Stealer::Taker(StealHandle {
440 rx,
441 inner: self.inner.clone(),
442 }))
443 } else {
444 None
445 }
446 }
447 }
448 }
449 }
450}
451
452impl<T> Default for Worker<T> {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458impl<T> fmt::Debug for Worker<T> {
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 f.pad("Worker { .. }")
461 }
462}
463
464impl<T> Drop for Worker<T> {
465 fn drop(&mut self) {
466 let slot = self
468 .inner
469 .slot
470 .fetch_add(1 << FLAGS_SHIFT, Ordering::Relaxed);
471
472 let buffer = self.buffer.get();
473
474 if slot & BUFFER_IDX == buffer.slot & BUFFER_IDX {
476 let length = slot_delta(slot, buffer.slot);
477
478 if let Some(tx) = self.tx.replace(None) {
480 if let Err(queue) = tx.send(unsafe { buffer.to_vec(length) }) {
481 drop(queue);
482 }
483 } else {
484 unsafe {
486 for i in 0..length {
488 buffer.at(i).drop_in_place();
489 }
490
491 buffer.dealloc();
493 }
494 }
495 }
496 }
497}
498
499#[doc(hidden)]
500pub struct StealHandle<T> {
501 rx: Receiver<Vec<T>>,
503 inner: Arc<CachePadded<Inner<T>>>,
505}
506
507pub enum Stealer<T> {
509 Owner(Vec<T>),
511 Taker(StealHandle<T>),
513}
514
515unsafe impl<T: Send> Send for Stealer<T> {}
516unsafe impl<T: Send> Sync for Stealer<T> {}
517
518impl<T> Stealer<T> {
519 pub async fn take(self) -> Vec<T> {
521 match self {
522 Stealer::Owner(batch) => batch,
523 Stealer::Taker(StealHandle { rx, inner }) => {
524 let slot = inner.slot.fetch_xor(BUFFER_IDX, Ordering::Relaxed);
525
526 if slot & WRITE_IN_PROGRESS == WRITE_IN_PROGRESS {
528 rx.await.unwrap()
530 } else {
531 let guard = &epoch::pin();
532
533 let buffer = inner.get_buffer(slot).load_consume(guard);
534
535 unsafe {
536 let buffer = *buffer.into_owned();
537 buffer.to_vec(slot_delta(slot, buffer.slot))
538 }
539 }
540 }
541 }
542 }
543
544 pub fn take_blocking(self) -> Vec<T> {
546 match self {
547 Stealer::Owner(batch) => batch,
548 Stealer::Taker(StealHandle { rx, inner }) => {
549 let slot = inner.slot.fetch_xor(BUFFER_IDX, Ordering::Relaxed);
550
551 if slot & WRITE_IN_PROGRESS == WRITE_IN_PROGRESS {
554 block_on(rx).unwrap()
556 } else {
557 let guard = &epoch::pin();
558
559 let buffer = inner.get_buffer(slot).load_consume(guard);
560
561 unsafe {
562 let buffer = *buffer.into_owned();
563 buffer.to_vec(slot_delta(slot, buffer.slot))
564 }
565 }
566 }
567 }
568 }
569}
570
571impl<T> From<Stealer<T>> for Vec<T> {
573 fn from(stealer: Stealer<T>) -> Self {
574 stealer.take_blocking()
575 }
576}
577
578impl<T> fmt::Debug for Stealer<T> {
579 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
580 f.pad("Stealer { .. }")
581 }
582}
583
584#[cfg(all(test))]
585mod tests {
586 use super::*;
587
588 #[cfg(loom)]
589 use loom::thread;
590
591 #[cfg(not(loom))]
592 use std::thread;
593
594 macro_rules! model {
595 ($test:block) => {
596 #[cfg(loom)]
597 loom::model(|| $test);
598
599 #[cfg(not(loom))]
600 $test
601 };
602 }
603
604 #[test]
605 fn slot_wraps_around() {
606 let delta = slot_delta(1 << LENGTH_SHIFT, usize::MAX);
607
608 assert_eq!(delta, 1);
609 }
610
611 #[test]
612 fn it_resizes() {
613 model!({
614 let queue = Worker::new();
615 let stealer = queue.push(0).unwrap();
616
617 for i in 1..128 {
618 queue.push(i);
619 }
620
621 let batch = stealer.take_blocking();
622 let expected = (0..128).collect::<Vec<i32>>();
623
624 assert_eq!(batch, expected);
625 });
626 }
627
628 #[test]
629 fn it_makes_new_stealer_per_batch() {
630 model!({
631 let queue = Worker::new();
632 let stealer = queue.push(0).unwrap();
633
634 queue.push(1);
635 queue.push(2);
636
637 assert_eq!(stealer.take_blocking(), vec![0, 1, 2]);
638
639 let stealer = queue.push(3).unwrap();
640 queue.push(4);
641 queue.push(5);
642
643 assert_eq!(stealer.take_blocking(), vec![3, 4, 5]);
644 });
645 }
646
647 #[test]
648 fn it_auto_batches() {
649 model!({
650 let queue = Worker::auto_batched(64);
651 let mut stealers: Vec<Stealer<i32>> = vec![];
652
653 for i in 0..128 {
654 if let Some(stealer) = queue.push(i) {
655 stealers.push(stealer);
656 }
657 }
658
659 let batch: Vec<i32> = stealers
660 .into_iter()
661 .rev()
662 .flat_map(|stealer| stealer.take_blocking())
663 .collect();
664
665 let expected = (0..128).collect::<Vec<i32>>();
666
667 assert_eq!(batch, expected);
668 });
669 }
670
671 #[cfg(not(loom))]
672 #[tokio::test]
673 async fn stealer_takes() {
674 let queue = Worker::new();
675 let stealer = queue.push(0).unwrap();
676
677 for i in 1..1024 {
678 queue.push(i);
679 }
680
681 let batch = stealer.take().await;
682 let expected = (0..1024).collect::<Vec<i32>>();
683
684 assert_eq!(batch, expected);
685 }
686
687 #[test]
688 fn stealer_takes_blocking() {
689 model!({
690 let queue = Worker::new();
691 let stealer = queue.push(0).unwrap();
692
693 for i in 1..128 {
694 queue.push(i);
695 }
696
697 thread::spawn(move || {
698 stealer.take_blocking();
699 })
700 .join()
701 .unwrap();
702 });
703 }
704
705 #[cfg(not(loom))]
706 #[tokio::test]
707 async fn worker_drops() {
708 let queue = Worker::new();
709 let stealer = queue.push(0).unwrap();
710
711 for i in 1..128 {
712 queue.push(i);
713 }
714
715 drop(queue);
716
717 let batch = stealer.take().await;
718 let expected = (0..128).collect::<Vec<i32>>();
719
720 assert_eq!(batch, expected);
721 }
722
723 #[cfg(loom)]
724 #[tokio::test]
725 async fn worker_drops() {
726 loom::model(|| {
727 let queue = Worker::new();
728 let stealer = queue.push(0).unwrap();
729
730 for i in 1..128 {
731 queue.push(i);
732 }
733
734 drop(queue);
735
736 let batch = stealer.take_blocking();
737 let expected = (0..128).collect::<Vec<i32>>();
738
739 assert_eq!(batch, expected);
740 });
741 }
742}