task_executor/
spawner.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::{Future, Sink, SinkExt};
7use futures::channel::oneshot;
8
9use crate::TaskType;
10
11use super::{assert_future, Error, ErrorType, Executor};
12
13pub struct GroupSpawner<'a, Item, Tx, G> {
14    inner: Spawner<'a, Item, Tx, G, ()>,
15    name: Option<G>,
16}
17
18impl<Item, Tx, G> Unpin for GroupSpawner<'_, Item, Tx, G> {}
19
20impl<'a, Item, Tx, G> GroupSpawner<'a, Item, Tx, G>
21    where
22        Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
23        G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
24{
25    #[inline]
26    pub(crate) fn new(inner: Spawner<'a, Item, Tx, G, ()>, name: G) -> Self {
27        Self {
28            inner,
29            name: Some(name),
30        }
31    }
32
33    #[inline]
34    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
35        where
36            Item: Future + Send + 'static,
37            Item::Output: Send + 'static,
38    {
39        let task = match self.inner.item.take() {
40            Some(task) => task,
41            None => {
42                log::error!("polled Feed after completion, task is None!");
43                return Err(Error::SendError(ErrorType::Closed(None)));
44            }
45        };
46
47        let name = match self.name.take() {
48            Some(name) => name,
49            None => {
50                log::error!("polled Feed after completion, name is None!");
51                return Err(Error::SendError(ErrorType::Closed(None)));
52            }
53        };
54
55        if self.inner.sink.is_closed() {
56            return Err(Error::SendError(ErrorType::Closed(Some(task))));
57        }
58
59        let (res_tx, res_rx) = oneshot::channel();
60        let waiting_count = self.inner.sink.waiting_count.clone();
61        let task = async move {
62            waiting_count.dec();
63            let output = task.await;
64            if let Err(_e) = res_tx.send(output) {
65                log::warn!("send result failed");
66            }
67        };
68        self.inner.sink.waiting_count.inc();
69
70        if let Err(_e) = self
71            .inner
72            .sink
73            .group_send(name, Box::new(Box::pin(task)))
74            .await
75        {
76            self.inner.sink.waiting_count.dec();
77            Err(Error::SendError(ErrorType::Closed(None)))
78        } else {
79            res_rx.await.map_err(|_| {
80                self.inner.sink.waiting_count.dec();
81                Error::RecvResultError
82            })
83        }
84    }
85}
86
87impl<Item, Tx, G> Future for GroupSpawner<'_, Item, Tx, G>
88    where
89        Item: Future + Send + 'static,
90        Item::Output: Send + 'static,
91        Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
92        G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
93{
94    type Output = Result<(), Error<Item>>;
95
96    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97        let this = self.get_mut();
98        let task = match this.inner.item.take() {
99            Some(task) => task,
100            None => {
101                log::error!("polled Feed after completion, task is None!");
102                return Poll::Ready(Ok(()));
103            }
104        };
105
106        let name = match this.name.take() {
107            Some(name) => name,
108            None => {
109                log::error!("polled Feed after completion, name is None!");
110                return Poll::Ready(Ok(()));
111            }
112        };
113
114        if this.inner.sink.is_closed() {
115            return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
116        }
117        let waiting_count = this.inner.sink.waiting_count.clone();
118        let task = async move {
119            waiting_count.dec();
120            let _ = task.await;
121        };
122        this.inner.sink.waiting_count.inc();
123        let mut group_send = this
124            .inner
125            .sink
126            .group_send(name, Box::new(Box::pin(task)))
127            .boxed();
128        use futures_lite::FutureExt;
129        if (futures::ready!(group_send.poll(cx))).is_err() {
130            this.inner.sink.waiting_count.dec();
131            Poll::Ready(Err(Error::SendError(ErrorType::Closed(None))))
132        } else {
133            Poll::Ready(Ok(()))
134        }
135    }
136}
137
138pub struct Spawner<'a, Item, Tx, G, D> {
139    sink: &'a Executor<Tx, G, D>,
140    item: Option<Item>,
141    d: Option<D>,
142}
143
144impl<'a, Item, Tx, G, D> Unpin for Spawner<'a, Item, Tx, G, D> {}
145
146impl<'a, Item, Tx, G> Spawner<'a, Item, Tx, G, ()>
147    where
148        Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
149        G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
150{
151    #[inline]
152    pub fn group(self, name: G) -> GroupSpawner<'a, Item, Tx, G>
153        where
154            Item: Future + Send + 'static,
155            Item::Output: Send + 'static,
156    {
157        let fut = GroupSpawner::new(self, name);
158        assert_future::<Result<(), _>, _>(fut)
159    }
160}
161
162impl<'a, Item, Tx, G, D> Spawner<'a, Item, Tx, G, D>
163    where
164        Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
165        G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
166{
167    #[inline]
168    pub(crate) fn new(sink: &'a Executor<Tx, G, D>, item: Item, d: D) -> Self {
169        Self {
170            sink,
171            item: Some(item),
172            d: Some(d),
173        }
174    }
175
176    #[inline]
177    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
178        where
179            Item: Future + Send + 'static,
180            Item::Output: Send + 'static,
181    {
182        let task = self
183            .item
184            .take()
185            .expect("polled Feed after completion, task is None!");
186        let d = self
187            .d
188            .take()
189            .expect("polled Feed after completion, d is None!");
190
191        if self.sink.is_closed() {
192            return Err(Error::SendError(ErrorType::Closed(Some(task))));
193        }
194
195        let (res_tx, res_rx) = oneshot::channel();
196        let waiting_count = self.sink.waiting_count.clone();
197        let task = async move {
198            waiting_count.dec();
199            let output = task.await;
200            if let Err(_e) = res_tx.send(output) {
201                log::warn!("send result failed");
202            }
203        };
204        self.sink.waiting_count.inc();
205
206        if self
207            .sink
208            .tx
209            .clone()
210            .send((d, Box::new(Box::pin(task))))
211            .await
212            .is_err()
213        {
214            self.sink.waiting_count.dec();
215            return Err(Error::SendError(ErrorType::Closed(None)));
216        }
217        res_rx.await.map_err(|_| {
218            self.sink.waiting_count.dec();
219            Error::RecvResultError
220        })
221    }
222}
223
224impl<Item, Tx, G, D> Future for Spawner<'_, Item, Tx, G, D>
225    where
226        Item: Future + Send + 'static,
227        Item::Output: Send + 'static,
228        Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
229        G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
230{
231    type Output = Result<(), Error<Item>>;
232
233    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
234        let this = self.get_mut();
235        let task = match this.item.take() {
236            Some(task) => task,
237            None => {
238                log::error!("polled Feed after completion, task is None!");
239                return Poll::Ready(Ok(()));
240            }
241        };
242
243        let d = match this.d.take() {
244            Some(d) => d,
245            None => {
246                log::error!("polled Feed after completion, d is None!");
247                return Poll::Ready(Ok(()));
248            }
249        };
250
251        if this.sink.is_closed() {
252            return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
253        }
254
255        let mut tx = this.sink.tx.clone();
256        let mut sink = Pin::new(&mut tx);
257        futures::ready!(sink.as_mut().poll_ready(cx))
258            .map_err(|_| Error::SendError(ErrorType::Closed(None)))?;
259        let waiting_count = this.sink.waiting_count.clone();
260        let task = async move {
261            waiting_count.dec();
262            let _ = task.await;
263        };
264        this.sink.waiting_count.inc();
265        sink.as_mut()
266            .start_send((d, Box::new(Box::pin(task))))
267            .map_err(|_e| {
268                this.sink.waiting_count.dec();
269                Error::SendError(ErrorType::Closed(None))
270            })?;
271        Poll::Ready(Ok(()))
272    }
273}