task_exec_queue/
spawner.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use crate::exec::PendingOnce;
8use futures::channel::oneshot;
9use futures::task::AtomicWaker;
10use futures::{Future, Sink, SinkExt};
11use futures_lite::FutureExt;
12
13use crate::TaskType;
14
15use super::{assert_future, Error, ErrorType, TaskExecQueue};
16
17pub struct GroupSpawner<'a, Item, Tx, G> {
18    inner: Spawner<'a, Item, Tx, G, ()>,
19    name: Option<G>,
20}
21
22impl<Item, Tx, G> Unpin for GroupSpawner<'_, Item, Tx, G> {}
23
24impl<'a, Item, Tx, G> GroupSpawner<'a, Item, Tx, G>
25where
26    Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
27    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
28{
29    #[inline]
30    pub(crate) fn new(inner: Spawner<'a, Item, Tx, G, ()>, name: G) -> Self {
31        Self {
32            inner,
33            name: Some(name),
34        }
35    }
36
37    #[inline]
38    pub fn quickly(mut self) -> Self {
39        self.inner.quickly = true;
40        self
41    }
42
43    #[inline]
44    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
45    where
46        Item: Future + Send + 'static,
47        Item::Output: Send + 'static,
48    {
49        if self.inner.sink.is_closed() {
50            return Err(Error::SendError(ErrorType::Closed(self.inner.item.take())));
51        }
52
53        if !self.inner.quickly && self.inner.sink.is_full() {
54            let w = Arc::new(AtomicWaker::new());
55            self.inner.sink.waiting_wakers.push(w.clone());
56            PendingOnce::new(w).await;
57        }
58
59        let task = match self.inner.item.take() {
60            Some(task) => task,
61            None => {
62                log::error!("polled Feed after completion, task is None!");
63                return Err(Error::SendError(ErrorType::Closed(None)));
64            }
65        };
66
67        let name = match self.name.take() {
68            Some(name) => name,
69            None => {
70                log::error!("polled Feed after completion, name is None!");
71                return Err(Error::SendError(ErrorType::Closed(None)));
72            }
73        };
74
75        let (res_tx, res_rx) = oneshot::channel();
76        let waiting_count = self.inner.sink.waiting_count.clone();
77        let waiting_wakers = self.inner.sink.waiting_wakers.clone();
78        let task = async move {
79            waiting_count.dec();
80            if let Some(w) = waiting_wakers.pop() {
81                w.wake();
82            }
83            let output = task.await;
84            if let Err(_e) = res_tx.send(output) {
85                log::warn!("send result failed");
86            }
87        };
88        self.inner.sink.waiting_count.inc();
89
90        if let Err(_e) = self
91            .inner
92            .sink
93            .group_send(name, Box::new(Box::pin(task)))
94            .await
95        {
96            self.inner.sink.waiting_count.dec();
97            Err(Error::SendError(ErrorType::Closed(None)))
98        } else {
99            res_rx.await.map_err(|_| {
100                self.inner.sink.waiting_count.dec();
101                Error::RecvResultError
102            })
103        }
104    }
105}
106
107impl<Item, Tx, G> Future for GroupSpawner<'_, Item, Tx, G>
108where
109    Item: Future + Send + 'static,
110    Item::Output: Send + 'static,
111    Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
112    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
113{
114    type Output = Result<(), Error<Item>>;
115
116    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
117        let this = self.get_mut();
118        if this.inner.sink.is_closed() && !this.inner.is_pending {
119            return Poll::Ready(Err(Error::SendError(ErrorType::Closed(
120                this.inner.item.take(),
121            ))));
122        }
123
124        if !this.inner.quickly && this.inner.sink.is_full() {
125            let w = Arc::new(AtomicWaker::new());
126            w.register(cx.waker());
127            this.inner.sink.waiting_wakers.push(w);
128            this.inner.is_pending = true;
129            return Poll::Pending;
130        }
131
132        let task = match this.inner.item.take() {
133            Some(task) => task,
134            None => {
135                log::error!("polled Feed after completion, task is None!");
136                return Poll::Ready(Ok(()));
137            }
138        };
139
140        let name = match this.name.take() {
141            Some(name) => name,
142            None => {
143                log::error!("polled Feed after completion, name is None!");
144                return Poll::Ready(Ok(()));
145            }
146        };
147
148        let waiting_count = this.inner.sink.waiting_count.clone();
149        let waiting_wakers = this.inner.sink.waiting_wakers.clone();
150        let task = async move {
151            waiting_count.dec();
152            if let Some(w) = waiting_wakers.pop() {
153                w.wake();
154            }
155            let _ = task.await;
156        };
157        this.inner.sink.waiting_count.inc();
158
159        let mut group_send = this
160            .inner
161            .sink
162            .group_send(name, Box::new(Box::pin(task)))
163            .boxed();
164
165        if (futures::ready!(group_send.poll(cx))).is_err() {
166            this.inner.sink.waiting_count.dec();
167            Poll::Ready(Err(Error::SendError(ErrorType::Closed(None))))
168        } else {
169            Poll::Ready(Ok(()))
170        }
171    }
172}
173
174pub struct TryGroupSpawner<'a, Item, Tx, G> {
175    inner: GroupSpawner<'a, Item, Tx, G>,
176}
177
178impl<Item, Tx, G> Unpin for TryGroupSpawner<'_, Item, Tx, G> {}
179
180impl<'a, Item, Tx, G> TryGroupSpawner<'a, Item, Tx, G>
181where
182    Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
183    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
184{
185    #[inline]
186    pub(crate) fn new(inner: Spawner<'a, Item, Tx, G, ()>, name: G) -> Self {
187        Self {
188            inner: GroupSpawner {
189                inner,
190                name: Some(name),
191            },
192        }
193    }
194
195    #[inline]
196    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
197    where
198        Item: Future + Send + 'static,
199        Item::Output: Send + 'static,
200    {
201        if self.inner.inner.sink.is_full() {
202            return Err(Error::TrySendError(ErrorType::Full(
203                self.inner.inner.item.take(),
204            )));
205        }
206        self.inner.result().await
207    }
208}
209
210impl<Item, Tx, G> Future for TryGroupSpawner<'_, Item, Tx, G>
211where
212    Item: Future + Send + 'static,
213    Item::Output: Send + 'static,
214    Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
215    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
216{
217    type Output = Result<(), Error<Item>>;
218
219    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220        let this = self.get_mut();
221
222        if this.inner.inner.sink.is_full() {
223            return Poll::Ready(Err(Error::TrySendError(ErrorType::Full(
224                this.inner.inner.item.take(),
225            ))));
226        }
227
228        this.inner.poll(cx)
229    }
230}
231
232pub struct Spawner<'a, Item, Tx, G, D> {
233    sink: &'a TaskExecQueue<Tx, G, D>,
234    item: Option<Item>,
235    d: Option<D>,
236    quickly: bool,
237    is_pending: bool,
238}
239
240impl<'a, Item, Tx, G, D> Unpin for Spawner<'a, Item, Tx, G, D> {}
241
242impl<'a, Item, Tx, G> Spawner<'a, Item, Tx, G, ()>
243where
244    Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
245    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
246{
247    #[inline]
248    pub fn group(self, name: G) -> GroupSpawner<'a, Item, Tx, G>
249    where
250        Item: Future + Send + 'static,
251        Item::Output: Send + 'static,
252    {
253        let fut = GroupSpawner::new(self, name);
254        assert_future::<Result<(), _>, _>(fut)
255    }
256}
257
258impl<'a, Item, Tx, G, D> Spawner<'a, Item, Tx, G, D>
259where
260    Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
261    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
262{
263    #[inline]
264    pub(crate) fn new(sink: &'a TaskExecQueue<Tx, G, D>, item: Item, d: D) -> Self {
265        Self {
266            sink,
267            item: Some(item),
268            d: Some(d),
269            quickly: false,
270            is_pending: false,
271        }
272    }
273
274    #[inline]
275    pub fn quickly(mut self) -> Self {
276        self.quickly = true;
277        self
278    }
279
280    #[inline]
281    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
282    where
283        Item: Future + Send + 'static,
284        Item::Output: Send + 'static,
285    {
286        if self.sink.is_closed() {
287            return Err(Error::SendError(ErrorType::Closed(self.item.take())));
288        }
289
290        if !self.quickly && self.sink.is_full() {
291            let w = Arc::new(AtomicWaker::new());
292            self.sink.waiting_wakers.push(w.clone());
293            PendingOnce::new(w).await;
294        }
295
296        let task = self
297            .item
298            .take()
299            .expect("polled Feed after completion, task is None!");
300        let d = self
301            .d
302            .take()
303            .expect("polled Feed after completion, d is None!");
304
305        let (res_tx, res_rx) = oneshot::channel();
306        let waiting_count = self.sink.waiting_count.clone();
307        let waiting_wakers = self.sink.waiting_wakers.clone();
308        let task = async move {
309            waiting_count.dec();
310            if let Some(w) = waiting_wakers.pop() {
311                w.wake();
312            }
313            let output = task.await;
314            if let Err(_e) = res_tx.send(output) {
315                log::warn!("send result failed");
316            }
317        };
318
319        self.sink.waiting_count.inc();
320        if self
321            .sink
322            .tx
323            .clone()
324            .send((d, Box::new(Box::pin(task))))
325            .await
326            .is_err()
327        {
328            self.sink.waiting_count.dec();
329            return Err(Error::SendError(ErrorType::Closed(None)));
330        }
331        res_rx.await.map_err(|_| {
332            self.sink.waiting_count.dec();
333            Error::RecvResultError
334        })
335    }
336}
337
338impl<Item, Tx, G, D> Future for Spawner<'_, Item, Tx, G, D>
339where
340    Item: Future + Send + 'static,
341    Item::Output: Send + 'static,
342    Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
343    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
344{
345    type Output = Result<(), Error<Item>>;
346
347    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
348        let this = self.get_mut();
349
350        if this.sink.is_closed() && !this.is_pending {
351            return Poll::Ready(Err(Error::SendError(ErrorType::Closed(this.item.take()))));
352        }
353
354        if !this.quickly && this.sink.is_full() {
355            let w = Arc::new(AtomicWaker::new());
356            w.register(cx.waker());
357            this.sink.waiting_wakers.push(w);
358            this.is_pending = true;
359            return Poll::Pending;
360        }
361
362        let task = match this.item.take() {
363            Some(task) => task,
364            None => {
365                log::error!("polled Feed after completion, task is None!");
366                return Poll::Ready(Ok(()));
367            }
368        };
369
370        let d = match this.d.take() {
371            Some(d) => d,
372            None => {
373                log::error!("polled Feed after completion, d is None!");
374                return Poll::Ready(Ok(()));
375            }
376        };
377
378        let mut tx = this.sink.tx.clone();
379        let mut sink = Pin::new(&mut tx);
380
381        let waiting_count = this.sink.waiting_count.clone();
382        let waiting_wakers = this.sink.waiting_wakers.clone();
383        let task = async move {
384            waiting_count.dec();
385            if let Some(w) = waiting_wakers.pop() {
386                w.wake();
387            }
388            let _ = task.await;
389        };
390        this.sink.waiting_count.inc();
391        sink.as_mut()
392            .start_send((d, Box::new(Box::pin(task))))
393            .map_err(|_e| {
394                this.sink.waiting_count.dec();
395                Error::SendError(ErrorType::Closed(None))
396            })?;
397        Poll::Ready(Ok(()))
398    }
399}
400
401pub struct TrySpawner<'a, Item, Tx, G, D> {
402    inner: Spawner<'a, Item, Tx, G, D>,
403}
404
405impl<'a, Item, Tx, G, D> Unpin for TrySpawner<'a, Item, Tx, G, D> {}
406
407impl<'a, Item, Tx, G> TrySpawner<'a, Item, Tx, G, ()>
408where
409    Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
410    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
411{
412    #[inline]
413    pub fn group(self, name: G) -> TryGroupSpawner<'a, Item, Tx, G>
414    where
415        Item: Future + Send + 'static,
416        Item::Output: Send + 'static,
417    {
418        let fut = TryGroupSpawner::new(self.inner, name);
419        assert_future::<Result<(), _>, _>(fut)
420    }
421}
422
423impl<'a, Item, Tx, G, D> TrySpawner<'a, Item, Tx, G, D>
424where
425    Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
426    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
427{
428    #[inline]
429    pub(crate) fn new(sink: &'a TaskExecQueue<Tx, G, D>, item: Item, d: D) -> Self {
430        Self {
431            inner: Spawner {
432                sink,
433                item: Some(item),
434                d: Some(d),
435                quickly: false,
436                is_pending: false,
437            },
438        }
439    }
440
441    #[inline]
442    pub fn quickly(mut self) -> Self {
443        self.inner.quickly = true;
444        self
445    }
446
447    #[inline]
448    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
449    where
450        Item: Future + Send + 'static,
451        Item::Output: Send + 'static,
452    {
453        if self.inner.sink.is_full() {
454            return Err(Error::TrySendError(ErrorType::Full(self.inner.item.take())));
455        }
456        self.inner.result().await
457    }
458}
459
460impl<Item, Tx, G, D> Future for TrySpawner<'_, Item, Tx, G, D>
461where
462    Item: Future + Send + 'static,
463    Item::Output: Send + 'static,
464    Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
465    G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
466{
467    type Output = Result<(), Error<Item>>;
468
469    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
470        let this = self.get_mut();
471        if this.inner.sink.is_full() {
472            return Poll::Ready(Err(Error::TrySendError(ErrorType::Full(
473                this.inner.item.take(),
474            ))));
475        }
476        this.inner.poll(cx)
477    }
478}