1use std::sync::atomic::{AtomicBool, Ordering};
21use std::sync::Mutex;
22
23use flume::TryRecvError;
24use thiserror::Error;
25
26#[derive(Default)]
27struct LockedOption<T> {
28 opt: Mutex<Option<T>>,
29 has_val: AtomicBool,
30}
31
32impl<T> LockedOption<T> {
33 pub fn none() -> Self {
34 LockedOption {
35 opt: Mutex::new(None),
36 has_val: AtomicBool::new(false),
37 }
38 }
39
40 pub fn is_some(&self) -> bool {
41 self.has_val.load(Ordering::Acquire)
42 }
43
44 pub fn is_none(&self) -> bool {
45 !self.is_some()
46 }
47
48 pub fn take(&self) -> Option<T> {
49 if !self.has_val.load(Ordering::Acquire) {
50 return None;
51 }
52 let mut lock = self.opt.lock().unwrap();
53 let val_opt = lock.take();
54 self.has_val.store(false, Ordering::Release);
55 val_opt
56 }
57
58 pub fn place(&self, val: T) {
59 let mut lock = self.opt.lock().unwrap();
60 self.has_val.store(true, Ordering::Release);
61 *lock = Some(val);
62 }
63}
64
65#[derive(Debug, Error)]
66pub enum SendError {
67 #[error("The channel is closed.")]
68 Disconnected,
69 #[error("The channel is full.")]
70 Full,
71}
72
73#[derive(Debug, Error)]
74pub enum TrySendError<M> {
75 #[error("The channel is closed.")]
76 Disconnected,
77 #[error("The channel is full.")]
78 Full(M),
79}
80
81impl<M> From<flume::TrySendError<M>> for TrySendError<M> {
82 fn from(err: flume::TrySendError<M>) -> Self {
83 match err {
84 flume::TrySendError::Full(msg) => TrySendError::Full(msg),
85 flume::TrySendError::Disconnected(_) => TrySendError::Disconnected,
86 }
87 }
88}
89
90#[derive(Clone, Copy, Debug, Error, Eq, PartialEq)]
91pub enum RecvError {
92 #[error("No message are currently available.")]
93 NoMessageAvailable,
94 #[error("All sender were dropped and no pending messages are in the channel.")]
95 Disconnected,
96}
97
98impl From<flume::RecvTimeoutError> for RecvError {
99 fn from(flume_err: flume::RecvTimeoutError) -> Self {
100 match flume_err {
101 flume::RecvTimeoutError::Timeout => Self::NoMessageAvailable,
102 flume::RecvTimeoutError::Disconnected => Self::Disconnected,
103 }
104 }
105}
106
107impl<T> From<flume::SendError<T>> for SendError {
108 fn from(_send_error: flume::SendError<T>) -> Self {
109 SendError::Disconnected
110 }
111}
112
113impl<T> From<flume::TrySendError<T>> for SendError {
114 fn from(try_send_error: flume::TrySendError<T>) -> Self {
115 match try_send_error {
116 flume::TrySendError::Full(_) => SendError::Full,
117 flume::TrySendError::Disconnected(_) => SendError::Disconnected,
118 }
119 }
120}
121
122#[derive(Clone, Copy, Debug)]
123pub enum QueueCapacity {
124 Bounded(usize),
125 Unbounded,
126}
127
128pub fn channel<T>(queue_capacity: QueueCapacity) -> (Sender<T>, Receiver<T>) {
133 let (high_priority_tx, high_priority_rx) = flume::unbounded();
134 let (low_priority_tx, low_priority_rx) = match queue_capacity {
135 QueueCapacity::Bounded(cap) => flume::bounded(cap),
136 QueueCapacity::Unbounded => flume::unbounded(),
137 };
138 let receiver = Receiver {
139 low_priority_rx,
140 high_priority_rx,
141 _high_priority_tx: high_priority_tx.clone(),
142 pending_low_priority_message: LockedOption::none(),
143 _clone_is_forbidden: CloneIsForbidden,
144 };
145 let sender = Sender {
146 low_priority_tx,
147 high_priority_tx,
148 };
149 (sender, receiver)
150}
151
152pub struct Sender<T> {
153 low_priority_tx: flume::Sender<T>,
154 high_priority_tx: flume::Sender<T>,
155}
156
157impl<T> Sender<T> {
158 pub fn is_disconnected(&self) -> bool {
159 self.low_priority_tx.is_disconnected()
160 }
161
162 pub fn try_send_low_priority(&self, msg: T) -> Result<(), TrySendError<T>> {
163 self.low_priority_tx.try_send(msg)?;
164 Ok(())
165 }
166
167 pub async fn send_low_priority(&self, msg: T) -> Result<(), SendError> {
168 self.low_priority_tx.send_async(msg).await?;
169 Ok(())
170 }
171
172 pub fn send_high_priority(&self, msg: T) -> Result<(), SendError> {
173 self.high_priority_tx.send(msg)?;
174 Ok(())
175 }
176}
177
178struct CloneIsForbidden;
182
183pub struct Receiver<T> {
184 low_priority_rx: flume::Receiver<T>,
185 high_priority_rx: flume::Receiver<T>,
186 _high_priority_tx: flume::Sender<T>,
187 pending_low_priority_message: LockedOption<T>,
188 _clone_is_forbidden: CloneIsForbidden,
189}
190
191impl<T> Drop for Receiver<T> {
192 fn drop(&mut self) {
193 self.high_priority_rx.drain();
199 self.low_priority_rx.drain();
200 }
201}
202
203impl<T> Receiver<T> {
204 pub fn is_empty(&self) -> bool {
205 self.low_priority_rx.is_empty()
206 && self.pending_low_priority_message.is_none()
207 && self.high_priority_rx.is_empty()
208 }
209
210 pub fn try_recv_high_priority_message(&self) -> Result<T, RecvError> {
211 match self.high_priority_rx.try_recv() {
212 Ok(msg) => Ok(msg),
213 Err(TryRecvError::Disconnected) => {
214 unreachable!(
215 "This can never happen, as the high priority Sender is owned by the Receiver."
216 );
217 }
218 Err(TryRecvError::Empty) => {
219 if self.low_priority_rx.is_disconnected() {
220 if let Ok(msg) = self.high_priority_rx.try_recv() {
223 Ok(msg)
224 } else {
225 Err(RecvError::Disconnected)
226 }
227 } else {
228 Err(RecvError::NoMessageAvailable)
229 }
230 }
231 }
232 }
233
234 pub fn try_recv(&self) -> Result<T, RecvError> {
235 if let Ok(msg) = self.high_priority_rx.try_recv() {
236 return Ok(msg);
237 }
238 if let Some(pending_msg) = self.pending_low_priority_message.take() {
239 return Ok(pending_msg);
240 }
241 match self.low_priority_rx.try_recv() {
242 Ok(low_msg) => {
243 if let Ok(high_msg) = self.high_priority_rx.try_recv() {
244 self.pending_low_priority_message.place(low_msg);
245 Ok(high_msg)
246 } else {
247 Ok(low_msg)
248 }
249 }
250 Err(TryRecvError::Disconnected) => {
251 if let Ok(high_msg) = self.high_priority_rx.try_recv() {
252 Ok(high_msg)
253 } else {
254 Err(RecvError::Disconnected)
255 }
256 }
257 Err(TryRecvError::Empty) => Err(RecvError::NoMessageAvailable),
258 }
259 }
260
261 pub async fn recv_high_priority(&self) -> T {
262 self.high_priority_rx
263 .recv_async()
264 .await
265 .expect("The Receiver owns the high priority Sender to avoid any disconnection.")
266 }
267
268 pub async fn recv(&self) -> Result<T, RecvError> {
269 if let Ok(msg) = self.try_recv_high_priority_message() {
270 return Ok(msg);
271 }
272 if let Some(pending_msg) = self.pending_low_priority_message.take() {
273 return Ok(pending_msg);
274 }
275 tokio::select! {
276 biased;
279 high_priority_msg_res = self.high_priority_rx.recv_async() => {
280 match high_priority_msg_res {
281 Ok(high_priority_msg) => {
282 Ok(high_priority_msg)
283 },
284 Err(_) => {
285 unreachable!("The Receiver owns the high priority Sender to avoid any disconnection.")
286 },
287 }
288 }
289 low_priority_msg_res = self.low_priority_rx.recv_async() => {
290 match low_priority_msg_res {
291 Ok(low_priority_msg) => {
292 if let Ok(high_priority_msg) = self.try_recv_high_priority_message() {
293 self.pending_low_priority_message.place(low_priority_msg);
294 Ok(high_priority_msg)
295 } else {
296 Ok(low_priority_msg)
297 }
298 },
299 Err(flume::RecvError::Disconnected) => {
300 if let Ok(high_priority_msg) = self.try_recv_high_priority_message() {
301 Ok(high_priority_msg)
302 } else {
303 Err(RecvError::Disconnected)
304 }
305 }
306 }
307 }
308 }
309 }
310
311 pub fn drain_low_priority(&self) -> Vec<T> {
313 let mut messages = Vec::new();
314 while let Ok(msg) = self.low_priority_rx.try_recv() {
315 messages.push(msg);
316 }
317 messages
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use std::sync::Arc;
324 use std::time::Duration;
325
326 use super::*;
327
328 #[tokio::test]
329 async fn test_channel_with_priority_drop_receiver_drop_messages() {
330 let arc_high = Arc::new(());
331 let arc_low = Arc::new(());
332 let (tx, rx) = super::channel(QueueCapacity::Bounded(2));
333 tx.send_high_priority(arc_high.clone()).unwrap();
334 tx.send_low_priority(arc_low.clone()).await.unwrap();
335 assert_eq!(Arc::strong_count(&arc_high), 2);
336 assert_eq!(Arc::strong_count(&arc_low), 2);
337 drop(rx);
338 assert_eq!(Arc::strong_count(&arc_high), 1);
339 assert_eq!(Arc::strong_count(&arc_low), 1);
340 }
341
342 #[test]
343 fn test_locked_option_new_empty() {
344 let locked_option: LockedOption<usize> = LockedOption::none();
345 assert_eq!(locked_option.take(), None);
346 }
347
348 #[test]
349 fn test_locked_option_place() {
350 let locked_option = LockedOption::none();
351 locked_option.place(1);
352 assert_eq!(locked_option.take(), Some(1));
353 }
354
355 #[test]
356 fn test_locked_option_place_twice_keep_last() {
357 let locked_option = LockedOption::none();
358 locked_option.place(1);
359 locked_option.place(2);
360 assert_eq!(locked_option.take(), Some(2));
361 }
362
363 #[test]
364 fn test_locked_option_place_take_twice() {
365 let locked_option = LockedOption::none();
366 locked_option.place(1);
367 assert_eq!(locked_option.take(), Some(1));
368 assert_eq!(locked_option.take(), None);
369 }
370
371 #[tokio::test]
372 async fn test_recv_priority() -> anyhow::Result<()> {
373 let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
374 sender.send_low_priority(1).await?;
375 sender.send_high_priority(2)?;
376 assert_eq!(receiver.recv().await, Ok(2));
377 assert_eq!(receiver.recv().await, Ok(1));
378 assert!(
379 tokio::time::timeout(Duration::from_millis(50), receiver.recv())
380 .await
381 .is_err()
382 );
383 Ok(())
384 }
385
386 #[tokio::test]
387 async fn test_try_recv() -> anyhow::Result<()> {
388 let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
389 sender.send_low_priority(1).await?;
390 assert_eq!(receiver.try_recv(), Ok(1));
391 assert_eq!(receiver.try_recv(), Err(RecvError::NoMessageAvailable));
392 Ok(())
393 }
394
395 #[tokio::test]
396 async fn test_try_recv_high_priority() -> anyhow::Result<()> {
397 let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
398 sender.send_low_priority(1).await?;
399 assert_eq!(
400 receiver.try_recv_high_priority_message(),
401 Err(RecvError::NoMessageAvailable)
402 );
403 Ok(())
404 }
405
406 #[tokio::test]
407 async fn test_recv_high_priority_ignore_disconnection() -> anyhow::Result<()> {
408 let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
409 std::mem::drop(sender);
410 assert!(
411 tokio::time::timeout(Duration::from_millis(100), receiver.recv_high_priority())
412 .await
413 .is_err()
414 );
415 Ok(())
416 }
417
418 #[tokio::test]
419 async fn test_recv_disconnect() -> anyhow::Result<()> {
420 let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
421 std::mem::drop(sender);
422 assert_eq!(receiver.recv().await, Err(RecvError::Disconnected));
423 Ok(())
424 }
425
426 #[tokio::test]
427 async fn test_recv_timeout_simple() -> anyhow::Result<()> {
428 let (_sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
429 assert!(matches!(
430 receiver.try_recv(),
431 Err(RecvError::NoMessageAvailable)
432 ));
433 Ok(())
434 }
435
436 #[tokio::test]
437 async fn test_try_recv_priority_corner_case() -> anyhow::Result<()> {
438 let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
439 tokio::task::spawn(async move {
440 tokio::time::sleep(Duration::from_millis(10)).await;
441 sender.send_high_priority(1)?;
442 sender.send_low_priority(2).await?;
443 Result::<(), SendError>::Ok(())
444 });
445 assert_eq!(receiver.recv().await, Ok(1));
446 assert_eq!(receiver.try_recv(), Ok(2));
447 assert!(matches!(receiver.try_recv(), Err(RecvError::Disconnected)));
448 Ok(())
449 }
450
451 #[tokio::test]
452 async fn test_try_recv_high_low() {
453 let (tx, rx) = super::channel::<usize>(QueueCapacity::Unbounded);
454 tx.send_low_priority(1).await.unwrap();
455 tx.send_high_priority(2).unwrap();
456 assert_eq!(rx.try_recv(), Ok(2));
457 assert_eq!(rx.try_recv(), Ok(1));
458 assert_eq!(rx.try_recv(), Err(RecvError::NoMessageAvailable));
459 }
460
461 #[tokio::test]
462 async fn test_try_recv_high() {
463 let (tx, rx) = super::channel::<usize>(QueueCapacity::Unbounded);
464 tx.send_low_priority(1).await.unwrap();
465 tx.send_high_priority(2).unwrap();
466 assert_eq!(rx.try_recv_high_priority_message(), Ok(2));
467 assert_eq!(
468 rx.try_recv_high_priority_message(),
469 Err(RecvError::NoMessageAvailable)
470 );
471 assert_eq!(rx.try_recv(), Ok(1));
472 assert_eq!(rx.try_recv(), Err(RecvError::NoMessageAvailable));
473 }
474}