1use std::fmt::Debug;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::{Future, Sink, SinkExt};
7use futures::channel::oneshot;
8
9use crate::TaskType;
10
11use super::{assert_future, Error, ErrorType, Executor};
12
13pub struct GroupSpawner<'a, Item, Tx, G> {
14 inner: Spawner<'a, Item, Tx, G, ()>,
15 name: Option<G>,
16}
17
18impl<Item, Tx, G> Unpin for GroupSpawner<'_, Item, Tx, G> {}
19
20impl<'a, Item, Tx, G> GroupSpawner<'a, Item, Tx, G>
21 where
22 Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
23 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
24{
25 #[inline]
26 pub(crate) fn new(inner: Spawner<'a, Item, Tx, G, ()>, name: G) -> Self {
27 Self {
28 inner,
29 name: Some(name),
30 }
31 }
32
33 #[inline]
34 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
35 where
36 Item: Future + Send + 'static,
37 Item::Output: Send + 'static,
38 {
39 let task = match self.inner.item.take() {
40 Some(task) => task,
41 None => {
42 log::error!("polled Feed after completion, task is None!");
43 return Err(Error::SendError(ErrorType::Closed(None)));
44 }
45 };
46
47 let name = match self.name.take() {
48 Some(name) => name,
49 None => {
50 log::error!("polled Feed after completion, name is None!");
51 return Err(Error::SendError(ErrorType::Closed(None)));
52 }
53 };
54
55 if self.inner.sink.is_closed() {
56 return Err(Error::SendError(ErrorType::Closed(Some(task))));
57 }
58
59 let (res_tx, res_rx) = oneshot::channel();
60 let waiting_count = self.inner.sink.waiting_count.clone();
61 let task = async move {
62 waiting_count.dec();
63 let output = task.await;
64 if let Err(_e) = res_tx.send(output) {
65 log::warn!("send result failed");
66 }
67 };
68 self.inner.sink.waiting_count.inc();
69
70 if let Err(_e) = self
71 .inner
72 .sink
73 .group_send(name, Box::new(Box::pin(task)))
74 .await
75 {
76 self.inner.sink.waiting_count.dec();
77 Err(Error::SendError(ErrorType::Closed(None)))
78 } else {
79 res_rx.await.map_err(|_| {
80 self.inner.sink.waiting_count.dec();
81 Error::RecvResultError
82 })
83 }
84 }
85}
86
87impl<Item, Tx, G> Future for GroupSpawner<'_, Item, Tx, G>
88 where
89 Item: Future + Send + 'static,
90 Item::Output: Send + 'static,
91 Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
92 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
93{
94 type Output = Result<(), Error<Item>>;
95
96 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97 let this = self.get_mut();
98 let task = match this.inner.item.take() {
99 Some(task) => task,
100 None => {
101 log::error!("polled Feed after completion, task is None!");
102 return Poll::Ready(Ok(()));
103 }
104 };
105
106 let name = match this.name.take() {
107 Some(name) => name,
108 None => {
109 log::error!("polled Feed after completion, name is None!");
110 return Poll::Ready(Ok(()));
111 }
112 };
113
114 if this.inner.sink.is_closed() {
115 return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
116 }
117 let waiting_count = this.inner.sink.waiting_count.clone();
118 let task = async move {
119 waiting_count.dec();
120 let _ = task.await;
121 };
122 this.inner.sink.waiting_count.inc();
123 let mut group_send = this
124 .inner
125 .sink
126 .group_send(name, Box::new(Box::pin(task)))
127 .boxed();
128 use futures_lite::FutureExt;
129 if (futures::ready!(group_send.poll(cx))).is_err() {
130 this.inner.sink.waiting_count.dec();
131 Poll::Ready(Err(Error::SendError(ErrorType::Closed(None))))
132 } else {
133 Poll::Ready(Ok(()))
134 }
135 }
136}
137
138pub struct Spawner<'a, Item, Tx, G, D> {
139 sink: &'a Executor<Tx, G, D>,
140 item: Option<Item>,
141 d: Option<D>,
142}
143
144impl<'a, Item, Tx, G, D> Unpin for Spawner<'a, Item, Tx, G, D> {}
145
146impl<'a, Item, Tx, G> Spawner<'a, Item, Tx, G, ()>
147 where
148 Tx: Clone + Unpin + Sink<((), TaskType)> + Send + Sync + 'static,
149 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
150{
151 #[inline]
152 pub fn group(self, name: G) -> GroupSpawner<'a, Item, Tx, G>
153 where
154 Item: Future + Send + 'static,
155 Item::Output: Send + 'static,
156 {
157 let fut = GroupSpawner::new(self, name);
158 assert_future::<Result<(), _>, _>(fut)
159 }
160}
161
162impl<'a, Item, Tx, G, D> Spawner<'a, Item, Tx, G, D>
163 where
164 Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
165 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
166{
167 #[inline]
168 pub(crate) fn new(sink: &'a Executor<Tx, G, D>, item: Item, d: D) -> Self {
169 Self {
170 sink,
171 item: Some(item),
172 d: Some(d),
173 }
174 }
175
176 #[inline]
177 pub async fn result(mut self) -> Result<Item::Output, Error<Item>>
178 where
179 Item: Future + Send + 'static,
180 Item::Output: Send + 'static,
181 {
182 let task = self
183 .item
184 .take()
185 .expect("polled Feed after completion, task is None!");
186 let d = self
187 .d
188 .take()
189 .expect("polled Feed after completion, d is None!");
190
191 if self.sink.is_closed() {
192 return Err(Error::SendError(ErrorType::Closed(Some(task))));
193 }
194
195 let (res_tx, res_rx) = oneshot::channel();
196 let waiting_count = self.sink.waiting_count.clone();
197 let task = async move {
198 waiting_count.dec();
199 let output = task.await;
200 if let Err(_e) = res_tx.send(output) {
201 log::warn!("send result failed");
202 }
203 };
204 self.sink.waiting_count.inc();
205
206 if self
207 .sink
208 .tx
209 .clone()
210 .send((d, Box::new(Box::pin(task))))
211 .await
212 .is_err()
213 {
214 self.sink.waiting_count.dec();
215 return Err(Error::SendError(ErrorType::Closed(None)));
216 }
217 res_rx.await.map_err(|_| {
218 self.sink.waiting_count.dec();
219 Error::RecvResultError
220 })
221 }
222}
223
224impl<Item, Tx, G, D> Future for Spawner<'_, Item, Tx, G, D>
225 where
226 Item: Future + Send + 'static,
227 Item::Output: Send + 'static,
228 Tx: Clone + Unpin + Sink<(D, TaskType)> + Send + Sync + 'static,
229 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
230{
231 type Output = Result<(), Error<Item>>;
232
233 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
234 let this = self.get_mut();
235 let task = match this.item.take() {
236 Some(task) => task,
237 None => {
238 log::error!("polled Feed after completion, task is None!");
239 return Poll::Ready(Ok(()));
240 }
241 };
242
243 let d = match this.d.take() {
244 Some(d) => d,
245 None => {
246 log::error!("polled Feed after completion, d is None!");
247 return Poll::Ready(Ok(()));
248 }
249 };
250
251 if this.sink.is_closed() {
252 return Poll::Ready(Err(Error::SendError(ErrorType::Closed(Some(task)))));
253 }
254
255 let mut tx = this.sink.tx.clone();
256 let mut sink = Pin::new(&mut tx);
257 futures::ready!(sink.as_mut().poll_ready(cx))
258 .map_err(|_| Error::SendError(ErrorType::Closed(None)))?;
259 let waiting_count = this.sink.waiting_count.clone();
260 let task = async move {
261 waiting_count.dec();
262 let _ = task.await;
263 };
264 this.sink.waiting_count.inc();
265 sink.as_mut()
266 .start_send((d, Box::new(Box::pin(task))))
267 .map_err(|_e| {
268 this.sink.waiting_count.dec();
269 Error::SendError(ErrorType::Closed(None))
270 })?;
271 Poll::Ready(Ok(()))
272 }
273}