1use std::collections::{HashMap, HashSet};
11
12use super::graph::{TaskGraph, TaskId};
13
14#[derive(Debug, Clone)]
16pub struct RegionHealth {
17 pub total_tasks: usize,
18 pub failed_tasks: usize,
19 pub failure_rate: f32,
21}
22
23impl RegionHealth {
24 fn new() -> Self {
25 Self {
26 total_tasks: 0,
27 failed_tasks: 0,
28 failure_rate: 0.0,
29 }
30 }
31
32 fn record(&mut self, failed: bool) {
33 self.total_tasks += 1;
34 if failed {
35 self.failed_tasks += 1;
36 }
37 #[allow(clippy::cast_precision_loss)]
38 {
39 self.failure_rate = self.failed_tasks as f32 / self.total_tasks as f32;
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct CascadeConfig {
47 pub failure_threshold: f32,
49}
50
51#[derive(Debug)]
58pub struct CascadeDetector {
59 config: CascadeConfig,
60 region_health: HashMap<TaskId, RegionHealth>,
62}
63
64impl CascadeDetector {
65 #[must_use]
67 pub fn new(config: CascadeConfig) -> Self {
68 Self {
69 config,
70 region_health: HashMap::new(),
71 }
72 }
73
74 pub fn record_outcome(&mut self, task_id: TaskId, succeeded: bool, graph: &TaskGraph) {
76 let root = primary_root(task_id, graph);
77 self.region_health
78 .entry(root)
79 .or_insert_with(RegionHealth::new)
80 .record(!succeeded);
81 }
82
83 #[must_use]
85 pub fn is_cascading(&self, task_id: TaskId, graph: &TaskGraph) -> bool {
86 let root = primary_root(task_id, graph);
87 self.region_health
88 .get(&root)
89 .is_some_and(|h| h.failure_rate > self.config.failure_threshold)
90 }
91
92 #[must_use]
96 pub fn deprioritized_tasks(&self, graph: &TaskGraph) -> HashSet<TaskId> {
97 let cascading_roots: HashSet<TaskId> = self
99 .region_health
100 .iter()
101 .filter(|(_, h)| h.failure_rate > self.config.failure_threshold)
102 .map(|(&root, _)| root)
103 .collect();
104
105 if cascading_roots.is_empty() {
106 return HashSet::new();
107 }
108
109 let total_regions = self.region_health.len();
111 if cascading_roots.len() == total_regions && total_regions > 0 {
112 tracing::warn!(
113 cascading_regions = total_regions,
114 "all DAG regions are in cascade failure state; \
115 deprioritisation has no effect — falling back to default ordering"
116 );
117 return HashSet::new();
118 }
119
120 graph
121 .tasks
122 .iter()
123 .filter(|t| cascading_roots.contains(&primary_root(t.id, graph)))
124 .map(|t| t.id)
125 .collect()
126 }
127
128 pub fn reset(&mut self) {
133 self.region_health.clear();
134 }
135
136 #[cfg(test)]
138 #[must_use]
139 pub fn region_health(&self) -> &HashMap<TaskId, RegionHealth> {
140 &self.region_health
141 }
142}
143
144fn primary_root(task_id: TaskId, graph: &TaskGraph) -> TaskId {
152 let roots = ancestor_roots(task_id, graph);
153 if roots.is_empty() {
154 return task_id;
155 }
156 if roots.len() == 1 {
157 return roots[0];
158 }
159
160 roots
162 .into_iter()
163 .max_by_key(|&r| (descendant_count(r, graph), u32::MAX - r.as_u32()))
164 .unwrap_or(task_id)
165}
166
167fn ancestor_roots(task_id: TaskId, graph: &TaskGraph) -> Vec<TaskId> {
169 let mut visited = HashSet::new();
170 let mut queue = std::collections::VecDeque::new();
171 queue.push_back(task_id);
172 visited.insert(task_id);
173
174 let mut roots = Vec::new();
175
176 while let Some(id) = queue.pop_front() {
177 let task = &graph.tasks[id.index()];
178 if task.depends_on.is_empty() {
179 roots.push(id);
180 } else {
181 for &dep in &task.depends_on {
182 if visited.insert(dep) {
183 queue.push_back(dep);
184 }
185 }
186 }
187 }
188
189 roots
190}
191
192fn descendant_count(root: TaskId, graph: &TaskGraph) -> usize {
194 let mut visited = HashSet::new();
195 let mut queue = std::collections::VecDeque::new();
196 queue.push_back(root);
197 visited.insert(root);
198
199 let mut forward: HashMap<TaskId, Vec<TaskId>> = HashMap::new();
202 for task in &graph.tasks {
203 for &dep in &task.depends_on {
204 forward.entry(dep).or_default().push(task.id);
205 }
206 }
207
208 while let Some(id) = queue.pop_front() {
209 if let Some(children) = forward.get(&id) {
210 for &child in children {
211 if visited.insert(child) {
212 queue.push_back(child);
213 }
214 }
215 }
216 }
217
218 visited.len()
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::graph::{TaskGraph, TaskId, TaskNode};
225
226 fn make_node(id: u32, deps: &[u32]) -> TaskNode {
227 let mut n = TaskNode::new(id, format!("t{id}"), "desc");
228 n.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
229 n
230 }
231
232 fn graph_from(nodes: Vec<TaskNode>) -> TaskGraph {
233 let mut g = TaskGraph::new("test");
234 g.tasks = nodes;
235 g
236 }
237
238 fn cfg(threshold: f32) -> CascadeConfig {
239 CascadeConfig {
240 failure_threshold: threshold,
241 }
242 }
243
244 #[test]
247 fn root_task_returns_self() {
248 let g = graph_from(vec![make_node(0, &[])]);
249 let roots = ancestor_roots(TaskId(0), &g);
250 assert_eq!(roots, vec![TaskId(0)]);
251 }
252
253 #[test]
254 fn linear_chain_root_is_task_zero() {
255 let g = graph_from(vec![
257 make_node(0, &[]),
258 make_node(1, &[0]),
259 make_node(2, &[1]),
260 ]);
261 let roots = ancestor_roots(TaskId(2), &g);
262 assert_eq!(roots, vec![TaskId(0)]);
263 }
264
265 #[test]
266 fn diamond_has_two_roots() {
267 let g = graph_from(vec![
269 make_node(0, &[]),
270 make_node(1, &[0]),
271 make_node(2, &[0]),
272 make_node(3, &[1, 2]),
273 ]);
274 let mut roots = ancestor_roots(TaskId(3), &g);
275 roots.sort_by_key(|r| r.as_u32());
276 assert_eq!(roots, vec![TaskId(0)]);
278 }
279
280 #[test]
281 fn fan_in_has_multiple_roots() {
282 let g = graph_from(vec![
284 make_node(0, &[]),
285 make_node(1, &[]),
286 make_node(2, &[]),
287 make_node(3, &[0, 1, 2]),
288 ]);
289 let mut roots = ancestor_roots(TaskId(3), &g);
290 roots.sort_by_key(|r| r.as_u32());
291 assert_eq!(roots, vec![TaskId(0), TaskId(1), TaskId(2)]);
292 }
293
294 #[test]
297 fn no_failures_not_cascading() {
298 let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
299 let mut det = CascadeDetector::new(cfg(0.5));
300 det.record_outcome(TaskId(1), true, &g);
301 assert!(!det.is_cascading(TaskId(1), &g));
302 }
303
304 #[test]
305 fn failure_rate_exceeds_threshold() {
306 let g = graph_from(vec![
308 make_node(0, &[]),
309 make_node(1, &[0]),
310 make_node(2, &[0]),
311 make_node(3, &[0]),
312 ]);
313 let mut det = CascadeDetector::new(cfg(0.5));
314 det.record_outcome(TaskId(1), false, &g);
315 det.record_outcome(TaskId(2), false, &g);
316 det.record_outcome(TaskId(3), true, &g);
317 assert!(det.is_cascading(TaskId(1), &g));
319 assert!(det.is_cascading(TaskId(2), &g));
320 assert!(det.is_cascading(TaskId(3), &g));
321 }
322
323 #[test]
324 fn reset_clears_all_regions() {
325 let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
326 let mut det = CascadeDetector::new(cfg(0.3));
327 det.record_outcome(TaskId(1), false, &g);
328 det.reset();
329 assert!(!det.is_cascading(TaskId(1), &g));
330 assert!(det.region_health().is_empty());
331 }
332
333 #[test]
336 fn deprioritized_tasks_empty_when_healthy() {
337 let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
338 let mut det = CascadeDetector::new(cfg(0.5));
339 det.record_outcome(TaskId(1), true, &g);
340 assert!(det.deprioritized_tasks(&g).is_empty());
341 }
342
343 #[test]
344 fn deprioritized_tasks_returns_failing_subtree() {
345 let g = graph_from(vec![
348 make_node(0, &[]),
349 make_node(1, &[0]),
350 make_node(2, &[0]),
351 make_node(3, &[]),
352 make_node(4, &[3]),
353 ]);
354 let mut det = CascadeDetector::new(cfg(0.4));
355 det.record_outcome(TaskId(1), false, &g);
356 det.record_outcome(TaskId(2), false, &g);
357 det.record_outcome(TaskId(4), true, &g);
358 let dp = det.deprioritized_tasks(&g);
359 assert!(dp.contains(&TaskId(0)));
361 assert!(dp.contains(&TaskId(1)));
362 assert!(dp.contains(&TaskId(2)));
363 assert!(!dp.contains(&TaskId(3)));
365 assert!(!dp.contains(&TaskId(4)));
366 }
367
368 #[test]
369 fn all_regions_cascading_returns_empty_for_safe_fallback() {
370 let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
371 let mut det = CascadeDetector::new(cfg(0.3));
372 det.record_outcome(TaskId(1), false, &g);
374 let dp = det.deprioritized_tasks(&g);
377 assert!(
378 dp.is_empty(),
379 "all-regions-cascading should return empty to allow forward progress"
380 );
381 }
382}