task_executor/
local_spawner.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::{Future, Sink, SinkExt};
7use futures::channel::oneshot;
8
9use crate::LocalTaskType;
10
11use super::{assert_future, Error, ErrorType, LocalExecutor};
12
13pub struct LocalGroupSpawner<'a, Item, Tx, G> {
14    inner: LocalSpawner<'a, Item, Tx, G, ()>,
15    name: Option<G>,
16}
17
18impl<Item, Tx, G> Unpin for LocalGroupSpawner<'_, Item, Tx, G> {}
19
20impl<'a, Item, Tx, G> LocalGroupSpawner<'a, Item, Tx, G>
21    where
22        Tx: Clone + Unpin + Sink<((), LocalTaskType)> + Sync + 'static,
23        G: Hash + Eq + Clone + Debug + Sync + 'static,
24{
25    #[inline]
26    pub(crate) fn new(inner: LocalSpawner<'a, Item, Tx, G, ()>, name: G) -> Self {
27        Self {
28            inner,
29            name: Some(name),
30        }
31    }
32
33    #[inline]
34    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
35        where
36            Item: Future + 'static,
37            Item::Output: 'static,
38    {
39        let task = match self.inner.item.take() {
40            Some(task) => task,
41            None => {
42                log::error!("polled Feed after completion, task is None!");
43                return Err(Error::SendError(ErrorType::Closed(None)));
44            }
45        };
46
47        let name = match self.name.take() {
48            Some(name) => name,
49            None => {
50                log::error!("polled Feed after completion, name is None!");
51                return Err(Error::SendError(ErrorType::Closed(None)));
52            }
53        };
54
55        if self.inner.sink.is_closed() {
56            return Err(Error::SendError(ErrorType::Closed(Some(task))));
57        }
58
59        let (res_tx, res_rx) = oneshot::channel();
60        let waiting_count = self.inner.sink.waiting_count.clone();
61        let task = async move {
62            waiting_count.dec();
63            let output = task.await;
64            if let Err(_e) = res_tx.send(output) {
65                log::warn!("send result failed");
66            }
67        };
68        self.inner.sink.waiting_count.inc();
69
70        if let Err(_e) = self
71            .inner
72            .sink
73            .group_send(name, Box::new(Box::pin(task)))
74            .await
75        {
76            self.inner.sink.waiting_count.dec();
77            Err(Error::SendError(ErrorType::Closed(None)))
78        } else {
79            res_rx.await.map_err(|_| {
80                self.inner.sink.waiting_count.dec();
81                Error::RecvResultError
82            })
83        }
84    }
85}
86
87impl<Item, Tx, G> Future for LocalGroupSpawner<'_, Item, Tx, G>
88    where
89        Item: Future + 'static,
90        Item::Output: 'static,
91        Tx: Clone + Unpin + Sink<((), LocalTaskType)> + Sync + 'static,
92        G: Hash + Eq + Clone + Debug + Sync + 'static,
93{
94    type Output = Result<(), Error<Item>>;
95
96    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97        let this = self.get_mut();
98        let task = match this.inner.item.take() {
99            Some(task) => task,
100            None => {
101                log::error!("polled Feed after completion, task is None!");
102                return Poll::Ready(Ok(()));
103            }
104        };
105
106        let name = match this.name.take() {
107            Some(name) => name,
108            None => {
109                log::error!("polled Feed after completion, name is None!");
110                return Poll::Ready(Ok(()));
111            }
112        };
113
114        if this.inner.sink.is_closed() {
115            return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
116        }
117        let waiting_count = this.inner.sink.waiting_count.clone();
118        let task = async move {
119            waiting_count.dec();
120            let _ = task.await;
121        };
122        this.inner.sink.waiting_count.inc();
123
124        let mut group_send = this
125            .inner
126            .sink
127            .group_send(name, Box::new(Box::pin(task)))
128            .boxed_local();
129        use futures_lite::FutureExt;
130        if (futures::ready!(group_send.poll(cx))).is_err() {
131            this.inner.sink.waiting_count.dec();
132            Poll::Ready(Err(Error::SendError(ErrorType::Closed(None))))
133        } else {
134            Poll::Ready(Ok(()))
135        }
136    }
137}
138
139pub struct LocalSpawner<'a, Item, Tx, G, D> {
140    sink: &'a LocalExecutor<Tx, G, D>,
141    item: Option<Item>,
142    d: Option<D>,
143}
144
145impl<'a, Item, Tx, G, D> Unpin for LocalSpawner<'a, Item, Tx, G, D> {}
146
147impl<'a, Item, Tx, G> LocalSpawner<'a, Item, Tx, G, ()>
148    where
149        Tx: Clone + Unpin + Sink<((), LocalTaskType)> + Sync + 'static,
150        G: Hash + Eq + Clone + Debug + Sync + 'static,
151{
152    #[inline]
153    pub fn group(self, name: G) -> LocalGroupSpawner<'a, Item, Tx, G>
154        where
155            Item: Future + 'static,
156            Item::Output: 'static,
157    {
158        let fut = LocalGroupSpawner::new(self, name);
159        assert_future::<Result<(), _>, _>(fut)
160    }
161}
162
163impl<'a, Item, Tx, G, D> LocalSpawner<'a, Item, Tx, G, D>
164    where
165        Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + Sync + 'static,
166        G: Hash + Eq + Clone + Debug + Sync + 'static,
167{
168    #[inline]
169    pub(crate) fn new(sink: &'a LocalExecutor<Tx, G, D>, item: Item, d: D) -> Self {
170        Self {
171            sink,
172            item: Some(item),
173            d: Some(d),
174        }
175    }
176
177    #[inline]
178    pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
179        where
180            Item: Future + 'static,
181            Item::Output: 'static,
182    {
183        let task = self
184            .item
185            .take()
186            .expect("polled Feed after completion, task is None!");
187        let d = self
188            .d
189            .take()
190            .expect("polled Feed after completion, d is None!");
191
192        if self.sink.is_closed() {
193            return Err(Error::SendError(ErrorType::Closed(Some(task))));
194        }
195
196        let (res_tx, res_rx) = oneshot::channel();
197        let waiting_count = self.sink.waiting_count.clone();
198        let task = async move {
199            waiting_count.dec();
200            let output = task.await;
201            if let Err(_e) = res_tx.send(output) {
202                log::warn!("send result failed");
203            }
204        };
205        self.sink.waiting_count.inc();
206
207        if self
208            .sink
209            .tx
210            .clone()
211            .send((d, Box::new(Box::pin(task))))
212            .await
213            .is_err()
214        {
215            self.sink.waiting_count.dec();
216            return Err(Error::SendError(ErrorType::Closed(None)));
217        }
218        res_rx.await.map_err(|_| {
219            self.sink.waiting_count.dec();
220            Error::RecvResultError
221        })
222    }
223}
224
225impl<Item, Tx, G, D> Future for LocalSpawner<'_, Item, Tx, G, D>
226    where
227        Item: Future + 'static,
228        Item::Output: 'static,
229        Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + Sync + 'static,
230        G: Hash + Eq + Clone + Debug + Sync + 'static,
231{
232    type Output = Result<(), Error<Item>>;
233
234    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235        let this = self.get_mut();
236        let task = match this.item.take() {
237            Some(task) => task,
238            None => {
239                log::error!("polled Feed after completion, task is None!");
240                return Poll::Ready(Ok(()));
241            }
242        };
243
244        let d = match this.d.take() {
245            Some(d) => d,
246            None => {
247                log::error!("polled Feed after completion, d is None!");
248                return Poll::Ready(Ok(()));
249            }
250        };
251
252        if this.sink.is_closed() {
253            return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
254        }
255
256        let mut tx = this.sink.tx.clone();
257        let mut sink = Pin::new(&mut tx);
258        futures::ready!(sink.as_mut().poll_ready(cx))
259            .map_err(|_| Error::SendError(ErrorType::Closed(None)))?;
260        let waiting_count = this.sink.waiting_count.clone();
261        let task = async move {
262            waiting_count.dec();
263            let _ = task.await;
264        };
265        this.sink.waiting_count.inc();
266        sink.as_mut()
267            .start_send((d, Box::new(Box::pin(task))))
268            .map_err(|_e| {
269                this.sink.waiting_count.dec();
270                Error::SendError(ErrorType::Closed(None))
271            })?;
272        Poll::Ready(Ok(()))
273    }
274}