task_exec_queue/
local_spawner.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::rc::Rc;
5use std::task::{Context, Poll};
6
7use crate::local::LocalPendingOnce;
8use futures::channel::oneshot;
9use futures::task::AtomicWaker;
10use futures::{Future, Sink, SinkExt};
11use futures_lite::FutureExt;
12
13use crate::LocalTaskType;
14
15use super::{assert_future, Error, ErrorType, LocalTaskExecQueue};
16
17pub struct LocalGroupSpawner<'a, Item, Tx, G> {
18    inner: LocalSpawner<'a, Item, Tx, G, ()>,
19    name: Option<G>,
20}
21
22impl<Item, Tx, G> Unpin for LocalGroupSpawner<'_, Item, Tx, G> {}
23
24impl<'a, Item, Tx, G> LocalGroupSpawner<'a, Item, Tx, G>
25where
26    Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
27    G: Hash + Eq + Clone + Debug + 'static,
28{
29    #[inline]
30    pub(crate) fn new(inner: LocalSpawner<'a, Item, Tx, G, ()>, name: G) -> Self {
31        Self {
32            inner,
33            name: Some(name),
34        }
35    }
36
37    #[inline]
38    pub fn quickly(mut self) -> Self {
39        self.inner.quickly = true;
40        self
41    }
42
43    #[inline]
44    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
45    where
46        Item: Future + 'static,
47        Item::Output: 'static,
48    {
49        if self.inner.sink.is_closed() {
50            return Err(Error::SendError(ErrorType::Closed(self.inner.item.take())));
51        }
52
53        if !self.inner.quickly && self.inner.sink.is_full() {
54            let w = Rc::new(AtomicWaker::new());
55            self.inner.sink.waiting_wakers.push(w.clone());
56            LocalPendingOnce::new(w).await;
57        }
58
59        let task = match self.inner.item.take() {
60            Some(task) => task,
61            None => {
62                log::error!("polled Feed after completion, task is None!");
63                return Err(Error::SendError(ErrorType::Closed(None)));
64            }
65        };
66
67        let name = match self.name.take() {
68            Some(name) => name,
69            None => {
70                log::error!("polled Feed after completion, name is None!");
71                return Err(Error::SendError(ErrorType::Closed(None)));
72            }
73        };
74
75        let (res_tx, res_rx) = oneshot::channel();
76        let waiting_count = self.inner.sink.waiting_count.clone();
77        let waiting_wakers = self.inner.sink.waiting_wakers.clone();
78        let task = async move {
79            waiting_count.dec();
80            if let Some(w) = waiting_wakers.pop() {
81                w.wake();
82            }
83            let output = task.await;
84            if let Err(_e) = res_tx.send(output) {
85                log::warn!("send result failed");
86            }
87        };
88        self.inner.sink.waiting_count.inc();
89
90        if let Err(_e) = self
91            .inner
92            .sink
93            .group_send(name, Box::new(Box::pin(task)))
94            .await
95        {
96            self.inner.sink.waiting_count.dec();
97            Err(Error::SendError(ErrorType::Closed(None)))
98        } else {
99            res_rx.await.map_err(|_| {
100                self.inner.sink.waiting_count.dec();
101                Error::RecvResultError
102            })
103        }
104    }
105}
106
107impl<Item, Tx, G> Future for LocalGroupSpawner<'_, Item, Tx, G>
108where
109    Item: Future + 'static,
110    Item::Output: 'static,
111    Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
112    G: Hash + Eq + Clone + Debug + 'static,
113{
114    type Output = Result<(), Error<Item>>;
115
116    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
117        let this = self.get_mut();
118        if this.inner.sink.is_closed() && !this.inner.is_pending {
119            return Poll::Ready(Err(Error::SendError(ErrorType::Closed(
120                this.inner.item.take(),
121            ))));
122        }
123
124        if !this.inner.quickly && this.inner.sink.is_full() {
125            let w = Rc::new(AtomicWaker::new());
126            w.register(cx.waker());
127            this.inner.sink.waiting_wakers.push(w);
128            this.inner.is_pending = true;
129            return Poll::Pending;
130        }
131
132        let task = match this.inner.item.take() {
133            Some(task) => task,
134            None => {
135                log::error!("polled Feed after completion, task is None!");
136                return Poll::Ready(Ok(()));
137            }
138        };
139
140        let name = match this.name.take() {
141            Some(name) => name,
142            None => {
143                log::error!("polled Feed after completion, name is None!");
144                return Poll::Ready(Ok(()));
145            }
146        };
147
148        if this.inner.sink.is_closed() {
149            return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
150        }
151        let waiting_count = this.inner.sink.waiting_count.clone();
152        let waiting_wakers = this.inner.sink.waiting_wakers.clone();
153        let task = async move {
154            waiting_count.dec();
155            if let Some(w) = waiting_wakers.pop() {
156                w.wake();
157            }
158            let _ = task.await;
159        };
160        this.inner.sink.waiting_count.inc();
161
162        let mut group_send = this
163            .inner
164            .sink
165            .group_send(name, Box::new(Box::pin(task)))
166            .boxed_local();
167
168        if (futures::ready!(group_send.poll(cx))).is_err() {
169            this.inner.sink.waiting_count.dec();
170            Poll::Ready(Err(Error::SendError(ErrorType::Closed(None))))
171        } else {
172            Poll::Ready(Ok(()))
173        }
174    }
175}
176
177pub struct TryLocalGroupSpawner<'a, Item, Tx, G> {
178    inner: LocalGroupSpawner<'a, Item, Tx, G>,
179}
180
181impl<Item, Tx, G> Unpin for TryLocalGroupSpawner<'_, Item, Tx, G> {}
182
183impl<'a, Item, Tx, G> TryLocalGroupSpawner<'a, Item, Tx, G>
184where
185    Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
186    G: Hash + Eq + Clone + Debug + 'static,
187{
188    #[inline]
189    pub(crate) fn new(inner: LocalSpawner<'a, Item, Tx, G, ()>, name: G) -> Self {
190        Self {
191            inner: LocalGroupSpawner {
192                inner,
193                name: Some(name),
194            },
195        }
196    }
197
198    #[inline]
199    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
200    where
201        Item: Future + 'static,
202        Item::Output: 'static,
203    {
204        if self.inner.inner.sink.is_full() {
205            return Err(Error::TrySendError(ErrorType::Full(
206                self.inner.inner.item.take(),
207            )));
208        }
209        self.inner.result().await
210    }
211}
212
213impl<Item, Tx, G> Future for TryLocalGroupSpawner<'_, Item, Tx, G>
214where
215    Item: Future + 'static,
216    Item::Output: 'static,
217    Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
218    G: Hash + Eq + Clone + Debug + 'static,
219{
220    type Output = Result<(), Error<Item>>;
221
222    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
223        let this = self.get_mut();
224
225        if this.inner.inner.sink.is_full() {
226            return Poll::Ready(Err(Error::TrySendError(ErrorType::Full(
227                this.inner.inner.item.take(),
228            ))));
229        }
230
231        this.inner.poll(cx)
232    }
233}
234
235pub struct LocalSpawner<'a, Item, Tx, G, D> {
236    sink: &'a LocalTaskExecQueue<Tx, G, D>,
237    item: Option<Item>,
238    d: Option<D>,
239    quickly: bool,
240    is_pending: bool,
241}
242
243impl<'a, Item, Tx, G, D> Unpin for LocalSpawner<'a, Item, Tx, G, D> {}
244
245impl<'a, Item, Tx, G> LocalSpawner<'a, Item, Tx, G, ()>
246where
247    Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
248    G: Hash + Eq + Clone + Debug + 'static,
249{
250    #[inline]
251    pub fn group(self, name: G) -> LocalGroupSpawner<'a, Item, Tx, G>
252    where
253        Item: Future + 'static,
254        Item::Output: 'static,
255    {
256        let fut = LocalGroupSpawner::new(self, name);
257        assert_future::<Result<(), _>, _>(fut)
258    }
259}
260
261impl<'a, Item, Tx, G, D> LocalSpawner<'a, Item, Tx, G, D>
262where
263    Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + 'static,
264    G: Hash + Eq + Clone + Debug + 'static,
265{
266    #[inline]
267    pub(crate) fn new(sink: &'a LocalTaskExecQueue<Tx, G, D>, item: Item, d: D) -> Self {
268        Self {
269            sink,
270            item: Some(item),
271            d: Some(d),
272            quickly: false,
273            is_pending: false,
274        }
275    }
276
277    #[inline]
278    pub fn quickly(mut self) -> Self {
279        self.quickly = true;
280        self
281    }
282
283    #[inline]
284    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
285    where
286        Item: Future + 'static,
287        Item::Output: 'static,
288    {
289        if self.sink.is_closed() {
290            return Err(Error::SendError(ErrorType::Closed(self.item.take())));
291        }
292
293        if !self.quickly && self.sink.is_full() {
294            let w = Rc::new(AtomicWaker::new());
295            self.sink.waiting_wakers.push(w.clone());
296            LocalPendingOnce::new(w).await;
297        }
298
299        let task = self
300            .item
301            .take()
302            .expect("polled Feed after completion, task is None!");
303        let d = self
304            .d
305            .take()
306            .expect("polled Feed after completion, d is None!");
307
308        let (res_tx, res_rx) = oneshot::channel();
309        let waiting_count = self.sink.waiting_count.clone();
310        let waiting_wakers = self.sink.waiting_wakers.clone();
311        let task = async move {
312            waiting_count.dec();
313            if let Some(w) = waiting_wakers.pop() {
314                w.wake();
315            }
316            let output = task.await;
317            if let Err(_e) = res_tx.send(output) {
318                log::warn!("send result failed");
319            }
320        };
321        self.sink.waiting_count.inc();
322
323        if self
324            .sink
325            .tx
326            .clone()
327            .send((d, Box::new(Box::pin(task))))
328            .await
329            .is_err()
330        {
331            self.sink.waiting_count.dec();
332            return Err(Error::SendError(ErrorType::Closed(None)));
333        }
334        res_rx.await.map_err(|_| {
335            self.sink.waiting_count.dec();
336            Error::RecvResultError
337        })
338    }
339}
340
341impl<Item, Tx, G, D> Future for LocalSpawner<'_, Item, Tx, G, D>
342where
343    Item: Future + 'static,
344    Item::Output: 'static,
345    Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + 'static,
346    G: Hash + Eq + Clone + Debug + 'static,
347{
348    type Output = Result<(), Error<Item>>;
349
350    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
351        let this = self.get_mut();
352
353        if this.sink.is_closed() && !this.is_pending {
354            return Poll::Ready(Err(Error::SendError(ErrorType::Closed(this.item.take()))));
355        }
356
357        if !this.quickly && this.sink.is_full() {
358            let w = Rc::new(AtomicWaker::new());
359            w.register(cx.waker());
360            this.sink.waiting_wakers.push(w);
361            this.is_pending = true;
362            return Poll::Pending;
363        }
364
365        let task = match this.item.take() {
366            Some(task) => task,
367            None => {
368                log::error!("polled Feed after completion, task is None!");
369                return Poll::Ready(Ok(()));
370            }
371        };
372
373        let d = match this.d.take() {
374            Some(d) => d,
375            None => {
376                log::error!("polled Feed after completion, d is None!");
377                return Poll::Ready(Ok(()));
378            }
379        };
380
381        let mut tx = this.sink.tx.clone();
382        let mut sink = Pin::new(&mut tx);
383        //futures::ready!(sink.as_mut().poll_ready(cx))
384        //    .map_err(|_| Error::SendError(ErrorType::Closed(None)))?;
385        let waiting_count = this.sink.waiting_count.clone();
386        let waiting_wakers = this.sink.waiting_wakers.clone();
387        let task = async move {
388            waiting_count.dec();
389            if let Some(w) = waiting_wakers.pop() {
390                w.wake();
391            }
392            let _ = task.await;
393        };
394        this.sink.waiting_count.inc();
395        sink.as_mut()
396            .start_send((d, Box::new(Box::pin(task))))
397            .map_err(|_e| {
398                this.sink.waiting_count.dec();
399                Error::SendError(ErrorType::Closed(None))
400            })?;
401        Poll::Ready(Ok(()))
402    }
403}
404
405pub struct TryLocalSpawner<'a, Item, Tx, G, D> {
406    inner: LocalSpawner<'a, Item, Tx, G, D>,
407}
408
409impl<'a, Item, Tx, G, D> Unpin for TryLocalSpawner<'a, Item, Tx, G, D> {}
410
411impl<'a, Item, Tx, G> TryLocalSpawner<'a, Item, Tx, G, ()>
412where
413    Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
414    G: Hash + Eq + Clone + Debug + 'static,
415{
416    #[inline]
417    pub fn group(self, name: G) -> TryLocalGroupSpawner<'a, Item, Tx, G>
418    where
419        Item: Future + 'static,
420        Item::Output: 'static,
421    {
422        let fut = TryLocalGroupSpawner::new(self.inner, name);
423        assert_future::<Result<(), _>, _>(fut)
424    }
425}
426
427impl<'a, Item, Tx, G, D> TryLocalSpawner<'a, Item, Tx, G, D>
428where
429    Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + 'static,
430    G: Hash + Eq + Clone + Debug + 'static,
431{
432    #[inline]
433    pub(crate) fn new(sink: &'a LocalTaskExecQueue<Tx, G, D>, item: Item, d: D) -> Self {
434        Self {
435            inner: LocalSpawner {
436                sink,
437                item: Some(item),
438                d: Some(d),
439                quickly: false,
440                is_pending: false,
441            },
442        }
443    }
444
445    #[inline]
446    pub fn quickly(mut self) -> Self {
447        self.inner.quickly = true;
448        self
449    }
450
451    #[inline]
452    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
453    where
454        Item: Future + 'static,
455        Item::Output: 'static,
456    {
457        if self.inner.sink.is_full() {
458            return Err(Error::TrySendError(ErrorType::Full(self.inner.item.take())));
459        }
460        self.inner.result().await
461    }
462}
463
464impl<Item, Tx, G, D> Future for TryLocalSpawner<'_, Item, Tx, G, D>
465where
466    Item: Future + 'static,
467    Item::Output: 'static,
468    Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + 'static,
469    G: Hash + Eq + Clone + Debug + 'static,
470{
471    type Output = Result<(), Error<Item>>;
472
473    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
474        let this = self.get_mut();
475        if this.inner.sink.is_full() {
476            return Poll::Ready(Err(Error::TrySendError(ErrorType::Full(
477                this.inner.item.take(),
478            ))));
479        }
480        this.inner.poll(cx)
481    }
482}