sharded_thread/
queue.rs

1use std::sync::atomic::AtomicUsize;
2use std::sync::Arc;
3use std::task::Poll;
4
5use futures::task::AtomicWaker;
6use futures::Stream;
7use sharded_queue::ShardedQueue;
8
9/// A queue that should be available on each thread.
10pub struct SharedQueueThreaded<T> {
11    queue: ShardedQueue<T>,
12    task_queue: AtomicUsize,
13    waker: AtomicWaker,
14}
15
16impl<T> SharedQueueThreaded<T> {
17    /// Create a new `SharedQueueThreaded` by specifing the number of thread how
18    /// can physically access the queue = the number of CPU core available
19    /// for the application.
20    pub fn new(
21        max_concurrent_thread_count: usize,
22    ) -> std::io::Result<Arc<Self>> {
23        let waker = AtomicWaker::new();
24        Ok(Arc::new(Self {
25            queue: ShardedQueue::new(max_concurrent_thread_count),
26            task_queue: AtomicUsize::new(0),
27            waker,
28        }))
29    }
30}
31
32pub trait SharedQueueChannels<T> {
33    fn unbounded(&self) -> (Sender<T>, Receiver<T>);
34
35    fn sender(&self) -> Sender<T>;
36}
37
38impl<T> SharedQueueChannels<T> for Arc<SharedQueueThreaded<T>> {
39    fn unbounded(&self) -> (Sender<T>, Receiver<T>) {
40        let tx = self.sender();
41
42        let rx = Receiver {
43            queue: Arc::clone(self),
44        };
45
46        (tx, rx)
47    }
48
49    fn sender(&self) -> Sender<T> {
50        Sender {
51            queue: Arc::clone(self),
52        }
53    }
54}
55
56pub struct Sender<T> {
57    queue: Arc<SharedQueueThreaded<T>>,
58}
59
60impl<T> Sender<T> {
61    /// Attempts to send a value to the queue
62    pub fn send(&self, item: T) {
63        self.queue
64            .task_queue
65            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
66        self.queue.queue.push_back(item);
67        self.queue.waker.wake();
68    }
69}
70
71#[derive(Clone)]
72pub struct Receiver<T> {
73    queue: Arc<SharedQueueThreaded<T>>,
74}
75
76impl<T> Stream for Receiver<T> {
77    type Item = T;
78
79    fn poll_next(
80        self: std::pin::Pin<&mut Self>,
81        cx: &mut std::task::Context<'_>,
82    ) -> std::task::Poll<Option<Self::Item>> {
83        self.queue.waker.register(cx.waker());
84
85        let old = self
86            .queue
87            .task_queue
88            .load(std::sync::atomic::Ordering::Relaxed);
89
90        if old > 0 {
91            self.queue
92                .task_queue
93                .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
94            let item = self.queue.queue.pop_front_or_spin_wait_item();
95            Poll::Ready(Some(item))
96        } else {
97            Poll::Pending
98        }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104
105    use std::time::Duration;
106
107    use futures::StreamExt;
108
109    use super::{SharedQueueChannels, SharedQueueThreaded};
110
111    #[monoio::test_all(timer_enabled = true)]
112    async fn ensure_send_receive() {
113        let queue = SharedQueueThreaded::<u8>::new(2).unwrap();
114
115        let (tx, mut rx) = queue.unbounded();
116
117        tx.send(1);
118        tx.send(2);
119
120        let val1 = rx.next().await.unwrap();
121        let val2 = rx.next().await.unwrap();
122        let val3 =
123            monoio::time::timeout(Duration::from_millis(10), rx.next()).await;
124
125        let mut merged = [val1, val2];
126        merged.sort();
127        let merged: Vec<u8> = merged.into_iter().collect();
128        assert_eq!(merged, [1, 2]);
129        assert!(val3.is_err());
130    }
131}