1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::sync::primitive::{Arc, AtomicUsize, AtomicWaker, Ordering};
use core::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use crossbeam_utils::CachePadded;
/// Creates a worker channel with a Sender and Receiver
pub fn channel() -> (Sender, Receiver) {
let state = Arc::new(State::default());
let sender = Sender(state.clone());
let receiver = Receiver { state, credits: 0 };
(sender, receiver)
}
/// A handle to the receiver side of the worker channel
///
/// This handle is used by the worker to wake up when there is work to do.
pub struct Receiver {
state: Arc<State>,
credits: usize,
}
impl Receiver {
/// Acquires work to be processed for the Receiver
///
/// `None` is returned when there are no more active Senders.
#[inline]
pub async fn acquire(&mut self) -> Option<usize> {
Acquire(self).await
}
/// Polls work to be processed for the receiver
///
/// `None` is returned when there are no more active Senders.
#[inline]
pub fn poll_acquire(&mut self, cx: &mut Context) -> Poll<Option<usize>> {
let state = &*self.state;
macro_rules! acquire {
() => {{
// take the credits that we've been given by the senders
self.credits += state.remaining.swap(0, Ordering::Acquire);
// if we have any credits then return
if self.credits > 0 {
return Poll::Ready(Some(self.credits));
}
}};
}
// first try to acquire credits
acquire!();
// if we didn't get any credits then register the waker
state.receiver.register(cx.waker());
// make one last effort to acquire credits in case a sender submitted some while we were
// registering the waker
acquire!();
// If we're the only ones with a handle to the state then we're done
if state.senders.load(Ordering::Acquire) == 0 {
return Poll::Ready(None);
}
Poll::Pending
}
/// Marks `count` jobs as finished
#[inline]
pub fn finish(&mut self, count: usize) {
debug_assert!(self.credits >= count);
// decrement the number of credits we have
self.credits -= count;
}
}
/// A handle to submit work to be done to a worker receiver
///
/// Multiple Sender handles can be created with `.clone()`.
#[derive(Clone)]
pub struct Sender(Arc<State>);
impl Sender {
/// Submits `count` jobs to be executed by the worker receiver
#[inline]
pub fn submit(&self, count: usize) {
let state = &*self.0;
// increment the work counter
state.remaining.fetch_add(count, Ordering::Release);
// wake up the receiver if possible
state.receiver.wake();
}
}
impl Drop for Sender {
#[inline]
fn drop(&mut self) {
let state = &*self.0;
state.senders.fetch_sub(1, Ordering::Release);
// wake up the receiver to notify that one of the senders has dropped
state.receiver.wake();
}
}
struct State {
remaining: CachePadded<AtomicUsize>,
receiver: AtomicWaker,
senders: CachePadded<AtomicUsize>,
}
impl Default for State {
fn default() -> Self {
Self {
remaining: Default::default(),
receiver: Default::default(),
senders: AtomicUsize::new(1).into(),
}
}
}
struct Acquire<'a>(&'a mut Receiver);
impl<'a> Future for Acquire<'a> {
type Output = Option<usize>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.poll_acquire(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::loom;
fn loom_scenario(iterations: usize, send_batch_size: usize, recv_batch_size: usize) {
assert_ne!(send_batch_size, 0);
assert_ne!(recv_batch_size, 0);
loom::model(move || {
let (send, mut recv) = channel();
let sender = loom::thread::spawn(move || {
for _ in 0..iterations {
send.submit(send_batch_size);
loom::hint::spin_loop();
}
});
let receiver = loom::thread::spawn(move || {
loom::future::block_on(async move {
let mut total = 0;
while let Some(mut count) = recv.acquire().await {
assert_ne!(count, 0);
while count > 0 {
let to_finish = count.min(recv_batch_size);
recv.finish(to_finish);
total += to_finish;
count -= to_finish;
}
}
assert_eq!(total, iterations * send_batch_size);
})
});
// loom tests will still run after returning so we don't need to join
if cfg!(not(loom)) {
sender.join().unwrap();
receiver.join().unwrap();
}
});
}
/// Async loom tests seem to spin forever if the number of iterations is higher than 1.
/// Ideally, this value would be a bit bigger to test more permutations of orderings.
const ITERATIONS: usize = if cfg!(loom) { 1 } else { 100 };
const SEND_BATCH_SIZE: usize = if cfg!(loom) { 2 } else { 8 };
const RECV_BATCH_SIZE: usize = if cfg!(loom) { 2 } else { 8 };
#[test]
fn loom_no_items() {
loom_scenario(0, 1, 1);
}
#[test]
fn loom_single_item() {
loom_scenario(ITERATIONS, 1, 1);
}
#[test]
fn loom_send_batch() {
loom_scenario(ITERATIONS, SEND_BATCH_SIZE, 1);
}
#[test]
fn loom_recv_batch() {
loom_scenario(ITERATIONS, 1, RECV_BATCH_SIZE);
}
#[test]
fn loom_both_batch() {
loom_scenario(ITERATIONS, SEND_BATCH_SIZE, RECV_BATCH_SIZE);
}
}