1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::sync::primitive::{Arc, AtomicUsize, AtomicWaker, Ordering};
use core::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use crossbeam_utils::CachePadded;

/// Creates a worker channel with a Sender and Receiver
pub fn channel() -> (Sender, Receiver) {
    let state = Arc::new(State::default());
    let sender = Sender(state.clone());
    let receiver = Receiver { state, credits: 0 };
    (sender, receiver)
}

/// A handle to the receiver side of the worker channel
///
/// This handle is used by the worker to wake up when there is work to do.
pub struct Receiver {
    state: Arc<State>,
    credits: usize,
}

impl Receiver {
    /// Acquires work to be processed for the Receiver
    ///
    /// `None` is returned when there are no more active Senders.
    #[inline]
    pub async fn acquire(&mut self) -> Option<usize> {
        Acquire(self).await
    }

    /// Polls work to be processed for the receiver
    ///
    /// `None` is returned when there are no more active Senders.
    #[inline]
    pub fn poll_acquire(&mut self, cx: &mut Context) -> Poll<Option<usize>> {
        let state = &*self.state;

        macro_rules! acquire {
            () => {{
                // take the credits that we've been given by the senders
                self.credits += state.remaining.swap(0, Ordering::Acquire);

                // if we have any credits then return
                if self.credits > 0 {
                    return Poll::Ready(Some(self.credits));
                }
            }};
        }

        // first try to acquire credits
        acquire!();

        // if we didn't get any credits then register the waker
        state.receiver.register(cx.waker());

        // make one last effort to acquire credits in case a sender submitted some while we were
        // registering the waker
        acquire!();

        // If we're the only ones with a handle to the state then we're done
        if state.senders.load(Ordering::Acquire) == 0 {
            return Poll::Ready(None);
        }

        Poll::Pending
    }

    /// Marks `count` jobs as finished
    #[inline]
    pub fn finish(&mut self, count: usize) {
        debug_assert!(self.credits >= count);
        // decrement the number of credits we have
        self.credits -= count;
    }
}

/// A handle to submit work to be done to a worker receiver
///
/// Multiple Sender handles can be created with `.clone()`.
#[derive(Clone)]
pub struct Sender(Arc<State>);

impl Sender {
    /// Submits `count` jobs to be executed by the worker receiver
    #[inline]
    pub fn submit(&self, count: usize) {
        let state = &*self.0;

        // increment the work counter
        state.remaining.fetch_add(count, Ordering::Release);

        // wake up the receiver if possible
        state.receiver.wake();
    }
}

impl Drop for Sender {
    #[inline]
    fn drop(&mut self) {
        let state = &*self.0;

        state.senders.fetch_sub(1, Ordering::Release);

        // wake up the receiver to notify that one of the senders has dropped
        state.receiver.wake();
    }
}

struct State {
    remaining: CachePadded<AtomicUsize>,
    receiver: AtomicWaker,
    senders: CachePadded<AtomicUsize>,
}

impl Default for State {
    fn default() -> Self {
        Self {
            remaining: Default::default(),
            receiver: Default::default(),
            senders: AtomicUsize::new(1).into(),
        }
    }
}

struct Acquire<'a>(&'a mut Receiver);

impl<'a> Future for Acquire<'a> {
    type Output = Option<usize>;

    #[inline]
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
        self.0.poll_acquire(cx)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::testing::loom;

    fn loom_scenario(iterations: usize, send_batch_size: usize, recv_batch_size: usize) {
        assert_ne!(send_batch_size, 0);
        assert_ne!(recv_batch_size, 0);

        loom::model(move || {
            let (send, mut recv) = channel();

            let sender = loom::thread::spawn(move || {
                for _ in 0..iterations {
                    send.submit(send_batch_size);
                    loom::hint::spin_loop();
                }
            });

            let receiver = loom::thread::spawn(move || {
                loom::future::block_on(async move {
                    let mut total = 0;
                    while let Some(mut count) = recv.acquire().await {
                        assert_ne!(count, 0);

                        while count > 0 {
                            let to_finish = count.min(recv_batch_size);
                            recv.finish(to_finish);
                            total += to_finish;
                            count -= to_finish;
                        }
                    }

                    assert_eq!(total, iterations * send_batch_size);
                })
            });

            // loom tests will still run after returning so we don't need to join
            if cfg!(not(loom)) {
                sender.join().unwrap();
                receiver.join().unwrap();
            }
        });
    }

    /// Async loom tests seem to spin forever if the number of iterations is higher than 1.
    /// Ideally, this value would be a bit bigger to test more permutations of orderings.
    const ITERATIONS: usize = if cfg!(loom) { 1 } else { 100 };
    const SEND_BATCH_SIZE: usize = if cfg!(loom) { 2 } else { 8 };
    const RECV_BATCH_SIZE: usize = if cfg!(loom) { 2 } else { 8 };

    #[test]
    fn loom_no_items() {
        loom_scenario(0, 1, 1);
    }

    #[test]
    fn loom_single_item() {
        loom_scenario(ITERATIONS, 1, 1);
    }

    #[test]
    fn loom_send_batch() {
        loom_scenario(ITERATIONS, SEND_BATCH_SIZE, 1);
    }

    #[test]
    fn loom_recv_batch() {
        loom_scenario(ITERATIONS, 1, RECV_BATCH_SIZE);
    }

    #[test]
    fn loom_both_batch() {
        loom_scenario(ITERATIONS, SEND_BATCH_SIZE, RECV_BATCH_SIZE);
    }
}