1use std::{
9 collections::{HashMap, HashSet},
10 hash::Hash,
11 sync::{atomic::AtomicU32, Arc},
12};
13
14use futures::{
15 channel::mpsc, future::BoxFuture, stream::FuturesUnordered, Future, FutureExt, StreamExt,
16};
17use tracing::Instrument;
18
19use crate::{signal::StopListener, StopBroadcaster};
20
21pub struct TaskManager<GroupKey, Outcome> {
23 span: Option<tracing::Span>,
24 groups: HashMap<GroupKey, TaskGroup>,
25 children: HashMap<GroupKey, HashSet<GroupKey>>,
26 parent_map: Box<dyn 'static + Send + Sync + Fn(&GroupKey) -> Option<GroupKey>>,
27 outcome_rx: mpsc::Sender<(GroupKey, Outcome)>,
28 stopping_group_counts: Vec<Arc<AtomicU32>>,
30}
31
32impl<GroupKey, Outcome> TaskManager<GroupKey, Outcome>
33where
34 GroupKey: Clone + Eq + Hash + Send + std::fmt::Debug + 'static,
35 Outcome: Send + 'static,
36{
37 pub fn new(
38 outcome_rx: mpsc::Sender<(GroupKey, Outcome)>,
39 parent_map: impl 'static + Send + Sync + Fn(&GroupKey) -> Option<GroupKey> + 'static,
40 ) -> Self {
41 Self {
42 span: None,
43 groups: Default::default(),
44 children: Default::default(),
45 parent_map: Box::new(parent_map),
46 outcome_rx,
47 stopping_group_counts: Default::default(),
48 }
49 }
50
51 pub fn new_instrumented(
52 span: tracing::Span,
53 outcome_rx: mpsc::Sender<(GroupKey, Outcome)>,
54 parent_map: impl 'static + Send + Sync + Fn(&GroupKey) -> Option<GroupKey> + 'static,
55 ) -> Self {
56 Self {
57 span: Some(span),
58 groups: Default::default(),
59 children: Default::default(),
60 parent_map: Box::new(parent_map),
61 outcome_rx,
62 stopping_group_counts: Default::default(),
63 }
64 }
65
66 pub fn add_task<Fut: Future<Output = Outcome> + Send + 'static>(
68 &mut self,
69 key: GroupKey,
70 f: impl FnOnce(StopListener) -> Fut + Send + 'static,
71 ) {
72 let span = self.span.clone();
73 let mut tx = self.outcome_rx.clone();
74 let group = self.group(key.clone());
75 let listener = group.stopper.listener();
76 let task = async move {
77 let outcome = if let Some(span) = span {
78 f(listener).instrument(span).await
79 } else {
80 f(listener).await
81 };
82 tx.try_send((key, outcome)).ok();
83 }
84 .boxed();
85 group.tasks.spawn(task);
86 }
87
88 pub fn stop_group(&mut self, key: &GroupKey) -> GroupStop {
91 let mut js = tokio::task::JoinSet::new();
92 for key in self.descendants(key) {
93 if let Some(mut group) = self.groups.remove(&key) {
94 group.stopper.emit();
96 let num = group.stopper.num;
97 self.stopping_group_counts.push(num);
98
99 js.spawn(finish_joinset(group.tasks));
100 }
101 }
102
103 async move { finish_joinset(js).await }.boxed()
104 }
105
106 pub(crate) fn descendants(&self, key: &GroupKey) -> HashSet<GroupKey> {
107 let mut all = HashSet::new();
108 all.insert(key.clone());
109
110 let this = &self;
111
112 if let Some(children) = this.children.get(&key) {
113 for child in children {
114 all.extend(this.descendants(child));
115 }
116 }
117
118 all
119 }
120
121 fn group(&mut self, key: GroupKey) -> &mut TaskGroup {
122 self.groups.entry(key.clone()).or_insert_with(|| {
123 if let Some(parent) = (self.parent_map)(&key) {
124 self.children
125 .entry(parent)
126 .or_insert_with(HashSet::new)
127 .insert(key);
128 }
129 TaskGroup::new()
130 })
131 }
132
133 #[cfg(test)]
137 fn num_tasks(&self, key: &GroupKey) -> usize {
138 let current = self
139 .groups
140 .get(key)
141 .map(|group| group.tasks.len())
142 .unwrap_or_default();
143
144 let pending = self
145 .stopping_group_counts
146 .iter()
147 .map(|c| c.load(std::sync::atomic::Ordering::SeqCst))
148 .sum::<u32>() as usize;
149
150 current + pending
152 }
153}
154
155pub type GroupStop = BoxFuture<'static, ()>;
156
157struct TaskGroup {
158 pub(crate) tasks: tokio::task::JoinSet<()>,
159 pub(crate) stopper: StopBroadcaster,
160}
161
162impl TaskGroup {
163 pub fn new() -> Self {
164 Self {
165 tasks: tokio::task::JoinSet::new(),
166 stopper: StopBroadcaster::new(),
167 }
168 }
169}
170
171pub type TaskStream<GroupKey, Outcome> =
172 futures::stream::SelectAll<FuturesUnordered<BoxFuture<'static, (GroupKey, Outcome)>>>;
173
174async fn finish_joinset(mut js: tokio::task::JoinSet<()>) {
175 futures::stream::unfold(&mut js, |tasks| async move {
176 if let Err(err) = tasks.join_next().await? {
177 tracing::error!("task_motel: Error while joining task: {:?}", err);
178 }
179 Some(((), tasks))
180 })
181 .collect::<Vec<_>>()
182 .await;
183 js.detach_all();
184}
185#[cfg(test)]
186mod tests {
187 use futures::{channel::mpsc, SinkExt};
188 use maplit::hashset;
189 use rand::seq::SliceRandom;
190
191 use crate::test_util::*;
192
193 use super::*;
194
195 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
196 enum GroupKey {
197 A,
198 B,
199 C,
200 D,
201 E,
202 F,
203 G,
204 }
205
206 #[tokio::test(start_paused = true)]
207 async fn test_task_completion() {
208 use GroupKey::*;
209 let (outcome_tx, mut outcome_rx) = mpsc::channel(1);
210 let mut tm: TaskManager<GroupKey, String> = TaskManager::new(outcome_tx, |g| match g {
211 B => Some(A),
212 _ => None,
213 });
214
215 let sec = tokio::time::Duration::from_secs(1);
216
217 tm.add_task(A, move |stop| {
218 async move {
219 let _stop = stop;
220 tokio::time::sleep(sec).await;
221 tokio::time::sleep(sec).await;
222 tokio::time::sleep(sec).await;
223 "done".to_string()
224 }
225 .boxed()
226 });
227
228 tokio::time::advance(sec).await;
229
230 assert_eq!(tm.num_tasks(&A), 1);
231
232 tokio::time::advance(sec).await;
233
234 let stopping = tm.stop_group(&A);
235
236 assert_eq!(tm.num_tasks(&A), 1);
237
238 stopping.await;
239
240 assert_eq!(tm.num_tasks(&A), 0);
241
242 assert_eq!(outcome_rx.next().await.unwrap(), (A, "done".to_string()));
244 assert_eq!(tm.num_tasks(&A), 0);
245 }
246
247 #[tokio::test]
248 async fn test_descendants() {
249 use GroupKey::*;
250 let (outcome_tx, outcome_rx) = mpsc::channel(1);
251 let mut tm: TaskManager<GroupKey, String> = TaskManager::new(outcome_tx, |g| match g {
252 A => None,
253 B => Some(A),
254 C => Some(B),
255 D => Some(B),
256 E => Some(D),
257 F => Some(E),
258 G => Some(C),
259 });
260
261 let mut keys = vec![A, B, C, D, E, F, G];
262 keys.shuffle(&mut rand::thread_rng());
263
264 for key in keys.clone() {
266 tm.add_task(key.clone(), |_| async move { format!("{:?}", key) })
267 }
268
269 assert_eq!(tm.descendants(&A), hashset! {A, B, C, D, E, F, G});
270 assert_eq!(tm.descendants(&B), hashset! {B, C, D, E, F, G});
271 assert_eq!(tm.descendants(&C), hashset! {C, G});
272 assert_eq!(tm.descendants(&D), hashset! {D, E, F});
273 assert_eq!(tm.descendants(&E), hashset! {E, F});
274 assert_eq!(tm.descendants(&F), hashset! {F});
275 assert_eq!(tm.descendants(&G), hashset! {G});
276
277 tm.stop_group(&A).await;
278
279 assert_eq!(
280 outcome_rx.take(keys.len()).collect::<HashSet<_>>().await,
281 hashset! {
282 (A, "A".to_string()),
283 (B, "B".to_string()),
284 (C, "C".to_string()),
285 (D, "D".to_string()),
286 (E, "E".to_string()),
287 (F, "F".to_string()),
288 (G, "G".to_string()),
289 }
290 );
291 }
292
293 #[tokio::test]
294 async fn test_group_nesting() {
295 use GroupKey::*;
296 let (outcome_tx, mut outcome_rx) = mpsc::channel(1);
297 let (mut trigger_tx, trigger_rx) = mpsc::channel(1);
298 let mut tm: TaskManager<GroupKey, String> = TaskManager::new(outcome_tx, |g| match g {
299 A => None,
300 B => Some(A),
301 C => Some(B),
302 D => Some(B),
303 _ => None,
304 });
305
306 tm.add_task(A, |stop| blocker("a1", stop));
307 tm.add_task(A, |stop| blocker("a2", stop));
308 tm.add_task(B, |stop| blocker("b1", stop));
309 tm.add_task(C, |stop| blocker("c1", stop));
310 tm.add_task(D, |stop| blocker("d1", stop));
311 tm.add_task(E, |stop| fused("e1", stop.fuse_with(trigger_rx.take(1))));
312
313 assert_eq!(tm.num_tasks(&A), 2);
314 assert_eq!(tm.num_tasks(&B), 1);
315 assert_eq!(tm.num_tasks(&C), 1);
316 assert_eq!(tm.num_tasks(&D), 1);
317 assert_eq!(tm.num_tasks(&E), 1);
318
319 trigger_tx.send(()).await.unwrap();
320 assert_eq!(outcome_rx.next().await.unwrap(), (E, "e1".to_string()));
321 assert_eq!(tm.num_tasks(&E), 1);
323
324 let stopping = tm.stop_group(&D);
325 assert_eq!(tm.num_tasks(&D), 1);
326 stopping.await;
327 assert_eq!(tm.num_tasks(&D), 0);
328 assert_eq!(
329 hashset![outcome_rx.next().await.unwrap(),],
330 hashset![(D, "d1".to_string())]
331 );
332
333 assert_eq!(tm.num_tasks(&A), 2);
334 assert_eq!(tm.num_tasks(&B), 1);
335 assert_eq!(tm.num_tasks(&C), 1);
336 assert_eq!(tm.num_tasks(&D), 0);
337
338 tm.add_task(D, |stop| blocker("dx", stop));
339 assert_eq!(tm.num_tasks(&D), 1);
340
341 tm.stop_group(&B).await;
342 assert_eq!(
343 hashset![
344 outcome_rx.next().await.unwrap(),
345 outcome_rx.next().await.unwrap(),
346 outcome_rx.next().await.unwrap(),
347 ],
348 hashset![
349 (B, "b1".to_string()),
350 (C, "c1".to_string()),
351 (D, "dx".to_string())
352 ]
353 );
354
355 assert_eq!(tm.num_tasks(&A), 2);
356 assert_eq!(tm.num_tasks(&B), 0);
357 assert_eq!(tm.num_tasks(&C), 0);
358 assert_eq!(tm.num_tasks(&D), 0);
359
360 tm.add_task(D, |stop| blocker("dy", stop));
361 assert_eq!(tm.num_tasks(&D), 1);
362
363 tm.stop_group(&A).await;
364 assert_eq!(
365 hashset![
366 outcome_rx.next().await.unwrap(),
367 outcome_rx.next().await.unwrap(),
368 outcome_rx.next().await.unwrap(),
369 ],
370 hashset![
371 (A, "a1".to_string()),
372 (A, "a2".to_string()),
373 (D, "dy".to_string())
374 ]
375 );
376 }
377}