witty_actors/
channel_with_priority.rs

1// Copyright (C) 2023 Quickwit, Inc.
2//
3// Quickwit is offered under the AGPL v3.0 and as commercial software.
4// For commercial licensing, contact us at hello@quickwit.io.
5//
6// AGPL:
7// This program is free software: you can redistribute it and/or modify
8// it under the terms of the GNU Affero General Public License as
9// published by the Free Software Foundation, either version 3 of the
10// License, or (at your option) any later version.
11//
12// This program is distributed in the hope that it will be useful,
13// but WITHOUT ANY WARRANTY; without even the implied warranty of
14// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15// GNU Affero General Public License for more details.
16//
17// You should have received a copy of the GNU Affero General Public License
18// along with this program. If not, see <http://www.gnu.org/licenses/>.
19
20use std::sync::atomic::{AtomicBool, Ordering};
21use std::sync::Mutex;
22
23use flume::TryRecvError;
24use thiserror::Error;
25
26#[derive(Default)]
27struct LockedOption<T> {
28    opt: Mutex<Option<T>>,
29    has_val: AtomicBool,
30}
31
32impl<T> LockedOption<T> {
33    pub fn none() -> Self {
34        LockedOption {
35            opt: Mutex::new(None),
36            has_val: AtomicBool::new(false),
37        }
38    }
39
40    pub fn is_some(&self) -> bool {
41        self.has_val.load(Ordering::Acquire)
42    }
43
44    pub fn is_none(&self) -> bool {
45        !self.is_some()
46    }
47
48    pub fn take(&self) -> Option<T> {
49        if !self.has_val.load(Ordering::Acquire) {
50            return None;
51        }
52        let mut lock = self.opt.lock().unwrap();
53        let val_opt = lock.take();
54        self.has_val.store(false, Ordering::Release);
55        val_opt
56    }
57
58    pub fn place(&self, val: T) {
59        let mut lock = self.opt.lock().unwrap();
60        self.has_val.store(true, Ordering::Release);
61        *lock = Some(val);
62    }
63}
64
65#[derive(Debug, Error)]
66pub enum SendError {
67    #[error("The channel is closed.")]
68    Disconnected,
69    #[error("The channel is full.")]
70    Full,
71}
72
73#[derive(Debug, Error)]
74pub enum TrySendError<M> {
75    #[error("The channel is closed.")]
76    Disconnected,
77    #[error("The channel is full.")]
78    Full(M),
79}
80
81impl<M> From<flume::TrySendError<M>> for TrySendError<M> {
82    fn from(err: flume::TrySendError<M>) -> Self {
83        match err {
84            flume::TrySendError::Full(msg) => TrySendError::Full(msg),
85            flume::TrySendError::Disconnected(_) => TrySendError::Disconnected,
86        }
87    }
88}
89
90#[derive(Clone, Copy, Debug, Error, Eq, PartialEq)]
91pub enum RecvError {
92    #[error("No message are currently available.")]
93    NoMessageAvailable,
94    #[error("All sender were dropped and no pending messages are in the channel.")]
95    Disconnected,
96}
97
98impl From<flume::RecvTimeoutError> for RecvError {
99    fn from(flume_err: flume::RecvTimeoutError) -> Self {
100        match flume_err {
101            flume::RecvTimeoutError::Timeout => Self::NoMessageAvailable,
102            flume::RecvTimeoutError::Disconnected => Self::Disconnected,
103        }
104    }
105}
106
107impl<T> From<flume::SendError<T>> for SendError {
108    fn from(_send_error: flume::SendError<T>) -> Self {
109        SendError::Disconnected
110    }
111}
112
113impl<T> From<flume::TrySendError<T>> for SendError {
114    fn from(try_send_error: flume::TrySendError<T>) -> Self {
115        match try_send_error {
116            flume::TrySendError::Full(_) => SendError::Full,
117            flume::TrySendError::Disconnected(_) => SendError::Disconnected,
118        }
119    }
120}
121
122#[derive(Clone, Copy, Debug)]
123pub enum QueueCapacity {
124    Bounded(usize),
125    Unbounded,
126}
127
128/// Creates a channel with the ability to send high priority messages.
129///
130/// A high priority message is guaranteed to be consumed before any
131/// low priority message sent after it.
132pub fn channel<T>(queue_capacity: QueueCapacity) -> (Sender<T>, Receiver<T>) {
133    let (high_priority_tx, high_priority_rx) = flume::unbounded();
134    let (low_priority_tx, low_priority_rx) = match queue_capacity {
135        QueueCapacity::Bounded(cap) => flume::bounded(cap),
136        QueueCapacity::Unbounded => flume::unbounded(),
137    };
138    let receiver = Receiver {
139        low_priority_rx,
140        high_priority_rx,
141        _high_priority_tx: high_priority_tx.clone(),
142        pending_low_priority_message: LockedOption::none(),
143        _clone_is_forbidden: CloneIsForbidden,
144    };
145    let sender = Sender {
146        low_priority_tx,
147        high_priority_tx,
148    };
149    (sender, receiver)
150}
151
152pub struct Sender<T> {
153    low_priority_tx: flume::Sender<T>,
154    high_priority_tx: flume::Sender<T>,
155}
156
157impl<T> Sender<T> {
158    pub fn is_disconnected(&self) -> bool {
159        self.low_priority_tx.is_disconnected()
160    }
161
162    pub fn try_send_low_priority(&self, msg: T) -> Result<(), TrySendError<T>> {
163        self.low_priority_tx.try_send(msg)?;
164        Ok(())
165    }
166
167    pub async fn send_low_priority(&self, msg: T) -> Result<(), SendError> {
168        self.low_priority_tx.send_async(msg).await?;
169        Ok(())
170    }
171
172    pub fn send_high_priority(&self, msg: T) -> Result<(), SendError> {
173        self.high_priority_tx.send(msg)?;
174        Ok(())
175    }
176}
177
178// Message to future generations. I created this flag to prevent you
179// from naively making a struct cloneable.
180// The drop implementation drains the elements in the channel.
181struct CloneIsForbidden;
182
183pub struct Receiver<T> {
184    low_priority_rx: flume::Receiver<T>,
185    high_priority_rx: flume::Receiver<T>,
186    _high_priority_tx: flume::Sender<T>,
187    pending_low_priority_message: LockedOption<T>,
188    _clone_is_forbidden: CloneIsForbidden,
189}
190
191impl<T> Drop for Receiver<T> {
192    fn drop(&mut self) {
193        // Flume strangely (tokio::mpsc does not behave like this for instance)
194        // does not drop the message in the channel when all receiver are dropped.
195        //
196        // They are only dropped when both the receivers AND the sender are dropped.
197        // We fix this behavior by drainng the channel upon drop.
198        self.high_priority_rx.drain();
199        self.low_priority_rx.drain();
200    }
201}
202
203impl<T> Receiver<T> {
204    pub fn is_empty(&self) -> bool {
205        self.low_priority_rx.is_empty()
206            && self.pending_low_priority_message.is_none()
207            && self.high_priority_rx.is_empty()
208    }
209
210    pub fn try_recv_high_priority_message(&self) -> Result<T, RecvError> {
211        match self.high_priority_rx.try_recv() {
212            Ok(msg) => Ok(msg),
213            Err(TryRecvError::Disconnected) => {
214                unreachable!(
215                    "This can never happen, as the high priority Sender is owned by the Receiver."
216                );
217            }
218            Err(TryRecvError::Empty) => {
219                if self.low_priority_rx.is_disconnected() {
220                    // We check that no new high priority message were sent
221                    // in between.
222                    if let Ok(msg) = self.high_priority_rx.try_recv() {
223                        Ok(msg)
224                    } else {
225                        Err(RecvError::Disconnected)
226                    }
227                } else {
228                    Err(RecvError::NoMessageAvailable)
229                }
230            }
231        }
232    }
233
234    pub fn try_recv(&self) -> Result<T, RecvError> {
235        if let Ok(msg) = self.high_priority_rx.try_recv() {
236            return Ok(msg);
237        }
238        if let Some(pending_msg) = self.pending_low_priority_message.take() {
239            return Ok(pending_msg);
240        }
241        match self.low_priority_rx.try_recv() {
242            Ok(low_msg) => {
243                if let Ok(high_msg) = self.high_priority_rx.try_recv() {
244                    self.pending_low_priority_message.place(low_msg);
245                    Ok(high_msg)
246                } else {
247                    Ok(low_msg)
248                }
249            }
250            Err(TryRecvError::Disconnected) => {
251                if let Ok(high_msg) = self.high_priority_rx.try_recv() {
252                    Ok(high_msg)
253                } else {
254                    Err(RecvError::Disconnected)
255                }
256            }
257            Err(TryRecvError::Empty) => Err(RecvError::NoMessageAvailable),
258        }
259    }
260
261    pub async fn recv_high_priority(&self) -> T {
262        self.high_priority_rx
263            .recv_async()
264            .await
265            .expect("The Receiver owns the high priority Sender to avoid any disconnection.")
266    }
267
268    pub async fn recv(&self) -> Result<T, RecvError> {
269        if let Ok(msg) = self.try_recv_high_priority_message() {
270            return Ok(msg);
271        }
272        if let Some(pending_msg) = self.pending_low_priority_message.take() {
273            return Ok(pending_msg);
274        }
275        tokio::select! {
276            // We don't really care about fairness here.
277            // We will double check if there is a command or not anyway.
278            biased;
279            high_priority_msg_res = self.high_priority_rx.recv_async() => {
280                match high_priority_msg_res {
281                    Ok(high_priority_msg) => {
282                        Ok(high_priority_msg)
283                    },
284                    Err(_) => {
285                        unreachable!("The Receiver owns the high priority Sender to avoid any disconnection.")
286                    },
287                }
288            }
289            low_priority_msg_res = self.low_priority_rx.recv_async() => {
290                match low_priority_msg_res {
291                    Ok(low_priority_msg) => {
292                        if let Ok(high_priority_msg) = self.try_recv_high_priority_message() {
293                            self.pending_low_priority_message.place(low_priority_msg);
294                            Ok(high_priority_msg)
295                        } else {
296                            Ok(low_priority_msg)
297                        }
298                    },
299                    Err(flume::RecvError::Disconnected) => {
300                        if let Ok(high_priority_msg) = self.try_recv_high_priority_message() {
301                            Ok(high_priority_msg)
302                        } else {
303                            Err(RecvError::Disconnected)
304                        }
305                    }
306                }
307           }
308        }
309    }
310
311    /// Drain all of the pending low priority messages and return them.
312    pub fn drain_low_priority(&self) -> Vec<T> {
313        let mut messages = Vec::new();
314        while let Ok(msg) = self.low_priority_rx.try_recv() {
315            messages.push(msg);
316        }
317        messages
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use std::sync::Arc;
324    use std::time::Duration;
325
326    use super::*;
327
328    #[tokio::test]
329    async fn test_channel_with_priority_drop_receiver_drop_messages() {
330        let arc_high = Arc::new(());
331        let arc_low = Arc::new(());
332        let (tx, rx) = super::channel(QueueCapacity::Bounded(2));
333        tx.send_high_priority(arc_high.clone()).unwrap();
334        tx.send_low_priority(arc_low.clone()).await.unwrap();
335        assert_eq!(Arc::strong_count(&arc_high), 2);
336        assert_eq!(Arc::strong_count(&arc_low), 2);
337        drop(rx);
338        assert_eq!(Arc::strong_count(&arc_high), 1);
339        assert_eq!(Arc::strong_count(&arc_low), 1);
340    }
341
342    #[test]
343    fn test_locked_option_new_empty() {
344        let locked_option: LockedOption<usize> = LockedOption::none();
345        assert_eq!(locked_option.take(), None);
346    }
347
348    #[test]
349    fn test_locked_option_place() {
350        let locked_option = LockedOption::none();
351        locked_option.place(1);
352        assert_eq!(locked_option.take(), Some(1));
353    }
354
355    #[test]
356    fn test_locked_option_place_twice_keep_last() {
357        let locked_option = LockedOption::none();
358        locked_option.place(1);
359        locked_option.place(2);
360        assert_eq!(locked_option.take(), Some(2));
361    }
362
363    #[test]
364    fn test_locked_option_place_take_twice() {
365        let locked_option = LockedOption::none();
366        locked_option.place(1);
367        assert_eq!(locked_option.take(), Some(1));
368        assert_eq!(locked_option.take(), None);
369    }
370
371    #[tokio::test]
372    async fn test_recv_priority() -> anyhow::Result<()> {
373        let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
374        sender.send_low_priority(1).await?;
375        sender.send_high_priority(2)?;
376        assert_eq!(receiver.recv().await, Ok(2));
377        assert_eq!(receiver.recv().await, Ok(1));
378        assert!(
379            tokio::time::timeout(Duration::from_millis(50), receiver.recv())
380                .await
381                .is_err()
382        );
383        Ok(())
384    }
385
386    #[tokio::test]
387    async fn test_try_recv() -> anyhow::Result<()> {
388        let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
389        sender.send_low_priority(1).await?;
390        assert_eq!(receiver.try_recv(), Ok(1));
391        assert_eq!(receiver.try_recv(), Err(RecvError::NoMessageAvailable));
392        Ok(())
393    }
394
395    #[tokio::test]
396    async fn test_try_recv_high_priority() -> anyhow::Result<()> {
397        let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
398        sender.send_low_priority(1).await?;
399        assert_eq!(
400            receiver.try_recv_high_priority_message(),
401            Err(RecvError::NoMessageAvailable)
402        );
403        Ok(())
404    }
405
406    #[tokio::test]
407    async fn test_recv_high_priority_ignore_disconnection() -> anyhow::Result<()> {
408        let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
409        std::mem::drop(sender);
410        assert!(
411            tokio::time::timeout(Duration::from_millis(100), receiver.recv_high_priority())
412                .await
413                .is_err()
414        );
415        Ok(())
416    }
417
418    #[tokio::test]
419    async fn test_recv_disconnect() -> anyhow::Result<()> {
420        let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
421        std::mem::drop(sender);
422        assert_eq!(receiver.recv().await, Err(RecvError::Disconnected));
423        Ok(())
424    }
425
426    #[tokio::test]
427    async fn test_recv_timeout_simple() -> anyhow::Result<()> {
428        let (_sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
429        assert!(matches!(
430            receiver.try_recv(),
431            Err(RecvError::NoMessageAvailable)
432        ));
433        Ok(())
434    }
435
436    #[tokio::test]
437    async fn test_try_recv_priority_corner_case() -> anyhow::Result<()> {
438        let (sender, receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
439        tokio::task::spawn(async move {
440            tokio::time::sleep(Duration::from_millis(10)).await;
441            sender.send_high_priority(1)?;
442            sender.send_low_priority(2).await?;
443            Result::<(), SendError>::Ok(())
444        });
445        assert_eq!(receiver.recv().await, Ok(1));
446        assert_eq!(receiver.try_recv(), Ok(2));
447        assert!(matches!(receiver.try_recv(), Err(RecvError::Disconnected)));
448        Ok(())
449    }
450
451    #[tokio::test]
452    async fn test_try_recv_high_low() {
453        let (tx, rx) = super::channel::<usize>(QueueCapacity::Unbounded);
454        tx.send_low_priority(1).await.unwrap();
455        tx.send_high_priority(2).unwrap();
456        assert_eq!(rx.try_recv(), Ok(2));
457        assert_eq!(rx.try_recv(), Ok(1));
458        assert_eq!(rx.try_recv(), Err(RecvError::NoMessageAvailable));
459    }
460
461    #[tokio::test]
462    async fn test_try_recv_high() {
463        let (tx, rx) = super::channel::<usize>(QueueCapacity::Unbounded);
464        tx.send_low_priority(1).await.unwrap();
465        tx.send_high_priority(2).unwrap();
466        assert_eq!(rx.try_recv_high_priority_message(), Ok(2));
467        assert_eq!(
468            rx.try_recv_high_priority_message(),
469            Err(RecvError::NoMessageAvailable)
470        );
471        assert_eq!(rx.try_recv(), Ok(1));
472        assert_eq!(rx.try_recv(), Err(RecvError::NoMessageAvailable));
473    }
474}