1use crate::*;
2use futures::{Future, FutureExt, Stream};
3use std::{fmt::Debug, mem::ManuallyDrop, sync::Arc, task::Poll, time::Duration};
4use tokio::task::JoinHandle;
5
6#[derive(Debug)]
16pub struct ChildPool<E, C = dyn AnyChannel>
17where
18 E: Send + 'static,
19 C: DynChannel + ?Sized,
20{
21 pub(super) channel: Arc<C>,
22 pub(super) handles: Option<Vec<JoinHandle<E>>>,
23 pub(super) link: Link,
24 pub(super) is_aborted: bool,
25}
26
27impl<E, C> ChildPool<E, C>
28where
29 E: Send + 'static,
30 C: DynChannel + ?Sized,
31{
32 pub(crate) fn new(channel: Arc<C>, handles: Vec<JoinHandle<E>>, link: Link) -> Self {
33 Self {
34 channel,
35 handles: Some(handles),
36 link,
37 is_aborted: false,
38 }
39 }
40
41 fn into_parts(self) -> (Arc<C>, Vec<JoinHandle<E>>, Link, bool) {
42 let no_drop = ManuallyDrop::new(self);
43 unsafe {
44 let mut handle = std::ptr::read(&no_drop.handles);
45 let channel = std::ptr::read(&no_drop.channel);
46 let link = std::ptr::read(&no_drop.link);
47 let is_aborted = std::ptr::read(&no_drop.is_aborted);
48 (channel, handle.take().unwrap(), link, is_aborted)
49 }
50 }
51
52 pub fn into_joinhandles(self) -> Vec<JoinHandle<E>> {
57 self.into_parts().1
58 }
59
60 pub fn abort(&mut self) -> bool {
64 self.channel.close();
65 let was_aborted = self.is_aborted;
66 self.is_aborted = true;
67 for handle in self.handles.as_ref().unwrap() {
68 handle.abort()
69 }
70 !was_aborted
71 }
72
73 pub fn is_finished(&self) -> bool {
75 self.handles
76 .as_ref()
77 .unwrap()
78 .iter()
79 .all(|handle| handle.is_finished())
80 }
81
82 pub fn task_count(&self) -> usize {
87 self.handles
88 .as_ref()
89 .unwrap()
90 .iter()
91 .filter(|handle| !handle.is_finished())
92 .collect::<Vec<_>>()
93 .len()
94 }
95
96 pub fn handle_count(&self) -> usize {
99 self.handles.as_ref().unwrap().len()
100 }
101
102 pub fn try_spawn<M, Fun, Fut>(&mut self, fun: Fun) -> Result<(), TrySpawnError<Fun>>
108 where
109 Fun: FnOnce(Inbox<M>) -> Fut + Send + 'static,
110 Fut: Future<Output = E> + Send + 'static,
111 M: Send + 'static,
112 E: Send + 'static,
113 C: AnyChannel,
114 {
115 let channel = match Arc::downcast::<Channel<M>>(self.channel.clone().into_any()) {
116 Ok(channel) => channel,
117 Err(_) => return Err(TrySpawnError::IncorrectType(fun)),
118 };
119
120 match channel.try_add_inbox() {
121 Ok(_) => {
122 let inbox = Inbox::from_channel(channel);
123 let handle = tokio::task::spawn(async move { fun(inbox).await });
124 self.handles.as_mut().unwrap().push(handle);
125 Ok(())
126 }
127 Err(_) => Err(TrySpawnError::Exited(fun)),
128 }
129 }
130
131 pub fn downcast<M>(self) -> Result<ChildPool<E, Channel<M>>, Self>
133 where
134 M: Send + 'static,
135 C: AnyChannel,
136 {
137 let (channel, handles, link, is_aborted) = self.into_parts();
138 match channel.clone().into_any().downcast::<Channel<M>>() {
139 Ok(channel) => Ok(ChildPool {
140 handles: Some(handles),
141 channel,
142 link,
143 is_aborted,
144 }),
145 Err(_) => Err(ChildPool {
146 handles: Some(handles),
147 channel,
148 link,
149 is_aborted,
150 }),
151 }
152 }
153
154 pub fn shutdown(&mut self, timeout: Duration) -> ShutdownStream<'_, E, C> {
171 ShutdownStream::new(self, timeout)
172 }
173
174 gen::dyn_channel_methods!();
175 gen::child_methods!();
176}
177
178impl<E, M> ChildPool<E, Channel<M>>
179where
180 E: Send + 'static,
181{
182 pub fn into_dyn(self) -> ChildPool<E>
184 where
185 M: Send + 'static,
186 {
187 let parts = self.into_parts();
188 ChildPool {
189 handles: Some(parts.1),
190 channel: parts.0,
191 link: parts.2,
192 is_aborted: parts.3,
193 }
194 }
195
196 pub fn spawn<Fun, Fut>(&mut self, fun: Fun) -> Result<(), SpawnError<Fun>>
200 where
201 Fun: FnOnce(Inbox<M>) -> Fut + Send + 'static,
202 Fut: Future<Output = E> + Send + 'static,
203 E: Send + 'static,
204 M: Send + 'static,
205 {
206 match self.channel.try_add_inbox() {
207 Ok(_) => {
208 let inbox = Inbox::from_channel(self.channel.clone());
209 let handle = tokio::task::spawn(async move { fun(inbox).await });
210 self.handles.as_mut().unwrap().push(handle);
211 Ok(())
212 }
213 Err(_) => Err(SpawnError(fun)),
214 }
215 }
216
217 gen::send_methods!();
218}
219
220#[cfg(feature = "internals")]
221impl<E, C> ChildPool<E, C>
222where
223 E: Send + 'static,
224 C: DynChannel + ?Sized,
225{
226 pub fn transform_channel<C2: DynChannel + ?Sized>(
227 self,
228 func: fn(Arc<C>) -> Arc<C2>,
229 ) -> ChildPool<E, C2> {
230 let (channel, handles, link, is_aborted) = self.into_parts();
231 ChildPool {
232 channel: func(channel),
233 handles: Some(handles),
234 link,
235 is_aborted,
236 }
237 }
238
239 pub fn channel_ref(&self) -> &C {
240 &self.channel
241 }
242}
243
244impl<E: Send + 'static, C: DynChannel + ?Sized> Stream for ChildPool<E, C> {
245 type Item = Result<E, ExitError>;
246
247 fn poll_next(
248 mut self: std::pin::Pin<&mut Self>,
249 cx: &mut std::task::Context<'_>,
250 ) -> std::task::Poll<Option<Self::Item>> {
251 if self.handles.as_ref().unwrap().len() == 0 {
252 return Poll::Ready(None);
253 }
254
255 for (i, handle) in self.handles.as_mut().unwrap().iter_mut().enumerate() {
256 if let Poll::Ready(res) = handle.poll_unpin(cx) {
257 self.handles.as_mut().unwrap().swap_remove(i);
258 return Poll::Ready(Some(res.map_err(Into::into)));
259 }
260 }
261
262 Poll::Pending
263 }
264}
265
266impl<E: Send + 'static, C: DynChannel + ?Sized> Drop for ChildPool<E, C> {
267 fn drop(&mut self) {
268 if let Link::Attached(abort_timer) = self.link {
269 if !self.is_aborted && !self.is_finished() {
270 if abort_timer.is_zero() {
271 self.abort();
272 } else {
273 self.halt();
274 let handles = self.handles.take().unwrap();
275 tokio::task::spawn(async move {
276 tokio::time::sleep(abort_timer).await;
277 for handle in handles {
278 handle.abort()
279 }
280 });
281 }
282 }
283 }
284 }
285}
286
287#[cfg(test)]
288mod test {
289 use crate::*;
290 use futures::future::pending;
291 use std::sync::atomic::{AtomicU8, Ordering};
292 use std::time::Duration;
293
294 #[tokio::test]
295 async fn dropping() {
296 static HALT_COUNT: AtomicU8 = AtomicU8::new(0);
297 let (child, addr) = spawn_many(
298 0..3,
299 Config::default(),
300 |_, mut inbox: Inbox<()>| async move {
301 if let Err(RecvError::Halted) = inbox.recv().await {
302 HALT_COUNT.fetch_add(1, Ordering::AcqRel);
303 };
304 },
305 );
306 drop(child);
307 addr.await;
308
309 assert_eq!(HALT_COUNT.load(Ordering::Acquire), 3);
310 }
311
312 #[tokio::test]
313 async fn dropping_halts_then_aborts() {
314 static HALT_COUNT: AtomicU8 = AtomicU8::new(0);
315 let (child, addr) = spawn_many(
316 0..3,
317 Config::attached(Duration::from_millis(1)),
318 |_, mut inbox: Inbox<()>| async move {
319 if let Err(RecvError::Halted) = inbox.recv().await {
320 HALT_COUNT.fetch_add(1, Ordering::AcqRel);
321 };
322 pending::<()>().await;
323 },
324 );
325 drop(child);
326 addr.await;
327
328 assert_eq!(HALT_COUNT.load(Ordering::Acquire), 3);
329 }
330
331 #[tokio::test]
332 async fn dropping_detached() {
333 static HALT_COUNT: AtomicU8 = AtomicU8::new(0);
334
335 let (child, addr) = spawn_many(
336 0..3,
337 Config::detached(),
338 |_, mut inbox: Inbox<()>| async move {
339 if let Err(RecvError::Halted) = inbox.recv().await {
340 HALT_COUNT.fetch_add(1, Ordering::AcqRel);
341 };
342 },
343 );
344 drop(child);
345 tokio::time::sleep(Duration::from_millis(1)).await;
346 addr.try_send(()).unwrap();
347 addr.try_send(()).unwrap();
348 addr.try_send(()).unwrap();
349 addr.await;
350
351 assert_eq!(HALT_COUNT.load(Ordering::Acquire), 0);
352 }
353
354 #[tokio::test]
355 async fn downcast() {
356 let (pool, _addr) = spawn_many(0..5, Config::default(), pooled_basic_actor!());
357 assert!(matches!(pool.into_dyn().downcast::<()>(), Ok(_)));
358 }
359
360 #[tokio::test]
361 async fn spawn_ok() {
362 let (mut child, _addr) = spawn_one(Config::default(), basic_actor!());
363 assert!(child.spawn(basic_actor!()).is_ok());
364 assert!(child.into_dyn().try_spawn(basic_actor!()).is_ok());
365 }
366
367 #[tokio::test]
368 async fn spawn_err_exit() {
369 let (mut child, addr) = spawn_one(Config::default(), basic_actor!());
370 addr.halt();
371 addr.await;
372 assert!(matches!(child.spawn(basic_actor!()), Err(SpawnError(_))));
373 assert!(matches!(
374 child.into_dyn().try_spawn(basic_actor!()),
375 Err(TrySpawnError::Exited(_))
376 ));
377 }
378
379 #[tokio::test]
380 async fn spawn_err_incorrect_type() {
381 let (child, _addr) = spawn_one(Config::default(), basic_actor!(u32));
382 assert!(matches!(
383 child.into_dyn().try_spawn(basic_actor!(u64)),
384 Err(TrySpawnError::IncorrectType(_))
385 ));
386 }
387}