s2n_quic_core/sync/
worker.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::sync::primitive::{Arc, AtomicUsize, AtomicWaker, Ordering};
5use core::{
6    future::Future,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use crossbeam_utils::CachePadded;
11
12/// Creates a worker channel with a Sender and Receiver
13pub fn channel() -> (Sender, Receiver) {
14    let state = Arc::new(State::default());
15    let sender = Sender(state.clone());
16    let receiver = Receiver { state, credits: 0 };
17    (sender, receiver)
18}
19
20/// A handle to the receiver side of the worker channel
21///
22/// This handle is used by the worker to wake up when there is work to do.
23pub struct Receiver {
24    state: Arc<State>,
25    credits: usize,
26}
27
28impl Receiver {
29    /// Acquires work to be processed for the Receiver
30    ///
31    /// `None` is returned when there are no more active Senders.
32    #[inline]
33    pub async fn acquire(&mut self) -> Option<usize> {
34        Acquire(self).await
35    }
36
37    /// Polls work to be processed for the receiver
38    ///
39    /// `None` is returned when there are no more active Senders.
40    #[inline]
41    pub fn poll_acquire(&mut self, cx: &mut Context) -> Poll<Option<usize>> {
42        let state = &*self.state;
43
44        macro_rules! acquire {
45            () => {{
46                // take the credits that we've been given by the senders
47                self.credits += state.remaining.swap(0, Ordering::Acquire);
48
49                // if we have any credits then return
50                if self.credits > 0 {
51                    return Poll::Ready(Some(self.credits));
52                }
53            }};
54        }
55
56        // first try to acquire credits
57        acquire!();
58
59        // if we didn't get any credits then register the waker
60        state.receiver.register(cx.waker());
61
62        // make one last effort to acquire credits in case a sender submitted some while we were
63        // registering the waker
64        acquire!();
65
66        // If we're the only ones with a handle to the state then we're done
67        if state.senders.load(Ordering::Acquire) == 0 {
68            return Poll::Ready(None);
69        }
70
71        Poll::Pending
72    }
73
74    /// Marks `count` jobs as finished
75    #[inline]
76    pub fn finish(&mut self, count: usize) {
77        debug_assert!(self.credits >= count);
78        // decrement the number of credits we have
79        self.credits -= count;
80    }
81}
82
83/// A handle to submit work to be done to a worker receiver
84///
85/// Multiple Sender handles can be created with `.clone()`.
86#[derive(Clone)]
87pub struct Sender(Arc<State>);
88
89impl Sender {
90    /// Submits `count` jobs to be executed by the worker receiver
91    #[inline]
92    pub fn submit(&self, count: usize) {
93        let state = &*self.0;
94
95        // increment the work counter
96        state.remaining.fetch_add(count, Ordering::Release);
97
98        // wake up the receiver if possible
99        state.receiver.wake();
100    }
101}
102
103impl Drop for Sender {
104    #[inline]
105    fn drop(&mut self) {
106        let state = &*self.0;
107
108        state.senders.fetch_sub(1, Ordering::Release);
109
110        // wake up the receiver to notify that one of the senders has dropped
111        state.receiver.wake();
112    }
113}
114
115struct State {
116    remaining: CachePadded<AtomicUsize>,
117    receiver: AtomicWaker,
118    senders: CachePadded<AtomicUsize>,
119}
120
121impl Default for State {
122    fn default() -> Self {
123        Self {
124            remaining: Default::default(),
125            receiver: Default::default(),
126            senders: AtomicUsize::new(1).into(),
127        }
128    }
129}
130
131struct Acquire<'a>(&'a mut Receiver);
132
133impl Future for Acquire<'_> {
134    type Output = Option<usize>;
135
136    #[inline]
137    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
138        self.0.poll_acquire(cx)
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::testing::loom;
146
147    fn loom_scenario(iterations: usize, send_batch_size: usize, recv_batch_size: usize) {
148        assert_ne!(send_batch_size, 0);
149        assert_ne!(recv_batch_size, 0);
150
151        loom::model(move || {
152            let (send, mut recv) = channel();
153
154            let sender = loom::thread::spawn(move || {
155                for _ in 0..iterations {
156                    send.submit(send_batch_size);
157                    loom::hint::spin_loop();
158                }
159            });
160
161            let receiver = loom::thread::spawn(move || {
162                loom::future::block_on(async move {
163                    let mut total = 0;
164                    while let Some(mut count) = recv.acquire().await {
165                        assert_ne!(count, 0);
166
167                        while count > 0 {
168                            let to_finish = count.min(recv_batch_size);
169                            recv.finish(to_finish);
170                            total += to_finish;
171                            count -= to_finish;
172                        }
173                    }
174
175                    assert_eq!(total, iterations * send_batch_size);
176                })
177            });
178
179            // loom tests will still run after returning so we don't need to join
180            if cfg!(not(loom)) {
181                sender.join().unwrap();
182                receiver.join().unwrap();
183            }
184        });
185    }
186
187    /// Async loom tests seem to spin forever if the number of iterations is higher than 1.
188    /// Ideally, this value would be a bit bigger to test more permutations of orderings.
189    const ITERATIONS: usize = if cfg!(loom) { 1 } else { 100 };
190    const SEND_BATCH_SIZE: usize = if cfg!(loom) { 2 } else { 8 };
191    const RECV_BATCH_SIZE: usize = if cfg!(loom) { 2 } else { 8 };
192
193    #[test]
194    fn loom_no_items() {
195        loom_scenario(0, 1, 1);
196    }
197
198    #[test]
199    fn loom_single_item() {
200        loom_scenario(ITERATIONS, 1, 1);
201    }
202
203    #[test]
204    fn loom_send_batch() {
205        loom_scenario(ITERATIONS, SEND_BATCH_SIZE, 1);
206    }
207
208    #[test]
209    fn loom_recv_batch() {
210        loom_scenario(ITERATIONS, 1, RECV_BATCH_SIZE);
211    }
212
213    #[test]
214    fn loom_both_batch() {
215        loom_scenario(ITERATIONS, SEND_BATCH_SIZE, RECV_BATCH_SIZE);
216    }
217}