1use crate::*;
2use concurrent_queue::PushError;
3use event_listener::EventListener;
4use futures::{Future, FutureExt};
5use std::{
6 pin::Pin,
7 task::{Context, Poll},
8};
9use tokio::time::Sleep;
10
11impl<M> Channel<M> {
12 pub fn send(&self, msg: M) -> SendFut<'_, M> {
13 SendFut::new(self, msg)
14 }
15
16 pub fn send_now(&self, msg: M) -> Result<(), TrySendError<M>> {
17 Ok(self.push_msg(msg)?)
18 }
19
20 pub fn try_send(&self, msg: M) -> Result<(), TrySendError<M>> {
21 match self.capacity() {
22 Capacity::Bounded(_) => Ok(self.push_msg(msg)?),
23 Capacity::Unbounded(backoff) => match backoff.get_timeout(self.msg_count()) {
24 Some(_) => Err(TrySendError::Full(msg)),
25 None => Ok(self.push_msg(msg)?),
26 },
27 }
28 }
29
30 pub fn send_blocking(&self, mut msg: M) -> Result<(), SendError<M>> {
31 match self.capacity() {
32 Capacity::Bounded(_) => loop {
33 msg = match self.push_msg(msg) {
34 Ok(()) => {
35 return Ok(());
36 }
37 Err(PushError::Closed(msg)) => {
38 return Err(SendError(msg));
39 }
40 Err(PushError::Full(msg)) => msg,
41 };
42
43 self.get_send_listener().wait();
44 },
45 Capacity::Unbounded(backoff) => {
46 let timeout = backoff.get_timeout(self.msg_count());
47 if let Some(timeout) = timeout {
48 std::thread::sleep(timeout);
49 }
50 self.push_msg(msg).map_err(|e| match e {
51 PushError::Full(_) => unreachable!("unbounded"),
52 PushError::Closed(msg) => SendError(msg),
53 })
54 }
55 }
56 }
57}
58
59#[derive(Debug)]
61pub struct SendFut<'a, M> {
62 channel: &'a Channel<M>,
63 msg: Option<M>,
64 fut: Option<InnerSendFut>,
65}
66
67#[derive(Debug)]
69enum InnerSendFut {
70 Listener(EventListener),
71 Sleep(Pin<Box<Sleep>>),
72}
73
74impl Unpin for InnerSendFut {}
75impl Future for InnerSendFut {
76 type Output = ();
77
78 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
79 match &mut *self {
80 InnerSendFut::Listener(listener) => listener.poll_unpin(cx),
81 InnerSendFut::Sleep(sleep) => sleep.poll_unpin(cx),
82 }
83 }
84}
85
86impl<'a, M> SendFut<'a, M> {
87 pub(crate) fn new(channel: &'a Channel<M>, msg: M) -> Self {
88 match &channel.capacity {
89 Capacity::Bounded(_) => SendFut {
90 channel,
91 msg: Some(msg),
92 fut: None,
93 },
94 Capacity::Unbounded(back_pressure) => SendFut {
95 channel,
96 msg: Some(msg),
97 fut: back_pressure
98 .get_timeout(channel.msg_count())
99 .map(|timeout| InnerSendFut::Sleep(Box::pin(tokio::time::sleep(timeout)))),
100 },
101 }
102 }
103
104 fn poll_bounded_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<M>>> {
105 macro_rules! try_send {
106 ($msg:ident) => {
107 match self.channel.try_send($msg) {
108 Ok(()) => return Poll::Ready(Ok(())),
109 Err(e) => match e {
110 TrySendError::Closed(msg) => return Poll::Ready(Err(SendError(msg))),
111 TrySendError::Full(msg_new) => $msg = msg_new,
112 },
113 }
114 };
115 }
116
117 let mut msg = self.msg.take().unwrap();
118
119 try_send!(msg);
120
121 loop {
122 if self.fut.is_none() {
124 self.fut = Some(InnerSendFut::Listener(self.channel.get_send_listener()))
125 }
126
127 try_send!(msg);
128
129 match self.fut.as_mut().unwrap().poll_unpin(cx) {
131 Poll::Ready(()) => {
132 try_send!(msg);
133 self.fut = None
134 }
135 Poll::Pending => {
136 self.msg = Some(msg);
137 return Poll::Pending;
138 }
139 }
140 }
141 }
142
143 fn poll_unbounded_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<M>>> {
144 if let Some(fut) = &mut self.fut {
145 match fut.poll_unpin(cx) {
146 Poll::Ready(()) => self.poll_push_unbounded(),
147 Poll::Pending => Poll::Pending,
148 }
149 } else {
150 self.poll_push_unbounded()
151 }
152 }
153
154 fn poll_push_unbounded(&mut self) -> Poll<Result<(), SendError<M>>> {
155 let msg = self.msg.take().unwrap();
156 match self.channel.push_msg(msg) {
157 Ok(()) => Poll::Ready(Ok(())),
158 Err(PushError::Closed(msg)) => Poll::Ready(Err(SendError(msg))),
159 Err(PushError::Full(_msg)) => unreachable!(),
160 }
161 }
162}
163
164impl<'a, M> Unpin for SendFut<'a, M> {}
165
166impl<'a, M> Future for SendFut<'a, M> {
167 type Output = Result<(), SendError<M>>;
168
169 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
170 match self.channel.capacity() {
171 Capacity::Bounded(_) => self.poll_bounded_send(cx),
172 Capacity::Unbounded(_) => self.poll_unbounded_send(cx),
173 }
174 }
175}
176
177#[cfg(test)]
178mod test {
179 use std::{sync::Arc, time::Duration};
180
181 use tokio::time::Instant;
182
183 use crate::*;
184
185 #[test]
186 fn try_send_with_space() {
187 let channel = Channel::<()>::new(1, 1, Capacity::Bounded(10));
188 channel.try_send(()).unwrap();
189 channel.send_now(()).unwrap();
190 assert_eq!(channel.msg_count(), 2);
191
192 let channel = Channel::<()>::new(1, 1, Capacity::Unbounded(BackPressure::disabled()));
193 channel.try_send(()).unwrap();
194 channel.send_now(()).unwrap();
195 assert_eq!(channel.msg_count(), 2);
196 }
197
198 #[test]
199 fn try_send_unbounded_full() {
200 let channel = Channel::<()>::new(
201 1,
202 1,
203 Capacity::Unbounded(BackPressure::linear(0, Duration::from_secs(1))),
204 );
205 assert_eq!(channel.try_send(()), Err(TrySendError::Full(())));
206 assert_eq!(channel.send_now(()), Ok(()));
207 assert_eq!(channel.msg_count(), 1);
208 }
209
210 #[test]
211 fn try_send_bounded_full() {
212 let channel = Channel::<()>::new(1, 1, Capacity::Bounded(1));
213 channel.try_send(()).unwrap();
214 assert_eq!(channel.try_send(()), Err(TrySendError::Full(())));
215 assert_eq!(channel.send_now(()), Err(TrySendError::Full(())));
216 assert_eq!(channel.msg_count(), 1);
217 }
218
219 #[tokio::test]
220 async fn send_with_space() {
221 let channel = Channel::<()>::new(1, 1, Capacity::Bounded(10));
222 channel.send(()).await.unwrap();
223 assert_eq!(channel.msg_count(), 1);
224
225 let channel = Channel::<()>::new(1, 1, Capacity::Unbounded(BackPressure::disabled()));
226 channel.send(()).await.unwrap();
227 assert_eq!(channel.msg_count(), 1);
228 }
229
230 #[tokio::test]
231 async fn send_unbounded_full() {
232 let channel = Channel::<()>::new(
233 1,
234 1,
235 Capacity::Unbounded(BackPressure::linear(0, Duration::from_millis(1))),
236 );
237 let time = Instant::now();
238 channel.send(()).await.unwrap();
239 channel.send(()).await.unwrap();
240 channel.send(()).await.unwrap();
241 assert!(time.elapsed().as_millis() > 6);
242 assert_eq!(channel.msg_count(), 3);
243 }
244
245 #[tokio::test]
246 async fn send_bounded_full() {
247 let channel = Arc::new(Channel::<()>::new(1, 1, Capacity::Bounded(1)));
248 let channel_clone = channel.clone();
249
250 tokio::task::spawn(async move {
251 let time = Instant::now();
252 channel_clone.send(()).await.unwrap();
253 channel_clone.send(()).await.unwrap();
254 channel_clone.send(()).await.unwrap();
255 assert!(time.elapsed().as_millis() > 2);
256 });
257
258 channel.recv(&mut false, &mut None).await.unwrap();
259 tokio::time::sleep(Duration::from_millis(1)).await;
260 channel.recv(&mut false, &mut None).await.unwrap();
261 tokio::time::sleep(Duration::from_millis(1)).await;
262 channel.recv(&mut false, &mut None).await.unwrap();
263 }
264}