1use std::fmt::Debug;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::{Future, Sink, SinkExt};
7use futures::channel::oneshot;
8
9use crate::LocalTaskType;
10
11use super::{assert_future, Error, ErrorType, LocalExecutor};
12
13pub struct LocalGroupSpawner<'a, Item, Tx, G> {
14 inner: LocalSpawner<'a, Item, Tx, G, ()>,
15 name: Option<G>,
16}
17
18impl<Item, Tx, G> Unpin for LocalGroupSpawner<'_, Item, Tx, G> {}
19
20impl<'a, Item, Tx, G> LocalGroupSpawner<'a, Item, Tx, G>
21 where
22 Tx: Clone + Unpin + Sink<((), LocalTaskType)> + Sync + 'static,
23 G: Hash + Eq + Clone + Debug + Sync + 'static,
24{
25 #[inline]
26 pub(crate) fn new(inner: LocalSpawner<'a, Item, Tx, G, ()>, name: G) -> Self {
27 Self {
28 inner,
29 name: Some(name),
30 }
31 }
32
33 #[inline]
34 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
35 where
36 Item: Future + 'static,
37 Item::Output: 'static,
38 {
39 let task = match self.inner.item.take() {
40 Some(task) => task,
41 None => {
42 log::error!("polled Feed after completion, task is None!");
43 return Err(Error::SendError(ErrorType::Closed(None)));
44 }
45 };
46
47 let name = match self.name.take() {
48 Some(name) => name,
49 None => {
50 log::error!("polled Feed after completion, name is None!");
51 return Err(Error::SendError(ErrorType::Closed(None)));
52 }
53 };
54
55 if self.inner.sink.is_closed() {
56 return Err(Error::SendError(ErrorType::Closed(Some(task))));
57 }
58
59 let (res_tx, res_rx) = oneshot::channel();
60 let waiting_count = self.inner.sink.waiting_count.clone();
61 let task = async move {
62 waiting_count.dec();
63 let output = task.await;
64 if let Err(_e) = res_tx.send(output) {
65 log::warn!("send result failed");
66 }
67 };
68 self.inner.sink.waiting_count.inc();
69
70 if let Err(_e) = self
71 .inner
72 .sink
73 .group_send(name, Box::new(Box::pin(task)))
74 .await
75 {
76 self.inner.sink.waiting_count.dec();
77 Err(Error::SendError(ErrorType::Closed(None)))
78 } else {
79 res_rx.await.map_err(|_| {
80 self.inner.sink.waiting_count.dec();
81 Error::RecvResultError
82 })
83 }
84 }
85}
86
87impl<Item, Tx, G> Future for LocalGroupSpawner<'_, Item, Tx, G>
88 where
89 Item: Future + 'static,
90 Item::Output: 'static,
91 Tx: Clone + Unpin + Sink<((), LocalTaskType)> + Sync + 'static,
92 G: Hash + Eq + Clone + Debug + Sync + 'static,
93{
94 type Output = Result<(), Error<Item>>;
95
96 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97 let this = self.get_mut();
98 let task = match this.inner.item.take() {
99 Some(task) => task,
100 None => {
101 log::error!("polled Feed after completion, task is None!");
102 return Poll::Ready(Ok(()));
103 }
104 };
105
106 let name = match this.name.take() {
107 Some(name) => name,
108 None => {
109 log::error!("polled Feed after completion, name is None!");
110 return Poll::Ready(Ok(()));
111 }
112 };
113
114 if this.inner.sink.is_closed() {
115 return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
116 }
117 let waiting_count = this.inner.sink.waiting_count.clone();
118 let task = async move {
119 waiting_count.dec();
120 let _ = task.await;
121 };
122 this.inner.sink.waiting_count.inc();
123
124 let mut group_send = this
125 .inner
126 .sink
127 .group_send(name, Box::new(Box::pin(task)))
128 .boxed_local();
129 use futures_lite::FutureExt;
130 if (futures::ready!(group_send.poll(cx))).is_err() {
131 this.inner.sink.waiting_count.dec();
132 Poll::Ready(Err(Error::SendError(ErrorType::Closed(None))))
133 } else {
134 Poll::Ready(Ok(()))
135 }
136 }
137}
138
139pub struct LocalSpawner<'a, Item, Tx, G, D> {
140 sink: &'a LocalExecutor<Tx, G, D>,
141 item: Option<Item>,
142 d: Option<D>,
143}
144
145impl<'a, Item, Tx, G, D> Unpin for LocalSpawner<'a, Item, Tx, G, D> {}
146
147impl<'a, Item, Tx, G> LocalSpawner<'a, Item, Tx, G, ()>
148 where
149 Tx: Clone + Unpin + Sink<((), LocalTaskType)> + Sync + 'static,
150 G: Hash + Eq + Clone + Debug + Sync + 'static,
151{
152 #[inline]
153 pub fn group(self, name: G) -> LocalGroupSpawner<'a, Item, Tx, G>
154 where
155 Item: Future + 'static,
156 Item::Output: 'static,
157 {
158 let fut = LocalGroupSpawner::new(self, name);
159 assert_future::<Result<(), _>, _>(fut)
160 }
161}
162
163impl<'a, Item, Tx, G, D> LocalSpawner<'a, Item, Tx, G, D>
164 where
165 Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + Sync + 'static,
166 G: Hash + Eq + Clone + Debug + Sync + 'static,
167{
168 #[inline]
169 pub(crate) fn new(sink: &'a LocalExecutor<Tx, G, D>, item: Item, d: D) -> Self {
170 Self {
171 sink,
172 item: Some(item),
173 d: Some(d),
174 }
175 }
176
177 #[inline]
178 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
179 where
180 Item: Future + 'static,
181 Item::Output: 'static,
182 {
183 let task = self
184 .item
185 .take()
186 .expect("polled Feed after completion, task is None!");
187 let d = self
188 .d
189 .take()
190 .expect("polled Feed after completion, d is None!");
191
192 if self.sink.is_closed() {
193 return Err(Error::SendError(ErrorType::Closed(Some(task))));
194 }
195
196 let (res_tx, res_rx) = oneshot::channel();
197 let waiting_count = self.sink.waiting_count.clone();
198 let task = async move {
199 waiting_count.dec();
200 let output = task.await;
201 if let Err(_e) = res_tx.send(output) {
202 log::warn!("send result failed");
203 }
204 };
205 self.sink.waiting_count.inc();
206
207 if self
208 .sink
209 .tx
210 .clone()
211 .send((d, Box::new(Box::pin(task))))
212 .await
213 .is_err()
214 {
215 self.sink.waiting_count.dec();
216 return Err(Error::SendError(ErrorType::Closed(None)));
217 }
218 res_rx.await.map_err(|_| {
219 self.sink.waiting_count.dec();
220 Error::RecvResultError
221 })
222 }
223}
224
225impl<Item, Tx, G, D> Future for LocalSpawner<'_, Item, Tx, G, D>
226 where
227 Item: Future + 'static,
228 Item::Output: 'static,
229 Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + Sync + 'static,
230 G: Hash + Eq + Clone + Debug + Sync + 'static,
231{
232 type Output = Result<(), Error<Item>>;
233
234 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235 let this = self.get_mut();
236 let task = match this.item.take() {
237 Some(task) => task,
238 None => {
239 log::error!("polled Feed after completion, task is None!");
240 return Poll::Ready(Ok(()));
241 }
242 };
243
244 let d = match this.d.take() {
245 Some(d) => d,
246 None => {
247 log::error!("polled Feed after completion, d is None!");
248 return Poll::Ready(Ok(()));
249 }
250 };
251
252 if this.sink.is_closed() {
253 return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
254 }
255
256 let mut tx = this.sink.tx.clone();
257 let mut sink = Pin::new(&mut tx);
258 futures::ready!(sink.as_mut().poll_ready(cx))
259 .map_err(|_| Error::SendError(ErrorType::Closed(None)))?;
260 let waiting_count = this.sink.waiting_count.clone();
261 let task = async move {
262 waiting_count.dec();
263 let _ = task.await;
264 };
265 this.sink.waiting_count.inc();
266 sink.as_mut()
267 .start_send((d, Box::new(Box::pin(task))))
268 .map_err(|_e| {
269 this.sink.waiting_count.dec();
270 Error::SendError(ErrorType::Closed(None))
271 })?;
272 Poll::Ready(Ok(()))
273 }
274}