1use std::sync::atomic::AtomicUsize;
2use std::sync::Arc;
3use std::task::Poll;
4
5use futures::task::AtomicWaker;
6use futures::Stream;
7use sharded_queue::ShardedQueue;
8
9pub struct SharedQueueThreaded<T> {
11 queue: ShardedQueue<T>,
12 task_queue: AtomicUsize,
13 waker: AtomicWaker,
14}
15
16impl<T> SharedQueueThreaded<T> {
17 pub fn new(
21 max_concurrent_thread_count: usize,
22 ) -> std::io::Result<Arc<Self>> {
23 let waker = AtomicWaker::new();
24 Ok(Arc::new(Self {
25 queue: ShardedQueue::new(max_concurrent_thread_count),
26 task_queue: AtomicUsize::new(0),
27 waker,
28 }))
29 }
30}
31
32pub trait SharedQueueChannels<T> {
33 fn unbounded(&self) -> (Sender<T>, Receiver<T>);
34
35 fn sender(&self) -> Sender<T>;
36}
37
38impl<T> SharedQueueChannels<T> for Arc<SharedQueueThreaded<T>> {
39 fn unbounded(&self) -> (Sender<T>, Receiver<T>) {
40 let tx = self.sender();
41
42 let rx = Receiver {
43 queue: Arc::clone(self),
44 };
45
46 (tx, rx)
47 }
48
49 fn sender(&self) -> Sender<T> {
50 Sender {
51 queue: Arc::clone(self),
52 }
53 }
54}
55
56pub struct Sender<T> {
57 queue: Arc<SharedQueueThreaded<T>>,
58}
59
60impl<T> Sender<T> {
61 pub fn send(&self, item: T) {
63 self.queue
64 .task_queue
65 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
66 self.queue.queue.push_back(item);
67 self.queue.waker.wake();
68 }
69}
70
71#[derive(Clone)]
72pub struct Receiver<T> {
73 queue: Arc<SharedQueueThreaded<T>>,
74}
75
76impl<T> Stream for Receiver<T> {
77 type Item = T;
78
79 fn poll_next(
80 self: std::pin::Pin<&mut Self>,
81 cx: &mut std::task::Context<'_>,
82 ) -> std::task::Poll<Option<Self::Item>> {
83 self.queue.waker.register(cx.waker());
84
85 let old = self
86 .queue
87 .task_queue
88 .load(std::sync::atomic::Ordering::Relaxed);
89
90 if old > 0 {
91 self.queue
92 .task_queue
93 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
94 let item = self.queue.queue.pop_front_or_spin_wait_item();
95 Poll::Ready(Some(item))
96 } else {
97 Poll::Pending
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104
105 use std::time::Duration;
106
107 use futures::StreamExt;
108
109 use super::{SharedQueueChannels, SharedQueueThreaded};
110
111 #[monoio::test_all(timer_enabled = true)]
112 async fn ensure_send_receive() {
113 let queue = SharedQueueThreaded::<u8>::new(2).unwrap();
114
115 let (tx, mut rx) = queue.unbounded();
116
117 tx.send(1);
118 tx.send(2);
119
120 let val1 = rx.next().await.unwrap();
121 let val2 = rx.next().await.unwrap();
122 let val3 =
123 monoio::time::timeout(Duration::from_millis(10), rx.next()).await;
124
125 let mut merged = [val1, val2];
126 merged.sort();
127 let merged: Vec<u8> = merged.into_iter().collect();
128 assert_eq!(merged, [1, 2]);
129 assert!(val3.is_err());
130 }
131}