1use std::fmt::Debug;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::rc::Rc;
5use std::task::{Context, Poll};
6
7use crate::local::LocalPendingOnce;
8use futures::channel::oneshot;
9use futures::task::AtomicWaker;
10use futures::{Future, Sink, SinkExt};
11use futures_lite::FutureExt;
12
13use crate::LocalTaskType;
14
15use super::{assert_future, Error, ErrorType, LocalTaskExecQueue};
16
17pub struct LocalGroupSpawner<'a, Item, Tx, G> {
18 inner: LocalSpawner<'a, Item, Tx, G, ()>,
19 name: Option<G>,
20}
21
22impl<Item, Tx, G> Unpin for LocalGroupSpawner<'_, Item, Tx, G> {}
23
24impl<'a, Item, Tx, G> LocalGroupSpawner<'a, Item, Tx, G>
25where
26 Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
27 G: Hash + Eq + Clone + Debug + 'static,
28{
29 #[inline]
30 pub(crate) fn new(inner: LocalSpawner<'a, Item, Tx, G, ()>, name: G) -> Self {
31 Self {
32 inner,
33 name: Some(name),
34 }
35 }
36
37 #[inline]
38 pub fn quickly(mut self) -> Self {
39 self.inner.quickly = true;
40 self
41 }
42
43 #[inline]
44 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
45 where
46 Item: Future + 'static,
47 Item::Output: 'static,
48 {
49 if self.inner.sink.is_closed() {
50 return Err(Error::SendError(ErrorType::Closed(self.inner.item.take())));
51 }
52
53 if !self.inner.quickly && self.inner.sink.is_full() {
54 let w = Rc::new(AtomicWaker::new());
55 self.inner.sink.waiting_wakers.push(w.clone());
56 LocalPendingOnce::new(w).await;
57 }
58
59 let task = match self.inner.item.take() {
60 Some(task) => task,
61 None => {
62 log::error!("polled Feed after completion, task is None!");
63 return Err(Error::SendError(ErrorType::Closed(None)));
64 }
65 };
66
67 let name = match self.name.take() {
68 Some(name) => name,
69 None => {
70 log::error!("polled Feed after completion, name is None!");
71 return Err(Error::SendError(ErrorType::Closed(None)));
72 }
73 };
74
75 let (res_tx, res_rx) = oneshot::channel();
76 let waiting_count = self.inner.sink.waiting_count.clone();
77 let waiting_wakers = self.inner.sink.waiting_wakers.clone();
78 let task = async move {
79 waiting_count.dec();
80 if let Some(w) = waiting_wakers.pop() {
81 w.wake();
82 }
83 let output = task.await;
84 if let Err(_e) = res_tx.send(output) {
85 log::warn!("send result failed");
86 }
87 };
88 self.inner.sink.waiting_count.inc();
89
90 if let Err(_e) = self
91 .inner
92 .sink
93 .group_send(name, Box::new(Box::pin(task)))
94 .await
95 {
96 self.inner.sink.waiting_count.dec();
97 Err(Error::SendError(ErrorType::Closed(None)))
98 } else {
99 res_rx.await.map_err(|_| {
100 self.inner.sink.waiting_count.dec();
101 Error::RecvResultError
102 })
103 }
104 }
105}
106
107impl<Item, Tx, G> Future for LocalGroupSpawner<'_, Item, Tx, G>
108where
109 Item: Future + 'static,
110 Item::Output: 'static,
111 Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
112 G: Hash + Eq + Clone + Debug + 'static,
113{
114 type Output = Result<(), Error<Item>>;
115
116 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
117 let this = self.get_mut();
118 if this.inner.sink.is_closed() && !this.inner.is_pending {
119 return Poll::Ready(Err(Error::SendError(ErrorType::Closed(
120 this.inner.item.take(),
121 ))));
122 }
123
124 if !this.inner.quickly && this.inner.sink.is_full() {
125 let w = Rc::new(AtomicWaker::new());
126 w.register(cx.waker());
127 this.inner.sink.waiting_wakers.push(w);
128 this.inner.is_pending = true;
129 return Poll::Pending;
130 }
131
132 let task = match this.inner.item.take() {
133 Some(task) => task,
134 None => {
135 log::error!("polled Feed after completion, task is None!");
136 return Poll::Ready(Ok(()));
137 }
138 };
139
140 let name = match this.name.take() {
141 Some(name) => name,
142 None => {
143 log::error!("polled Feed after completion, name is None!");
144 return Poll::Ready(Ok(()));
145 }
146 };
147
148 if this.inner.sink.is_closed() {
149 return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
150 }
151 let waiting_count = this.inner.sink.waiting_count.clone();
152 let waiting_wakers = this.inner.sink.waiting_wakers.clone();
153 let task = async move {
154 waiting_count.dec();
155 if let Some(w) = waiting_wakers.pop() {
156 w.wake();
157 }
158 let _ = task.await;
159 };
160 this.inner.sink.waiting_count.inc();
161
162 let mut group_send = this
163 .inner
164 .sink
165 .group_send(name, Box::new(Box::pin(task)))
166 .boxed_local();
167
168 if (futures::ready!(group_send.poll(cx))).is_err() {
169 this.inner.sink.waiting_count.dec();
170 Poll::Ready(Err(Error::SendError(ErrorType::Closed(None))))
171 } else {
172 Poll::Ready(Ok(()))
173 }
174 }
175}
176
177pub struct TryLocalGroupSpawner<'a, Item, Tx, G> {
178 inner: LocalGroupSpawner<'a, Item, Tx, G>,
179}
180
181impl<Item, Tx, G> Unpin for TryLocalGroupSpawner<'_, Item, Tx, G> {}
182
183impl<'a, Item, Tx, G> TryLocalGroupSpawner<'a, Item, Tx, G>
184where
185 Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
186 G: Hash + Eq + Clone + Debug + 'static,
187{
188 #[inline]
189 pub(crate) fn new(inner: LocalSpawner<'a, Item, Tx, G, ()>, name: G) -> Self {
190 Self {
191 inner: LocalGroupSpawner {
192 inner,
193 name: Some(name),
194 },
195 }
196 }
197
198 #[inline]
199 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
200 where
201 Item: Future + 'static,
202 Item::Output: 'static,
203 {
204 if self.inner.inner.sink.is_full() {
205 return Err(Error::TrySendError(ErrorType::Full(
206 self.inner.inner.item.take(),
207 )));
208 }
209 self.inner.result().await
210 }
211}
212
213impl<Item, Tx, G> Future for TryLocalGroupSpawner<'_, Item, Tx, G>
214where
215 Item: Future + 'static,
216 Item::Output: 'static,
217 Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
218 G: Hash + Eq + Clone + Debug + 'static,
219{
220 type Output = Result<(), Error<Item>>;
221
222 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
223 let this = self.get_mut();
224
225 if this.inner.inner.sink.is_full() {
226 return Poll::Ready(Err(Error::TrySendError(ErrorType::Full(
227 this.inner.inner.item.take(),
228 ))));
229 }
230
231 this.inner.poll(cx)
232 }
233}
234
235pub struct LocalSpawner<'a, Item, Tx, G, D> {
236 sink: &'a LocalTaskExecQueue<Tx, G, D>,
237 item: Option<Item>,
238 d: Option<D>,
239 quickly: bool,
240 is_pending: bool,
241}
242
243impl<'a, Item, Tx, G, D> Unpin for LocalSpawner<'a, Item, Tx, G, D> {}
244
245impl<'a, Item, Tx, G> LocalSpawner<'a, Item, Tx, G, ()>
246where
247 Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
248 G: Hash + Eq + Clone + Debug + 'static,
249{
250 #[inline]
251 pub fn group(self, name: G) -> LocalGroupSpawner<'a, Item, Tx, G>
252 where
253 Item: Future + 'static,
254 Item::Output: 'static,
255 {
256 let fut = LocalGroupSpawner::new(self, name);
257 assert_future::<Result<(), _>, _>(fut)
258 }
259}
260
261impl<'a, Item, Tx, G, D> LocalSpawner<'a, Item, Tx, G, D>
262where
263 Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + 'static,
264 G: Hash + Eq + Clone + Debug + 'static,
265{
266 #[inline]
267 pub(crate) fn new(sink: &'a LocalTaskExecQueue<Tx, G, D>, item: Item, d: D) -> Self {
268 Self {
269 sink,
270 item: Some(item),
271 d: Some(d),
272 quickly: false,
273 is_pending: false,
274 }
275 }
276
277 #[inline]
278 pub fn quickly(mut self) -> Self {
279 self.quickly = true;
280 self
281 }
282
283 #[inline]
284 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
285 where
286 Item: Future + 'static,
287 Item::Output: 'static,
288 {
289 if self.sink.is_closed() {
290 return Err(Error::SendError(ErrorType::Closed(self.item.take())));
291 }
292
293 if !self.quickly && self.sink.is_full() {
294 let w = Rc::new(AtomicWaker::new());
295 self.sink.waiting_wakers.push(w.clone());
296 LocalPendingOnce::new(w).await;
297 }
298
299 let task = self
300 .item
301 .take()
302 .expect("polled Feed after completion, task is None!");
303 let d = self
304 .d
305 .take()
306 .expect("polled Feed after completion, d is None!");
307
308 let (res_tx, res_rx) = oneshot::channel();
309 let waiting_count = self.sink.waiting_count.clone();
310 let waiting_wakers = self.sink.waiting_wakers.clone();
311 let task = async move {
312 waiting_count.dec();
313 if let Some(w) = waiting_wakers.pop() {
314 w.wake();
315 }
316 let output = task.await;
317 if let Err(_e) = res_tx.send(output) {
318 log::warn!("send result failed");
319 }
320 };
321 self.sink.waiting_count.inc();
322
323 if self
324 .sink
325 .tx
326 .clone()
327 .send((d, Box::new(Box::pin(task))))
328 .await
329 .is_err()
330 {
331 self.sink.waiting_count.dec();
332 return Err(Error::SendError(ErrorType::Closed(None)));
333 }
334 res_rx.await.map_err(|_| {
335 self.sink.waiting_count.dec();
336 Error::RecvResultError
337 })
338 }
339}
340
341impl<Item, Tx, G, D> Future for LocalSpawner<'_, Item, Tx, G, D>
342where
343 Item: Future + 'static,
344 Item::Output: 'static,
345 Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + 'static,
346 G: Hash + Eq + Clone + Debug + 'static,
347{
348 type Output = Result<(), Error<Item>>;
349
350 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
351 let this = self.get_mut();
352
353 if this.sink.is_closed() && !this.is_pending {
354 return Poll::Ready(Err(Error::SendError(ErrorType::Closed(this.item.take()))));
355 }
356
357 if !this.quickly && this.sink.is_full() {
358 let w = Rc::new(AtomicWaker::new());
359 w.register(cx.waker());
360 this.sink.waiting_wakers.push(w);
361 this.is_pending = true;
362 return Poll::Pending;
363 }
364
365 let task = match this.item.take() {
366 Some(task) => task,
367 None => {
368 log::error!("polled Feed after completion, task is None!");
369 return Poll::Ready(Ok(()));
370 }
371 };
372
373 let d = match this.d.take() {
374 Some(d) => d,
375 None => {
376 log::error!("polled Feed after completion, d is None!");
377 return Poll::Ready(Ok(()));
378 }
379 };
380
381 let mut tx = this.sink.tx.clone();
382 let mut sink = Pin::new(&mut tx);
383 let waiting_count = this.sink.waiting_count.clone();
386 let waiting_wakers = this.sink.waiting_wakers.clone();
387 let task = async move {
388 waiting_count.dec();
389 if let Some(w) = waiting_wakers.pop() {
390 w.wake();
391 }
392 let _ = task.await;
393 };
394 this.sink.waiting_count.inc();
395 sink.as_mut()
396 .start_send((d, Box::new(Box::pin(task))))
397 .map_err(|_e| {
398 this.sink.waiting_count.dec();
399 Error::SendError(ErrorType::Closed(None))
400 })?;
401 Poll::Ready(Ok(()))
402 }
403}
404
405pub struct TryLocalSpawner<'a, Item, Tx, G, D> {
406 inner: LocalSpawner<'a, Item, Tx, G, D>,
407}
408
409impl<'a, Item, Tx, G, D> Unpin for TryLocalSpawner<'a, Item, Tx, G, D> {}
410
411impl<'a, Item, Tx, G> TryLocalSpawner<'a, Item, Tx, G, ()>
412where
413 Tx: Clone + Unpin + Sink<((), LocalTaskType)> + 'static,
414 G: Hash + Eq + Clone + Debug + 'static,
415{
416 #[inline]
417 pub fn group(self, name: G) -> TryLocalGroupSpawner<'a, Item, Tx, G>
418 where
419 Item: Future + 'static,
420 Item::Output: 'static,
421 {
422 let fut = TryLocalGroupSpawner::new(self.inner, name);
423 assert_future::<Result<(), _>, _>(fut)
424 }
425}
426
427impl<'a, Item, Tx, G, D> TryLocalSpawner<'a, Item, Tx, G, D>
428where
429 Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + 'static,
430 G: Hash + Eq + Clone + Debug + 'static,
431{
432 #[inline]
433 pub(crate) fn new(sink: &'a LocalTaskExecQueue<Tx, G, D>, item: Item, d: D) -> Self {
434 Self {
435 inner: LocalSpawner {
436 sink,
437 item: Some(item),
438 d: Some(d),
439 quickly: false,
440 is_pending: false,
441 },
442 }
443 }
444
445 #[inline]
446 pub fn quickly(mut self) -> Self {
447 self.inner.quickly = true;
448 self
449 }
450
451 #[inline]
452 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
453 where
454 Item: Future + 'static,
455 Item::Output: 'static,
456 {
457 if self.inner.sink.is_full() {
458 return Err(Error::TrySendError(ErrorType::Full(self.inner.item.take())));
459 }
460 self.inner.result().await
461 }
462}
463
464impl<Item, Tx, G, D> Future for TryLocalSpawner<'_, Item, Tx, G, D>
465where
466 Item: Future + 'static,
467 Item::Output: 'static,
468 Tx: Clone + Unpin + Sink<(D, LocalTaskType)> + 'static,
469 G: Hash + Eq + Clone + Debug + 'static,
470{
471 type Output = Result<(), Error<Item>>;
472
473 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
474 let this = self.get_mut();
475 if this.inner.sink.is_full() {
476 return Poll::Ready(Err(Error::TrySendError(ErrorType::Full(
477 this.inner.item.take(),
478 ))));
479 }
480 this.inner.poll(cx)
481 }
482}