Skip to main content

zeph_orchestration/
cascade.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Cascade-aware routing for DAG execution (arXiv:2603.17112).
5//!
6//! Tracks failure propagation across DAG regions. When a subtree's failure rate
7//! exceeds the configured threshold, tasks in that subtree are deprioritized in
8//! `DagScheduler::tick()` so that healthy independent branches run first.
9
10use std::collections::{HashMap, HashSet};
11
12use super::graph::{TaskGraph, TaskId};
13
14/// Per-region failure health snapshot.
15#[derive(Debug, Clone)]
16pub struct RegionHealth {
17    pub total_tasks: usize,
18    pub failed_tasks: usize,
19    /// `failed_tasks / total_tasks`. NaN when `total_tasks = 0`.
20    pub failure_rate: f32,
21}
22
23impl RegionHealth {
24    fn new() -> Self {
25        Self {
26            total_tasks: 0,
27            failed_tasks: 0,
28            failure_rate: 0.0,
29        }
30    }
31
32    fn record(&mut self, failed: bool) {
33        self.total_tasks += 1;
34        if failed {
35            self.failed_tasks += 1;
36        }
37        #[allow(clippy::cast_precision_loss)]
38        {
39            self.failure_rate = self.failed_tasks as f32 / self.total_tasks as f32;
40        }
41    }
42}
43
44/// Configuration for cascade detection. Extracted from `OrchestrationConfig` at construction.
45#[derive(Debug, Clone)]
46pub struct CascadeConfig {
47    /// Failure rate threshold above which a region is considered "cascading".
48    pub failure_threshold: f32,
49}
50
51/// Tracks failure propagation across DAG regions for cascade-aware routing.
52///
53/// A "region" is the set of tasks reachable from the "heaviest root" of each task — the
54/// root ancestor that is the source of the most downstream tasks. For tasks with multiple
55/// roots (diamond patterns), we pick the single root that covers the most descendants,
56/// preventing over-aggressive deprioritisation (C8 fix).
57#[derive(Debug)]
58pub struct CascadeDetector {
59    config: CascadeConfig,
60    /// Per-root failure health. Key = root `TaskId`.
61    region_health: HashMap<TaskId, RegionHealth>,
62}
63
64impl CascadeDetector {
65    /// Create a new detector.
66    #[must_use]
67    pub fn new(config: CascadeConfig) -> Self {
68        Self {
69            config,
70            region_health: HashMap::new(),
71        }
72    }
73
74    /// Record a task outcome and update the health of its primary region.
75    pub fn record_outcome(&mut self, task_id: TaskId, succeeded: bool, graph: &TaskGraph) {
76        let root = primary_root(task_id, graph);
77        self.region_health
78            .entry(root)
79            .or_insert_with(RegionHealth::new)
80            .record(!succeeded);
81    }
82
83    /// Returns `true` when the primary region of `task_id` is in cascade failure.
84    #[must_use]
85    pub fn is_cascading(&self, task_id: TaskId, graph: &TaskGraph) -> bool {
86        let root = primary_root(task_id, graph);
87        self.region_health
88            .get(&root)
89            .is_some_and(|h| h.failure_rate > self.config.failure_threshold)
90    }
91
92    /// Returns the set of task IDs that should be deprioritized due to cascade failure.
93    ///
94    /// Returns an empty set when no region is cascading, avoiding unnecessary reordering.
95    #[must_use]
96    pub fn deprioritized_tasks(&self, graph: &TaskGraph) -> HashSet<TaskId> {
97        // Collect cascading roots first to avoid calling is_cascading per-task.
98        let cascading_roots: HashSet<TaskId> = self
99            .region_health
100            .iter()
101            .filter(|(_, h)| h.failure_rate > self.config.failure_threshold)
102            .map(|(&root, _)| root)
103            .collect();
104
105        if cascading_roots.is_empty() {
106            return HashSet::new();
107        }
108
109        // Log degenerate case: all known regions are cascading.
110        let total_regions = self.region_health.len();
111        if cascading_roots.len() == total_regions && total_regions > 0 {
112            tracing::warn!(
113                cascading_regions = total_regions,
114                "all DAG regions are in cascade failure state; \
115                 deprioritisation has no effect — falling back to default ordering"
116            );
117            return HashSet::new();
118        }
119
120        graph
121            .tasks
122            .iter()
123            .filter(|t| cascading_roots.contains(&primary_root(t.id, graph)))
124            .map(|t| t.id)
125            .collect()
126    }
127
128    /// Reset all region health counters.
129    ///
130    /// Called by `DagScheduler::inject_tasks()` because graph topology has fundamentally
131    /// changed — old failure counts no longer reflect the new task set (C13 fix).
132    pub fn reset(&mut self) {
133        self.region_health.clear();
134    }
135
136    /// Expose region health for testing.
137    #[cfg(test)]
138    #[must_use]
139    pub fn region_health(&self) -> &HashMap<TaskId, RegionHealth> {
140        &self.region_health
141    }
142}
143
144/// Compute the "heaviest" root for `task_id`: the root ancestor that reaches the most
145/// downstream tasks (largest subtree). For tasks that have no ancestors (roots themselves)
146/// `task_id` is returned directly.
147///
148/// "Heaviest root" prevents over-aggressive cascade deprioritisation on diamond DAGs: if
149/// task C is reachable from both A and B, we assign it to whichever root's subtree is
150/// larger. Ties are broken by smaller `TaskId` value for determinism.
151fn primary_root(task_id: TaskId, graph: &TaskGraph) -> TaskId {
152    let roots = ancestor_roots(task_id, graph);
153    if roots.is_empty() {
154        return task_id;
155    }
156    if roots.len() == 1 {
157        return roots[0];
158    }
159
160    // Count descendants for each root candidate.
161    roots
162        .into_iter()
163        .max_by_key(|&r| (descendant_count(r, graph), u32::MAX - r.as_u32()))
164        .unwrap_or(task_id)
165}
166
167/// Collect all root (in-degree 0) ancestors of `task_id` via BFS.
168fn ancestor_roots(task_id: TaskId, graph: &TaskGraph) -> Vec<TaskId> {
169    let mut visited = HashSet::new();
170    let mut queue = std::collections::VecDeque::new();
171    queue.push_back(task_id);
172    visited.insert(task_id);
173
174    let mut roots = Vec::new();
175
176    while let Some(id) = queue.pop_front() {
177        let task = &graph.tasks[id.index()];
178        if task.depends_on.is_empty() {
179            roots.push(id);
180        } else {
181            for &dep in &task.depends_on {
182                if visited.insert(dep) {
183                    queue.push_back(dep);
184                }
185            }
186        }
187    }
188
189    roots
190}
191
192/// Count the number of tasks reachable from `root` (inclusive) via BFS.
193fn descendant_count(root: TaskId, graph: &TaskGraph) -> usize {
194    let mut visited = HashSet::new();
195    let mut queue = std::collections::VecDeque::new();
196    queue.push_back(root);
197    visited.insert(root);
198
199    // Build forward adjacency on the fly.
200    // Tasks store `depends_on` (reverse edges). We need forward edges.
201    let mut forward: HashMap<TaskId, Vec<TaskId>> = HashMap::new();
202    for task in &graph.tasks {
203        for &dep in &task.depends_on {
204            forward.entry(dep).or_default().push(task.id);
205        }
206    }
207
208    while let Some(id) = queue.pop_front() {
209        if let Some(children) = forward.get(&id) {
210            for &child in children {
211                if visited.insert(child) {
212                    queue.push_back(child);
213                }
214            }
215        }
216    }
217
218    visited.len()
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::graph::{TaskGraph, TaskId, TaskNode};
225
226    fn make_node(id: u32, deps: &[u32]) -> TaskNode {
227        let mut n = TaskNode::new(id, format!("t{id}"), "desc");
228        n.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
229        n
230    }
231
232    fn graph_from(nodes: Vec<TaskNode>) -> TaskGraph {
233        let mut g = TaskGraph::new("test");
234        g.tasks = nodes;
235        g
236    }
237
238    fn cfg(threshold: f32) -> CascadeConfig {
239        CascadeConfig {
240            failure_threshold: threshold,
241        }
242    }
243
244    // --- ancestor_roots ---
245
246    #[test]
247    fn root_task_returns_self() {
248        let g = graph_from(vec![make_node(0, &[])]);
249        let roots = ancestor_roots(TaskId(0), &g);
250        assert_eq!(roots, vec![TaskId(0)]);
251    }
252
253    #[test]
254    fn linear_chain_root_is_task_zero() {
255        // 0 -> 1 -> 2
256        let g = graph_from(vec![
257            make_node(0, &[]),
258            make_node(1, &[0]),
259            make_node(2, &[1]),
260        ]);
261        let roots = ancestor_roots(TaskId(2), &g);
262        assert_eq!(roots, vec![TaskId(0)]);
263    }
264
265    #[test]
266    fn diamond_has_two_roots() {
267        // A(0) -> {B(1), C(2)} -> D(3)
268        let g = graph_from(vec![
269            make_node(0, &[]),
270            make_node(1, &[0]),
271            make_node(2, &[0]),
272            make_node(3, &[1, 2]),
273        ]);
274        let mut roots = ancestor_roots(TaskId(3), &g);
275        roots.sort_by_key(|r| r.as_u32());
276        // Only one root (0) because 1 and 2 are not roots themselves.
277        assert_eq!(roots, vec![TaskId(0)]);
278    }
279
280    #[test]
281    fn fan_in_has_multiple_roots() {
282        // A(0), B(1), C(2) -> D(3)
283        let g = graph_from(vec![
284            make_node(0, &[]),
285            make_node(1, &[]),
286            make_node(2, &[]),
287            make_node(3, &[0, 1, 2]),
288        ]);
289        let mut roots = ancestor_roots(TaskId(3), &g);
290        roots.sort_by_key(|r| r.as_u32());
291        assert_eq!(roots, vec![TaskId(0), TaskId(1), TaskId(2)]);
292    }
293
294    // --- record_outcome + is_cascading ---
295
296    #[test]
297    fn no_failures_not_cascading() {
298        let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
299        let mut det = CascadeDetector::new(cfg(0.5));
300        det.record_outcome(TaskId(1), true, &g);
301        assert!(!det.is_cascading(TaskId(1), &g));
302    }
303
304    #[test]
305    fn failure_rate_exceeds_threshold() {
306        // 0 -> 1, 0 -> 2, 0 -> 3 (fan-out). Two of three fail.
307        let g = graph_from(vec![
308            make_node(0, &[]),
309            make_node(1, &[0]),
310            make_node(2, &[0]),
311            make_node(3, &[0]),
312        ]);
313        let mut det = CascadeDetector::new(cfg(0.5));
314        det.record_outcome(TaskId(1), false, &g);
315        det.record_outcome(TaskId(2), false, &g);
316        det.record_outcome(TaskId(3), true, &g);
317        // 2 failures / 3 total = 0.67 > 0.5 threshold
318        assert!(det.is_cascading(TaskId(1), &g));
319        assert!(det.is_cascading(TaskId(2), &g));
320        assert!(det.is_cascading(TaskId(3), &g));
321    }
322
323    #[test]
324    fn reset_clears_all_regions() {
325        let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
326        let mut det = CascadeDetector::new(cfg(0.3));
327        det.record_outcome(TaskId(1), false, &g);
328        det.reset();
329        assert!(!det.is_cascading(TaskId(1), &g));
330        assert!(det.region_health().is_empty());
331    }
332
333    // --- deprioritized_tasks ---
334
335    #[test]
336    fn deprioritized_tasks_empty_when_healthy() {
337        let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
338        let mut det = CascadeDetector::new(cfg(0.5));
339        det.record_outcome(TaskId(1), true, &g);
340        assert!(det.deprioritized_tasks(&g).is_empty());
341    }
342
343    #[test]
344    fn deprioritized_tasks_returns_failing_subtree() {
345        // Root 0 -> {1, 2}; Root 3 -> 4. Fail 1 and 2 (region of root 0).
346        // Root 3 stays healthy.
347        let g = graph_from(vec![
348            make_node(0, &[]),
349            make_node(1, &[0]),
350            make_node(2, &[0]),
351            make_node(3, &[]),
352            make_node(4, &[3]),
353        ]);
354        let mut det = CascadeDetector::new(cfg(0.4));
355        det.record_outcome(TaskId(1), false, &g);
356        det.record_outcome(TaskId(2), false, &g);
357        det.record_outcome(TaskId(4), true, &g);
358        let dp = det.deprioritized_tasks(&g);
359        // Tasks 0, 1, 2 belong to root 0 which is cascading.
360        assert!(dp.contains(&TaskId(0)));
361        assert!(dp.contains(&TaskId(1)));
362        assert!(dp.contains(&TaskId(2)));
363        // Tasks 3 and 4 are in healthy region.
364        assert!(!dp.contains(&TaskId(3)));
365        assert!(!dp.contains(&TaskId(4)));
366    }
367
368    #[test]
369    fn all_regions_cascading_returns_empty_for_safe_fallback() {
370        let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
371        let mut det = CascadeDetector::new(cfg(0.3));
372        // Only one region (root 0), make it cascade.
373        det.record_outcome(TaskId(1), false, &g);
374        // With one region and it cascading, deprioritized_tasks returns empty
375        // to prevent complete deadlock (C9 fix).
376        let dp = det.deprioritized_tasks(&g);
377        assert!(
378            dp.is_empty(),
379            "all-regions-cascading should return empty to allow forward progress"
380        );
381    }
382}