1use std::fmt::Debug;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use crate::exec::PendingOnce;
8use futures::channel::oneshot;
9use futures::task::AtomicWaker;
10use futures::{Future, Sink, SinkExt};
11use futures_lite::FutureExt;
12
13use crate::TaskType;
14
15use super::{assert_future, Error, ErrorType, TaskExecQueue};
16
17pub struct GroupSpawner<'a, Item, Tx, G> {
18 inner: Spawner<'a, Item, Tx, G, ()>,
19 name: Option<G>,
20}
21
22impl<Item, Tx, G> Unpin for GroupSpawner<'_, Item, Tx, G> {}
23
24impl<'a, Item, Tx, G> GroupSpawner<'a, Item, Tx, G>
25where
26 Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
27 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
28{
29 #[inline]
30 pub(crate) fn new(inner: Spawner<'a, Item, Tx, G, ()>, name: G) -> Self {
31 Self {
32 inner,
33 name: Some(name),
34 }
35 }
36
37 #[inline]
38 pub fn quickly(mut self) -> Self {
39 self.inner.quickly = true;
40 self
41 }
42
43 #[inline]
44 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
45 where
46 Item: Future + Send + 'static,
47 Item::Output: Send + 'static,
48 {
49 if self.inner.sink.is_closed() {
50 return Err(Error::SendError(ErrorType::Closed(self.inner.item.take())));
51 }
52
53 if !self.inner.quickly && self.inner.sink.is_full() {
54 let w = Arc::new(AtomicWaker::new());
55 self.inner.sink.waiting_wakers.push(w.clone());
56 PendingOnce::new(w).await;
57 }
58
59 let task = match self.inner.item.take() {
60 Some(task) => task,
61 None => {
62 log::error!("polled Feed after completion, task is None!");
63 return Err(Error::SendError(ErrorType::Closed(None)));
64 }
65 };
66
67 let name = match self.name.take() {
68 Some(name) => name,
69 None => {
70 log::error!("polled Feed after completion, name is None!");
71 return Err(Error::SendError(ErrorType::Closed(None)));
72 }
73 };
74
75 let (res_tx, res_rx) = oneshot::channel();
76 let waiting_count = self.inner.sink.waiting_count.clone();
77 let waiting_wakers = self.inner.sink.waiting_wakers.clone();
78 let task = async move {
79 waiting_count.dec();
80 if let Some(w) = waiting_wakers.pop() {
81 w.wake();
82 }
83 let output = task.await;
84 if let Err(_e) = res_tx.send(output) {
85 log::warn!("send result failed");
86 }
87 };
88 self.inner.sink.waiting_count.inc();
89
90 if let Err(_e) = self
91 .inner
92 .sink
93 .group_send(name, Box::new(Box::pin(task)))
94 .await
95 {
96 self.inner.sink.waiting_count.dec();
97 Err(Error::SendError(ErrorType::Closed(None)))
98 } else {
99 res_rx.await.map_err(|_| {
100 self.inner.sink.waiting_count.dec();
101 Error::RecvResultError
102 })
103 }
104 }
105}
106
107impl<Item, Tx, G> Future for GroupSpawner<'_, Item, Tx, G>
108where
109 Item: Future + Send + 'static,
110 Item::Output: Send + 'static,
111 Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
112 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
113{
114 type Output = Result<(), Error<Item>>;
115
116 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
117 let this = self.get_mut();
118 if this.inner.sink.is_closed() && !this.inner.is_pending {
119 return Poll::Ready(Err(Error::SendError(ErrorType::Closed(
120 this.inner.item.take(),
121 ))));
122 }
123
124 if !this.inner.quickly && this.inner.sink.is_full() {
125 let w = Arc::new(AtomicWaker::new());
126 w.register(cx.waker());
127 this.inner.sink.waiting_wakers.push(w);
128 this.inner.is_pending = true;
129 return Poll::Pending;
130 }
131
132 let task = match this.inner.item.take() {
133 Some(task) => task,
134 None => {
135 log::error!("polled Feed after completion, task is None!");
136 return Poll::Ready(Ok(()));
137 }
138 };
139
140 let name = match this.name.take() {
141 Some(name) => name,
142 None => {
143 log::error!("polled Feed after completion, name is None!");
144 return Poll::Ready(Ok(()));
145 }
146 };
147
148 let waiting_count = this.inner.sink.waiting_count.clone();
149 let waiting_wakers = this.inner.sink.waiting_wakers.clone();
150 let task = async move {
151 waiting_count.dec();
152 if let Some(w) = waiting_wakers.pop() {
153 w.wake();
154 }
155 let _ = task.await;
156 };
157 this.inner.sink.waiting_count.inc();
158
159 let mut group_send = this
160 .inner
161 .sink
162 .group_send(name, Box::new(Box::pin(task)))
163 .boxed();
164
165 if (futures::ready!(group_send.poll(cx))).is_err() {
166 this.inner.sink.waiting_count.dec();
167 Poll::Ready(Err(Error::SendError(ErrorType::Closed(None))))
168 } else {
169 Poll::Ready(Ok(()))
170 }
171 }
172}
173
174pub struct TryGroupSpawner<'a, Item, Tx, G> {
175 inner: GroupSpawner<'a, Item, Tx, G>,
176}
177
178impl<Item, Tx, G> Unpin for TryGroupSpawner<'_, Item, Tx, G> {}
179
180impl<'a, Item, Tx, G> TryGroupSpawner<'a, Item, Tx, G>
181where
182 Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
183 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
184{
185 #[inline]
186 pub(crate) fn new(inner: Spawner<'a, Item, Tx, G, ()>, name: G) -> Self {
187 Self {
188 inner: GroupSpawner {
189 inner,
190 name: Some(name),
191 },
192 }
193 }
194
195 #[inline]
196 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
197 where
198 Item: Future + Send + 'static,
199 Item::Output: Send + 'static,
200 {
201 if self.inner.inner.sink.is_full() {
202 return Err(Error::TrySendError(ErrorType::Full(
203 self.inner.inner.item.take(),
204 )));
205 }
206 self.inner.result().await
207 }
208}
209
210impl<Item, Tx, G> Future for TryGroupSpawner<'_, Item, Tx, G>
211where
212 Item: Future + Send + 'static,
213 Item::Output: Send + 'static,
214 Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
215 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
216{
217 type Output = Result<(), Error<Item>>;
218
219 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220 let this = self.get_mut();
221
222 if this.inner.inner.sink.is_full() {
223 return Poll::Ready(Err(Error::TrySendError(ErrorType::Full(
224 this.inner.inner.item.take(),
225 ))));
226 }
227
228 this.inner.poll(cx)
229 }
230}
231
232pub struct Spawner<'a, Item, Tx, G, D> {
233 sink: &'a TaskExecQueue<Tx, G, D>,
234 item: Option<Item>,
235 d: Option<D>,
236 quickly: bool,
237 is_pending: bool,
238}
239
240impl<'a, Item, Tx, G, D> Unpin for Spawner<'a, Item, Tx, G, D> {}
241
242impl<'a, Item, Tx, G> Spawner<'a, Item, Tx, G, ()>
243where
244 Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
245 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
246{
247 #[inline]
248 pub fn group(self, name: G) -> GroupSpawner<'a, Item, Tx, G>
249 where
250 Item: Future + Send + 'static,
251 Item::Output: Send + 'static,
252 {
253 let fut = GroupSpawner::new(self, name);
254 assert_future::<Result<(), _>, _>(fut)
255 }
256}
257
258impl<'a, Item, Tx, G, D> Spawner<'a, Item, Tx, G, D>
259where
260 Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
261 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
262{
263 #[inline]
264 pub(crate) fn new(sink: &'a TaskExecQueue<Tx, G, D>, item: Item, d: D) -> Self {
265 Self {
266 sink,
267 item: Some(item),
268 d: Some(d),
269 quickly: false,
270 is_pending: false,
271 }
272 }
273
274 #[inline]
275 pub fn quickly(mut self) -> Self {
276 self.quickly = true;
277 self
278 }
279
280 #[inline]
281 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
282 where
283 Item: Future + Send + 'static,
284 Item::Output: Send + 'static,
285 {
286 if self.sink.is_closed() {
287 return Err(Error::SendError(ErrorType::Closed(self.item.take())));
288 }
289
290 if !self.quickly && self.sink.is_full() {
291 let w = Arc::new(AtomicWaker::new());
292 self.sink.waiting_wakers.push(w.clone());
293 PendingOnce::new(w).await;
294 }
295
296 let task = self
297 .item
298 .take()
299 .expect("polled Feed after completion, task is None!");
300 let d = self
301 .d
302 .take()
303 .expect("polled Feed after completion, d is None!");
304
305 let (res_tx, res_rx) = oneshot::channel();
306 let waiting_count = self.sink.waiting_count.clone();
307 let waiting_wakers = self.sink.waiting_wakers.clone();
308 let task = async move {
309 waiting_count.dec();
310 if let Some(w) = waiting_wakers.pop() {
311 w.wake();
312 }
313 let output = task.await;
314 if let Err(_e) = res_tx.send(output) {
315 log::warn!("send result failed");
316 }
317 };
318
319 self.sink.waiting_count.inc();
320 if self
321 .sink
322 .tx
323 .clone()
324 .send((d, Box::new(Box::pin(task))))
325 .await
326 .is_err()
327 {
328 self.sink.waiting_count.dec();
329 return Err(Error::SendError(ErrorType::Closed(None)));
330 }
331 res_rx.await.map_err(|_| {
332 self.sink.waiting_count.dec();
333 Error::RecvResultError
334 })
335 }
336}
337
338impl<Item, Tx, G, D> Future for Spawner<'_, Item, Tx, G, D>
339where
340 Item: Future + Send + 'static,
341 Item::Output: Send + 'static,
342 Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
343 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
344{
345 type Output = Result<(), Error<Item>>;
346
347 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
348 let this = self.get_mut();
349
350 if this.sink.is_closed() && !this.is_pending {
351 return Poll::Ready(Err(Error::SendError(ErrorType::Closed(this.item.take()))));
352 }
353
354 if !this.quickly && this.sink.is_full() {
355 let w = Arc::new(AtomicWaker::new());
356 w.register(cx.waker());
357 this.sink.waiting_wakers.push(w);
358 this.is_pending = true;
359 return Poll::Pending;
360 }
361
362 let task = match this.item.take() {
363 Some(task) => task,
364 None => {
365 log::error!("polled Feed after completion, task is None!");
366 return Poll::Ready(Ok(()));
367 }
368 };
369
370 let d = match this.d.take() {
371 Some(d) => d,
372 None => {
373 log::error!("polled Feed after completion, d is None!");
374 return Poll::Ready(Ok(()));
375 }
376 };
377
378 let mut tx = this.sink.tx.clone();
379 let mut sink = Pin::new(&mut tx);
380
381 let waiting_count = this.sink.waiting_count.clone();
382 let waiting_wakers = this.sink.waiting_wakers.clone();
383 let task = async move {
384 waiting_count.dec();
385 if let Some(w) = waiting_wakers.pop() {
386 w.wake();
387 }
388 let _ = task.await;
389 };
390 this.sink.waiting_count.inc();
391 sink.as_mut()
392 .start_send((d, Box::new(Box::pin(task))))
393 .map_err(|_e| {
394 this.sink.waiting_count.dec();
395 Error::SendError(ErrorType::Closed(None))
396 })?;
397 Poll::Ready(Ok(()))
398 }
399}
400
401pub struct TrySpawner<'a, Item, Tx, G, D> {
402 inner: Spawner<'a, Item, Tx, G, D>,
403}
404
405impl<'a, Item, Tx, G, D> Unpin for TrySpawner<'a, Item, Tx, G, D> {}
406
407impl<'a, Item, Tx, G> TrySpawner<'a, Item, Tx, G, ()>
408where
409 Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
410 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
411{
412 #[inline]
413 pub fn group(self, name: G) -> TryGroupSpawner<'a, Item, Tx, G>
414 where
415 Item: Future + Send + 'static,
416 Item::Output: Send + 'static,
417 {
418 let fut = TryGroupSpawner::new(self.inner, name);
419 assert_future::<Result<(), _>, _>(fut)
420 }
421}
422
423impl<'a, Item, Tx, G, D> TrySpawner<'a, Item, Tx, G, D>
424where
425 Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
426 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
427{
428 #[inline]
429 pub(crate) fn new(sink: &'a TaskExecQueue<Tx, G, D>, item: Item, d: D) -> Self {
430 Self {
431 inner: Spawner {
432 sink,
433 item: Some(item),
434 d: Some(d),
435 quickly: false,
436 is_pending: false,
437 },
438 }
439 }
440
441 #[inline]
442 pub fn quickly(mut self) -> Self {
443 self.inner.quickly = true;
444 self
445 }
446
447 #[inline]
448 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
449 where
450 Item: Future + Send + 'static,
451 Item::Output: Send + 'static,
452 {
453 if self.inner.sink.is_full() {
454 return Err(Error::TrySendError(ErrorType::Full(self.inner.item.take())));
455 }
456 self.inner.result().await
457 }
458}
459
460impl<Item, Tx, G, D> Future for TrySpawner<'_, Item, Tx, G, D>
461where
462 Item: Future + Send + 'static,
463 Item::Output: Send + 'static,
464 Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
465 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
466{
467 type Output = Result<(), Error<Item>>;
468
469 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
470 let this = self.get_mut();
471 if this.inner.sink.is_full() {
472 return Poll::Ready(Err(Error::TrySendError(ErrorType::Full(
473 this.inner.item.take(),
474 ))));
475 }
476 this.inner.poll(cx)
477 }
478}