task_motel/
manager.rs

1//! Manage tasks arranged in nested groups
2//!
3//! Groups can be added and removed dynamically. When a group is removed,
4//! all of its tasks are stopped, and all of its descendent groups are also removed,
5//! and their contained tasks stopped as well. The group is only completely removed when
6//! all descendent tasks have stopped.
7
8use std::{
9    collections::{HashMap, HashSet},
10    hash::Hash,
11    sync::{atomic::AtomicU32, Arc},
12};
13
14use futures::{
15    channel::mpsc, future::BoxFuture, stream::FuturesUnordered, Future, FutureExt, StreamExt,
16};
17use tracing::Instrument;
18
19use crate::{signal::StopListener, StopBroadcaster};
20
21/// Tracks tasks at the global conductor level, as well as each individual cell level.
22pub struct TaskManager<GroupKey, Outcome> {
23    span: Option<tracing::Span>,
24    groups: HashMap<GroupKey, TaskGroup>,
25    children: HashMap<GroupKey, HashSet<GroupKey>>,
26    parent_map: Box<dyn 'static + Send + Sync + Fn(&GroupKey) -> Option<GroupKey>>,
27    outcome_rx: mpsc::Sender<(GroupKey, Outcome)>,
28    // used to keep track of the number of tasks still running
29    stopping_group_counts: Vec<Arc<AtomicU32>>,
30}
31
32impl<GroupKey, Outcome> TaskManager<GroupKey, Outcome>
33where
34    GroupKey: Clone + Eq + Hash + Send + std::fmt::Debug + 'static,
35    Outcome: Send + 'static,
36{
37    pub fn new(
38        outcome_rx: mpsc::Sender<(GroupKey, Outcome)>,
39        parent_map: impl 'static + Send + Sync + Fn(&GroupKey) -> Option<GroupKey> + 'static,
40    ) -> Self {
41        Self {
42            span: None,
43            groups: Default::default(),
44            children: Default::default(),
45            parent_map: Box::new(parent_map),
46            outcome_rx,
47            stopping_group_counts: Default::default(),
48        }
49    }
50
51    pub fn new_instrumented(
52        span: tracing::Span,
53        outcome_rx: mpsc::Sender<(GroupKey, Outcome)>,
54        parent_map: impl 'static + Send + Sync + Fn(&GroupKey) -> Option<GroupKey> + 'static,
55    ) -> Self {
56        Self {
57            span: Some(span),
58            groups: Default::default(),
59            children: Default::default(),
60            parent_map: Box::new(parent_map),
61            outcome_rx,
62            stopping_group_counts: Default::default(),
63        }
64    }
65
66    /// Add a task to a group
67    pub fn add_task<Fut: Future<Output = Outcome> + Send + 'static>(
68        &mut self,
69        key: GroupKey,
70        f: impl FnOnce(StopListener) -> Fut + Send + 'static,
71    ) {
72        let span = self.span.clone();
73        let mut tx = self.outcome_rx.clone();
74        let group = self.group(key.clone());
75        let listener = group.stopper.listener();
76        let task = async move {
77            let outcome = if let Some(span) = span {
78                f(listener).instrument(span).await
79            } else {
80                f(listener).await
81            };
82            tx.try_send((key, outcome)).ok();
83        }
84        .boxed();
85        group.tasks.spawn(task);
86    }
87
88    /// Remove a group, returning a future that can be waited upon for all tasks
89    /// in that group to complete.
90    pub fn stop_group(&mut self, key: &GroupKey) -> GroupStop {
91        let mut js = tokio::task::JoinSet::new();
92        for key in self.descendants(key) {
93            if let Some(mut group) = self.groups.remove(&key) {
94                // Signal all tasks to stop.
95                group.stopper.emit();
96                let num = group.stopper.num;
97                self.stopping_group_counts.push(num);
98
99                js.spawn(finish_joinset(group.tasks));
100            }
101        }
102
103        async move { finish_joinset(js).await }.boxed()
104    }
105
106    pub(crate) fn descendants(&self, key: &GroupKey) -> HashSet<GroupKey> {
107        let mut all = HashSet::new();
108        all.insert(key.clone());
109
110        let this = &self;
111
112        if let Some(children) = this.children.get(&key) {
113            for child in children {
114                all.extend(this.descendants(child));
115            }
116        }
117
118        all
119    }
120
121    fn group(&mut self, key: GroupKey) -> &mut TaskGroup {
122        self.groups.entry(key.clone()).or_insert_with(|| {
123            if let Some(parent) = (self.parent_map)(&key) {
124                self.children
125                    .entry(parent)
126                    .or_insert_with(HashSet::new)
127                    .insert(key);
128            }
129            TaskGroup::new()
130        })
131    }
132
133    /// For testing purposes only, this is not a reliable indicator, since the JoinSet
134    /// is not polled except when a group is ending, so the count never actually
135    /// decreases.
136    #[cfg(test)]
137    fn num_tasks(&self, key: &GroupKey) -> usize {
138        let current = self
139            .groups
140            .get(key)
141            .map(|group| group.tasks.len())
142            .unwrap_or_default();
143
144        let pending = self
145            .stopping_group_counts
146            .iter()
147            .map(|c| c.load(std::sync::atomic::Ordering::SeqCst))
148            .sum::<u32>() as usize;
149
150        // dbg!(current) + dbg!(pending)
151        current + pending
152    }
153}
154
155pub type GroupStop = BoxFuture<'static, ()>;
156
157struct TaskGroup {
158    pub(crate) tasks: tokio::task::JoinSet<()>,
159    pub(crate) stopper: StopBroadcaster,
160}
161
162impl TaskGroup {
163    pub fn new() -> Self {
164        Self {
165            tasks: tokio::task::JoinSet::new(),
166            stopper: StopBroadcaster::new(),
167        }
168    }
169}
170
171pub type TaskStream<GroupKey, Outcome> =
172    futures::stream::SelectAll<FuturesUnordered<BoxFuture<'static, (GroupKey, Outcome)>>>;
173
174async fn finish_joinset(mut js: tokio::task::JoinSet<()>) {
175    futures::stream::unfold(&mut js, |tasks| async move {
176        if let Err(err) = tasks.join_next().await? {
177            tracing::error!("task_motel: Error while joining task: {:?}", err);
178        }
179        Some(((), tasks))
180    })
181    .collect::<Vec<_>>()
182    .await;
183    js.detach_all();
184}
185#[cfg(test)]
186mod tests {
187    use futures::{channel::mpsc, SinkExt};
188    use maplit::hashset;
189    use rand::seq::SliceRandom;
190
191    use crate::test_util::*;
192
193    use super::*;
194
195    #[derive(Debug, Clone, Hash, PartialEq, Eq)]
196    enum GroupKey {
197        A,
198        B,
199        C,
200        D,
201        E,
202        F,
203        G,
204    }
205
206    #[tokio::test(start_paused = true)]
207    async fn test_task_completion() {
208        use GroupKey::*;
209        let (outcome_tx, mut outcome_rx) = mpsc::channel(1);
210        let mut tm: TaskManager<GroupKey, String> = TaskManager::new(outcome_tx, |g| match g {
211            B => Some(A),
212            _ => None,
213        });
214
215        let sec = tokio::time::Duration::from_secs(1);
216
217        tm.add_task(A, move |stop| {
218            async move {
219                let _stop = stop;
220                tokio::time::sleep(sec).await;
221                tokio::time::sleep(sec).await;
222                tokio::time::sleep(sec).await;
223                "done".to_string()
224            }
225            .boxed()
226        });
227
228        tokio::time::advance(sec).await;
229
230        assert_eq!(tm.num_tasks(&A), 1);
231
232        tokio::time::advance(sec).await;
233
234        let stopping = tm.stop_group(&A);
235
236        assert_eq!(tm.num_tasks(&A), 1);
237
238        stopping.await;
239
240        assert_eq!(tm.num_tasks(&A), 0);
241
242        // tm.stop_group(&A).await;
243        assert_eq!(outcome_rx.next().await.unwrap(), (A, "done".to_string()));
244        assert_eq!(tm.num_tasks(&A), 0);
245    }
246
247    #[tokio::test]
248    async fn test_descendants() {
249        use GroupKey::*;
250        let (outcome_tx, outcome_rx) = mpsc::channel(1);
251        let mut tm: TaskManager<GroupKey, String> = TaskManager::new(outcome_tx, |g| match g {
252            A => None,
253            B => Some(A),
254            C => Some(B),
255            D => Some(B),
256            E => Some(D),
257            F => Some(E),
258            G => Some(C),
259        });
260
261        let mut keys = vec![A, B, C, D, E, F, G];
262        keys.shuffle(&mut rand::thread_rng());
263
264        // Set up the parent map in random order
265        for key in keys.clone() {
266            tm.add_task(key.clone(), |_| async move { format!("{:?}", key) })
267        }
268
269        assert_eq!(tm.descendants(&A), hashset! {A, B, C, D, E, F, G});
270        assert_eq!(tm.descendants(&B), hashset! {B, C, D, E, F, G});
271        assert_eq!(tm.descendants(&C), hashset! {C, G});
272        assert_eq!(tm.descendants(&D), hashset! {D, E, F});
273        assert_eq!(tm.descendants(&E), hashset! {E, F});
274        assert_eq!(tm.descendants(&F), hashset! {F});
275        assert_eq!(tm.descendants(&G), hashset! {G});
276
277        tm.stop_group(&A).await;
278
279        assert_eq!(
280            outcome_rx.take(keys.len()).collect::<HashSet<_>>().await,
281            hashset! {
282                (A, "A".to_string()),
283                (B, "B".to_string()),
284                (C, "C".to_string()),
285                (D, "D".to_string()),
286                (E, "E".to_string()),
287                (F, "F".to_string()),
288                (G, "G".to_string()),
289            }
290        );
291    }
292
293    #[tokio::test]
294    async fn test_group_nesting() {
295        use GroupKey::*;
296        let (outcome_tx, mut outcome_rx) = mpsc::channel(1);
297        let (mut trigger_tx, trigger_rx) = mpsc::channel(1);
298        let mut tm: TaskManager<GroupKey, String> = TaskManager::new(outcome_tx, |g| match g {
299            A => None,
300            B => Some(A),
301            C => Some(B),
302            D => Some(B),
303            _ => None,
304        });
305
306        tm.add_task(A, |stop| blocker("a1", stop));
307        tm.add_task(A, |stop| blocker("a2", stop));
308        tm.add_task(B, |stop| blocker("b1", stop));
309        tm.add_task(C, |stop| blocker("c1", stop));
310        tm.add_task(D, |stop| blocker("d1", stop));
311        tm.add_task(E, |stop| fused("e1", stop.fuse_with(trigger_rx.take(1))));
312
313        assert_eq!(tm.num_tasks(&A), 2);
314        assert_eq!(tm.num_tasks(&B), 1);
315        assert_eq!(tm.num_tasks(&C), 1);
316        assert_eq!(tm.num_tasks(&D), 1);
317        assert_eq!(tm.num_tasks(&E), 1);
318
319        trigger_tx.send(()).await.unwrap();
320        assert_eq!(outcome_rx.next().await.unwrap(), (E, "e1".to_string()));
321        // The actual task number will not decrease until the group is officially stopped.
322        assert_eq!(tm.num_tasks(&E), 1);
323
324        let stopping = tm.stop_group(&D);
325        assert_eq!(tm.num_tasks(&D), 1);
326        stopping.await;
327        assert_eq!(tm.num_tasks(&D), 0);
328        assert_eq!(
329            hashset![outcome_rx.next().await.unwrap(),],
330            hashset![(D, "d1".to_string())]
331        );
332
333        assert_eq!(tm.num_tasks(&A), 2);
334        assert_eq!(tm.num_tasks(&B), 1);
335        assert_eq!(tm.num_tasks(&C), 1);
336        assert_eq!(tm.num_tasks(&D), 0);
337
338        tm.add_task(D, |stop| blocker("dx", stop));
339        assert_eq!(tm.num_tasks(&D), 1);
340
341        tm.stop_group(&B).await;
342        assert_eq!(
343            hashset![
344                outcome_rx.next().await.unwrap(),
345                outcome_rx.next().await.unwrap(),
346                outcome_rx.next().await.unwrap(),
347            ],
348            hashset![
349                (B, "b1".to_string()),
350                (C, "c1".to_string()),
351                (D, "dx".to_string())
352            ]
353        );
354
355        assert_eq!(tm.num_tasks(&A), 2);
356        assert_eq!(tm.num_tasks(&B), 0);
357        assert_eq!(tm.num_tasks(&C), 0);
358        assert_eq!(tm.num_tasks(&D), 0);
359
360        tm.add_task(D, |stop| blocker("dy", stop));
361        assert_eq!(tm.num_tasks(&D), 1);
362
363        tm.stop_group(&A).await;
364        assert_eq!(
365            hashset![
366                outcome_rx.next().await.unwrap(),
367                outcome_rx.next().await.unwrap(),
368                outcome_rx.next().await.unwrap(),
369            ],
370            hashset![
371                (A, "a1".to_string()),
372                (A, "a2".to_string()),
373                (D, "dy".to_string())
374            ]
375        );
376    }
377}