tiny_actor/actor/
child_pool.rs

1use crate::*;
2use futures::{Future, FutureExt, Stream};
3use std::{fmt::Debug, mem::ManuallyDrop, sync::Arc, task::Poll, time::Duration};
4use tokio::task::JoinHandle;
5
6/// A child-pool is the non clone-able reference to an actor with a multiple processes.
7///
8/// child-pools can be of two forms:
9/// * `ChildPool<E, Channel<M>>`: This is the default form, it can be transformed into a `ChildPool<E>` using
10/// [ChildPool::into_dyn]. Additional processes can be spawned using [ChildPool::spawn].
11/// * `ChildPool<E>`: This form is a dynamic child-pool, it can be transformed back into a `ChildPool<E, Channel<M>>`
12/// using [ChildPool::downcast::<M>]. Additional processes can be spawned using [ChildPool::try_spawn].
13///
14/// A child-pool can be streamed which returns values of `E` when the processes exit.
15#[derive(Debug)]
16pub struct ChildPool<E, C = dyn AnyChannel>
17where
18    E: Send + 'static,
19    C: DynChannel + ?Sized,
20{
21    pub(super) channel: Arc<C>,
22    pub(super) handles: Option<Vec<JoinHandle<E>>>,
23    pub(super) link: Link,
24    pub(super) is_aborted: bool,
25}
26
27impl<E, C> ChildPool<E, C>
28where
29    E: Send + 'static,
30    C: DynChannel + ?Sized,
31{
32    pub(crate) fn new(channel: Arc<C>, handles: Vec<JoinHandle<E>>, link: Link) -> Self {
33        Self {
34            channel,
35            handles: Some(handles),
36            link,
37            is_aborted: false,
38        }
39    }
40
41    fn into_parts(self) -> (Arc<C>, Vec<JoinHandle<E>>, Link, bool) {
42        let no_drop = ManuallyDrop::new(self);
43        unsafe {
44            let mut handle = std::ptr::read(&no_drop.handles);
45            let channel = std::ptr::read(&no_drop.channel);
46            let link = std::ptr::read(&no_drop.link);
47            let is_aborted = std::ptr::read(&no_drop.is_aborted);
48            (channel, handle.take().unwrap(), link, is_aborted)
49        }
50    }
51
52    /// Get the underlying [JoinHandles](JoinHandle). The order does not necessarily reflect
53    /// the order in which processes were spawned.
54    ///
55    /// This will not run drop, and therefore the `Actor` will not be halted/aborted.
56    pub fn into_joinhandles(self) -> Vec<JoinHandle<E>> {
57        self.into_parts().1
58    }
59
60    /// Abort the actor.
61    ///
62    /// Returns `true` if this is the first abort.
63    pub fn abort(&mut self) -> bool {
64        self.channel.close();
65        let was_aborted = self.is_aborted;
66        self.is_aborted = true;
67        for handle in self.handles.as_ref().unwrap() {
68            handle.abort()
69        }
70        !was_aborted
71    }
72
73    /// Whether all tasks have finished.
74    pub fn is_finished(&self) -> bool {
75        self.handles
76            .as_ref()
77            .unwrap()
78            .iter()
79            .all(|handle| handle.is_finished())
80    }
81
82    /// The amount of tasks that are alive.
83    ///
84    /// This should give the same result as [ChildPool::process_count], as long as
85    /// an inbox is only dropped whenever it's task finishes.
86    pub fn task_count(&self) -> usize {
87        self.handles
88            .as_ref()
89            .unwrap()
90            .iter()
91            .filter(|handle| !handle.is_finished())
92            .collect::<Vec<_>>()
93            .len()
94    }
95
96    /// The amount of handles to processes that this pool contains. This can be bigger
97    /// than the `process_count` or `task_count` if processes have exited.
98    pub fn handle_count(&self) -> usize {
99        self.handles.as_ref().unwrap().len()
100    }
101
102    /// Attempt to spawn an additional process on the channel.
103    ///
104    /// This method can fail if
105    /// * the message-type does not match that of the channel.
106    /// * the channel has already exited.
107    pub fn try_spawn<M, Fun, Fut>(&mut self, fun: Fun) -> Result<(), TrySpawnError<Fun>>
108    where
109        Fun: FnOnce(Inbox<M>) -> Fut + Send + 'static,
110        Fut: Future<Output = E> + Send + 'static,
111        M: Send + 'static,
112        E: Send + 'static,
113        C: AnyChannel,
114    {
115        let channel = match Arc::downcast::<Channel<M>>(self.channel.clone().into_any()) {
116            Ok(channel) => channel,
117            Err(_) => return Err(TrySpawnError::IncorrectType(fun)),
118        };
119
120        match channel.try_add_inbox() {
121            Ok(_) => {
122                let inbox = Inbox::from_channel(channel);
123                let handle = tokio::task::spawn(async move { fun(inbox).await });
124                self.handles.as_mut().unwrap().push(handle);
125                Ok(())
126            }
127            Err(_) => Err(TrySpawnError::Exited(fun)),
128        }
129    }
130
131    /// Downcast the `ChildPool<E>` to a `ChildPool<E, Channel<M>>`
132    pub fn downcast<M>(self) -> Result<ChildPool<E, Channel<M>>, Self>
133    where
134        M: Send + 'static,
135        C: AnyChannel,
136    {
137        let (channel, handles, link, is_aborted) = self.into_parts();
138        match channel.clone().into_any().downcast::<Channel<M>>() {
139            Ok(channel) => Ok(ChildPool {
140                handles: Some(handles),
141                channel,
142                link,
143                is_aborted,
144            }),
145            Err(_) => Err(ChildPool {
146                handles: Some(handles),
147                channel,
148                link,
149                is_aborted,
150            }),
151        }
152    }
153
154    /// Halts the actor, and then returns a stream that waits for exits.
155    ///
156    /// If the timeout expires before all processes have exited, the actor will be aborted.
157    ///
158    /// # Examples
159    /// ```no_run
160    /// # use tiny_actor::*;
161    /// # use std::time::Duration;
162    /// # #[tokio::main]
163    /// # async fn main() {
164    /// use futures::StreamExt;
165    ///
166    /// let mut pool: ChildPool<()> = todo!();
167    /// let exits: Vec<_> = pool.shutdown(Duration::from_secs(1)).collect().await;
168    /// # }
169    /// ```
170    pub fn shutdown(&mut self, timeout: Duration) -> ShutdownStream<'_, E, C> {
171        ShutdownStream::new(self, timeout)
172    }
173
174    gen::dyn_channel_methods!();
175    gen::child_methods!();
176}
177
178impl<E, M> ChildPool<E, Channel<M>>
179where
180    E: Send + 'static,
181{
182    /// Convert the `ChildPool<E, Channel<M>` into a `ChildPool<E>`.
183    pub fn into_dyn(self) -> ChildPool<E>
184    where
185        M: Send + 'static,
186    {
187        let parts = self.into_parts();
188        ChildPool {
189            handles: Some(parts.1),
190            channel: parts.0,
191            link: parts.2,
192            is_aborted: parts.3,
193        }
194    }
195
196    /// Attempt to spawn an additional process on the channel.
197    ///
198    /// This method fails if the channel has already exited.
199    pub fn spawn<Fun, Fut>(&mut self, fun: Fun) -> Result<(), SpawnError<Fun>>
200    where
201        Fun: FnOnce(Inbox<M>) -> Fut + Send + 'static,
202        Fut: Future<Output = E> + Send + 'static,
203        E: Send + 'static,
204        M: Send + 'static,
205    {
206        match self.channel.try_add_inbox() {
207            Ok(_) => {
208                let inbox = Inbox::from_channel(self.channel.clone());
209                let handle = tokio::task::spawn(async move { fun(inbox).await });
210                self.handles.as_mut().unwrap().push(handle);
211                Ok(())
212            }
213            Err(_) => Err(SpawnError(fun)),
214        }
215    }
216
217    gen::send_methods!();
218}
219
220#[cfg(feature = "internals")]
221impl<E, C> ChildPool<E, C>
222where
223    E: Send + 'static,
224    C: DynChannel + ?Sized,
225{
226    pub fn transform_channel<C2: DynChannel + ?Sized>(
227        self,
228        func: fn(Arc<C>) -> Arc<C2>,
229    ) -> ChildPool<E, C2> {
230        let (channel, handles, link, is_aborted) = self.into_parts();
231        ChildPool {
232            channel: func(channel),
233            handles: Some(handles),
234            link,
235            is_aborted,
236        }
237    }
238
239    pub fn channel_ref(&self) -> &C {
240        &self.channel
241    }
242}
243
244impl<E: Send + 'static, C: DynChannel + ?Sized> Stream for ChildPool<E, C> {
245    type Item = Result<E, ExitError>;
246
247    fn poll_next(
248        mut self: std::pin::Pin<&mut Self>,
249        cx: &mut std::task::Context<'_>,
250    ) -> std::task::Poll<Option<Self::Item>> {
251        if self.handles.as_ref().unwrap().len() == 0 {
252            return Poll::Ready(None);
253        }
254
255        for (i, handle) in self.handles.as_mut().unwrap().iter_mut().enumerate() {
256            if let Poll::Ready(res) = handle.poll_unpin(cx) {
257                self.handles.as_mut().unwrap().swap_remove(i);
258                return Poll::Ready(Some(res.map_err(Into::into)));
259            }
260        }
261
262        Poll::Pending
263    }
264}
265
266impl<E: Send + 'static, C: DynChannel + ?Sized> Drop for ChildPool<E, C> {
267    fn drop(&mut self) {
268        if let Link::Attached(abort_timer) = self.link {
269            if !self.is_aborted && !self.is_finished() {
270                if abort_timer.is_zero() {
271                    self.abort();
272                } else {
273                    self.halt();
274                    let handles = self.handles.take().unwrap();
275                    tokio::task::spawn(async move {
276                        tokio::time::sleep(abort_timer).await;
277                        for handle in handles {
278                            handle.abort()
279                        }
280                    });
281                }
282            }
283        }
284    }
285}
286
287#[cfg(test)]
288mod test {
289    use crate::*;
290    use futures::future::pending;
291    use std::sync::atomic::{AtomicU8, Ordering};
292    use std::time::Duration;
293
294    #[tokio::test]
295    async fn dropping() {
296        static HALT_COUNT: AtomicU8 = AtomicU8::new(0);
297        let (child, addr) = spawn_many(
298            0..3,
299            Config::default(),
300            |_, mut inbox: Inbox<()>| async move {
301                if let Err(RecvError::Halted) = inbox.recv().await {
302                    HALT_COUNT.fetch_add(1, Ordering::AcqRel);
303                };
304            },
305        );
306        drop(child);
307        addr.await;
308
309        assert_eq!(HALT_COUNT.load(Ordering::Acquire), 3);
310    }
311
312    #[tokio::test]
313    async fn dropping_halts_then_aborts() {
314        static HALT_COUNT: AtomicU8 = AtomicU8::new(0);
315        let (child, addr) = spawn_many(
316            0..3,
317            Config::attached(Duration::from_millis(1)),
318            |_, mut inbox: Inbox<()>| async move {
319                if let Err(RecvError::Halted) = inbox.recv().await {
320                    HALT_COUNT.fetch_add(1, Ordering::AcqRel);
321                };
322                pending::<()>().await;
323            },
324        );
325        drop(child);
326        addr.await;
327
328        assert_eq!(HALT_COUNT.load(Ordering::Acquire), 3);
329    }
330
331    #[tokio::test]
332    async fn dropping_detached() {
333        static HALT_COUNT: AtomicU8 = AtomicU8::new(0);
334
335        let (child, addr) = spawn_many(
336            0..3,
337            Config::detached(),
338            |_, mut inbox: Inbox<()>| async move {
339                if let Err(RecvError::Halted) = inbox.recv().await {
340                    HALT_COUNT.fetch_add(1, Ordering::AcqRel);
341                };
342            },
343        );
344        drop(child);
345        tokio::time::sleep(Duration::from_millis(1)).await;
346        addr.try_send(()).unwrap();
347        addr.try_send(()).unwrap();
348        addr.try_send(()).unwrap();
349        addr.await;
350
351        assert_eq!(HALT_COUNT.load(Ordering::Acquire), 0);
352    }
353
354    #[tokio::test]
355    async fn downcast() {
356        let (pool, _addr) = spawn_many(0..5, Config::default(), pooled_basic_actor!());
357        assert!(matches!(pool.into_dyn().downcast::<()>(), Ok(_)));
358    }
359
360    #[tokio::test]
361    async fn spawn_ok() {
362        let (mut child, _addr) = spawn_one(Config::default(), basic_actor!());
363        assert!(child.spawn(basic_actor!()).is_ok());
364        assert!(child.into_dyn().try_spawn(basic_actor!()).is_ok());
365    }
366
367    #[tokio::test]
368    async fn spawn_err_exit() {
369        let (mut child, addr) = spawn_one(Config::default(), basic_actor!());
370        addr.halt();
371        addr.await;
372        assert!(matches!(child.spawn(basic_actor!()), Err(SpawnError(_))));
373        assert!(matches!(
374            child.into_dyn().try_spawn(basic_actor!()),
375            Err(TrySpawnError::Exited(_))
376        ));
377    }
378
379    #[tokio::test]
380    async fn spawn_err_incorrect_type() {
381        let (child, _addr) = spawn_one(Config::default(), basic_actor!(u32));
382        assert!(matches!(
383            child.into_dyn().try_spawn(basic_actor!(u64)),
384            Err(TrySpawnError::IncorrectType(_))
385        ));
386    }
387}