1use crate::{Error, Result};
2use std::sync::Arc;
3
4pub struct EventPermit(
6 #[allow(dead_code)] Option<tokio::sync::OwnedSemaphorePermit>,
7);
8
9impl std::fmt::Debug for EventPermit {
10 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
11 f.debug_struct("EventPermit").finish()
12 }
13}
14
15pub struct EventSend<E: From<Error>> {
18 limit: Arc<tokio::sync::Semaphore>,
19 send: tokio::sync::mpsc::UnboundedSender<(E, EventPermit)>,
20}
21
22impl<E: From<Error>> Clone for EventSend<E> {
23 fn clone(&self) -> Self {
24 Self {
25 limit: self.limit.clone(),
26 send: self.send.clone(),
27 }
28 }
29}
30
31impl<E: From<Error>> EventSend<E> {
32 pub fn new(limit: u32) -> (Self, EventRecv<E>) {
34 let limit = Arc::new(tokio::sync::Semaphore::new(limit as usize));
35 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
36 (EventSend { limit, send }, EventRecv(recv))
37 }
38
39 pub fn try_permit(&self) -> Option<EventPermit> {
41 match self.limit.clone().try_acquire_owned() {
42 Ok(p) => Some(EventPermit(Some(p))),
43 _ => None,
44 }
45 }
46
47 pub async fn send(&self, evt: E) -> Result<()> {
49 let permit = self
50 .limit
51 .clone()
52 .acquire_owned()
53 .await
54 .map_err(|_| Error::id("Closed"))?;
55 self.send
56 .send((evt, EventPermit(Some(permit))))
57 .map_err(|_| Error::id("Closed"))
58 }
59
60 pub fn send_permit(&self, evt: E, permit: EventPermit) -> Result<()> {
62 self.send
63 .send((evt, permit))
64 .map_err(|_| Error::id("Closed"))
65 }
66
67 pub fn send_err(&self, err: impl Into<Error>) {
69 let _ = self.send.send((err.into().into(), EventPermit(None)));
70 }
71}
72
73pub struct EventRecv<E: From<Error>>(
76 tokio::sync::mpsc::UnboundedReceiver<(E, EventPermit)>,
77);
78
79impl<E: From<Error>> std::fmt::Debug for EventRecv<E> {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("EventRecv").finish()
82 }
83}
84
85impl<E: From<Error>> EventRecv<E> {
86 pub async fn recv(&mut self) -> Option<E> {
88 self.0.recv().await.map(|r| r.0)
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[tokio::test(flavor = "multi_thread")]
97 async fn event_limit() {
98 let (s, _r) = <EventSend<Error>>::new(1);
99
100 s.send(Error::id("yo").into()).await.unwrap();
101
102 assert!(tokio::time::timeout(
103 std::time::Duration::from_millis(10),
104 s.send(Error::id("yo").into()),
105 )
106 .await
107 .is_err());
108 }
109}