tiny_actor/channel/
sending.rs

1use crate::*;
2use concurrent_queue::PushError;
3use event_listener::EventListener;
4use futures::{Future, FutureExt};
5use std::{
6    pin::Pin,
7    task::{Context, Poll},
8};
9use tokio::time::Sleep;
10
11impl<M> Channel<M> {
12    pub fn send(&self, msg: M) -> SendFut<'_, M> {
13        SendFut::new(self, msg)
14    }
15
16    pub fn send_now(&self, msg: M) -> Result<(), TrySendError<M>> {
17        Ok(self.push_msg(msg)?)
18    }
19
20    pub fn try_send(&self, msg: M) -> Result<(), TrySendError<M>> {
21        match self.capacity() {
22            Capacity::Bounded(_) => Ok(self.push_msg(msg)?),
23            Capacity::Unbounded(backoff) => match backoff.get_timeout(self.msg_count()) {
24                Some(_) => Err(TrySendError::Full(msg)),
25                None => Ok(self.push_msg(msg)?),
26            },
27        }
28    }
29
30    pub fn send_blocking(&self, mut msg: M) -> Result<(), SendError<M>> {
31        match self.capacity() {
32            Capacity::Bounded(_) => loop {
33                msg = match self.push_msg(msg) {
34                    Ok(()) => {
35                        return Ok(());
36                    }
37                    Err(PushError::Closed(msg)) => {
38                        return Err(SendError(msg));
39                    }
40                    Err(PushError::Full(msg)) => msg,
41                };
42
43                self.get_send_listener().wait();
44            },
45            Capacity::Unbounded(backoff) => {
46                let timeout = backoff.get_timeout(self.msg_count());
47                if let Some(timeout) = timeout {
48                    std::thread::sleep(timeout);
49                }
50                self.push_msg(msg).map_err(|e| match e {
51                    PushError::Full(_) => unreachable!("unbounded"),
52                    PushError::Closed(msg) => SendError(msg),
53                })
54            }
55        }
56    }
57}
58
59/// The send-future, this can be `.await`-ed to send the message.
60#[derive(Debug)]
61pub struct SendFut<'a, M> {
62    channel: &'a Channel<M>,
63    msg: Option<M>,
64    fut: Option<InnerSendFut>,
65}
66
67/// Listener for a bounded channel, sleep for an unbounded channel.
68#[derive(Debug)]
69enum InnerSendFut {
70    Listener(EventListener),
71    Sleep(Pin<Box<Sleep>>),
72}
73
74impl Unpin for InnerSendFut {}
75impl Future for InnerSendFut {
76    type Output = ();
77
78    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
79        match &mut *self {
80            InnerSendFut::Listener(listener) => listener.poll_unpin(cx),
81            InnerSendFut::Sleep(sleep) => sleep.poll_unpin(cx),
82        }
83    }
84}
85
86impl<'a, M> SendFut<'a, M> {
87    pub(crate) fn new(channel: &'a Channel<M>, msg: M) -> Self {
88        match &channel.capacity {
89            Capacity::Bounded(_) => SendFut {
90                channel,
91                msg: Some(msg),
92                fut: None,
93            },
94            Capacity::Unbounded(back_pressure) => SendFut {
95                channel,
96                msg: Some(msg),
97                fut: back_pressure
98                    .get_timeout(channel.msg_count())
99                    .map(|timeout| InnerSendFut::Sleep(Box::pin(tokio::time::sleep(timeout)))),
100            },
101        }
102    }
103
104    fn poll_bounded_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<M>>> {
105        macro_rules! try_send {
106            ($msg:ident) => {
107                match self.channel.try_send($msg) {
108                    Ok(()) => return Poll::Ready(Ok(())),
109                    Err(e) => match e {
110                        TrySendError::Closed(msg) => return Poll::Ready(Err(SendError(msg))),
111                        TrySendError::Full(msg_new) => $msg = msg_new,
112                    },
113                }
114            };
115        }
116
117        let mut msg = self.msg.take().unwrap();
118
119        try_send!(msg);
120
121        loop {
122            // Otherwise, we create the future if it doesn't exist yet.
123            if self.fut.is_none() {
124                self.fut = Some(InnerSendFut::Listener(self.channel.get_send_listener()))
125            }
126
127            try_send!(msg);
128
129            // Poll it once, and return if pending, otherwise we loop again.
130            match self.fut.as_mut().unwrap().poll_unpin(cx) {
131                Poll::Ready(()) => {
132                    try_send!(msg);
133                    self.fut = None
134                }
135                Poll::Pending => {
136                    self.msg = Some(msg);
137                    return Poll::Pending;
138                }
139            }
140        }
141    }
142
143    fn poll_unbounded_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<M>>> {
144        if let Some(fut) = &mut self.fut {
145            match fut.poll_unpin(cx) {
146                Poll::Ready(()) => self.poll_push_unbounded(),
147                Poll::Pending => Poll::Pending,
148            }
149        } else {
150            self.poll_push_unbounded()
151        }
152    }
153
154    fn poll_push_unbounded(&mut self) -> Poll<Result<(), SendError<M>>> {
155        let msg = self.msg.take().unwrap();
156        match self.channel.push_msg(msg) {
157            Ok(()) => Poll::Ready(Ok(())),
158            Err(PushError::Closed(msg)) => Poll::Ready(Err(SendError(msg))),
159            Err(PushError::Full(_msg)) => unreachable!(),
160        }
161    }
162}
163
164impl<'a, M> Unpin for SendFut<'a, M> {}
165
166impl<'a, M> Future for SendFut<'a, M> {
167    type Output = Result<(), SendError<M>>;
168
169    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
170        match self.channel.capacity() {
171            Capacity::Bounded(_) => self.poll_bounded_send(cx),
172            Capacity::Unbounded(_) => self.poll_unbounded_send(cx),
173        }
174    }
175}
176
177#[cfg(test)]
178mod test {
179    use std::{sync::Arc, time::Duration};
180
181    use tokio::time::Instant;
182
183    use crate::*;
184
185    #[test]
186    fn try_send_with_space() {
187        let channel = Channel::<()>::new(1, 1, Capacity::Bounded(10));
188        channel.try_send(()).unwrap();
189        channel.send_now(()).unwrap();
190        assert_eq!(channel.msg_count(), 2);
191
192        let channel = Channel::<()>::new(1, 1, Capacity::Unbounded(BackPressure::disabled()));
193        channel.try_send(()).unwrap();
194        channel.send_now(()).unwrap();
195        assert_eq!(channel.msg_count(), 2);
196    }
197
198    #[test]
199    fn try_send_unbounded_full() {
200        let channel = Channel::<()>::new(
201            1,
202            1,
203            Capacity::Unbounded(BackPressure::linear(0, Duration::from_secs(1))),
204        );
205        assert_eq!(channel.try_send(()), Err(TrySendError::Full(())));
206        assert_eq!(channel.send_now(()), Ok(()));
207        assert_eq!(channel.msg_count(), 1);
208    }
209
210    #[test]
211    fn try_send_bounded_full() {
212        let channel = Channel::<()>::new(1, 1, Capacity::Bounded(1));
213        channel.try_send(()).unwrap();
214        assert_eq!(channel.try_send(()), Err(TrySendError::Full(())));
215        assert_eq!(channel.send_now(()), Err(TrySendError::Full(())));
216        assert_eq!(channel.msg_count(), 1);
217    }
218
219    #[tokio::test]
220    async fn send_with_space() {
221        let channel = Channel::<()>::new(1, 1, Capacity::Bounded(10));
222        channel.send(()).await.unwrap();
223        assert_eq!(channel.msg_count(), 1);
224
225        let channel = Channel::<()>::new(1, 1, Capacity::Unbounded(BackPressure::disabled()));
226        channel.send(()).await.unwrap();
227        assert_eq!(channel.msg_count(), 1);
228    }
229
230    #[tokio::test]
231    async fn send_unbounded_full() {
232        let channel = Channel::<()>::new(
233            1,
234            1,
235            Capacity::Unbounded(BackPressure::linear(0, Duration::from_millis(1))),
236        );
237        let time = Instant::now();
238        channel.send(()).await.unwrap();
239        channel.send(()).await.unwrap();
240        channel.send(()).await.unwrap();
241        assert!(time.elapsed().as_millis() > 6);
242        assert_eq!(channel.msg_count(), 3);
243    }
244
245    #[tokio::test]
246    async fn send_bounded_full() {
247        let channel = Arc::new(Channel::<()>::new(1, 1, Capacity::Bounded(1)));
248        let channel_clone = channel.clone();
249
250        tokio::task::spawn(async move {
251            let time = Instant::now();
252            channel_clone.send(()).await.unwrap();
253            channel_clone.send(()).await.unwrap();
254            channel_clone.send(()).await.unwrap();
255            assert!(time.elapsed().as_millis() > 2);
256        });
257
258        channel.recv(&mut false, &mut None).await.unwrap();
259        tokio::time::sleep(Duration::from_millis(1)).await;
260        channel.recv(&mut false, &mut None).await.unwrap();
261        tokio::time::sleep(Duration::from_millis(1)).await;
262        channel.recv(&mut false, &mut None).await.unwrap();
263    }
264}