Skip to main content

zeph_orchestration/
topology.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Heuristic topology classification for `TaskGraph` DAGs.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9use zeph_config::OrchestrationConfig;
10
11use super::graph::{TaskGraph, TaskId, TaskNode};
12
13/// Structural classification of a `TaskGraph`.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum Topology {
17    /// All tasks are independent (zero edges). Max parallelism applies.
18    AllParallel,
19    /// Strict linear chain: each task depends on exactly the previous one.
20    LinearChain,
21    /// Single root fans out to multiple independent leaves.
22    FanOut,
23    /// Multiple independent roots converge to a single sink node. Dual of `FanOut`.
24    ///
25    /// Detection: single node with in-degree >= 2 that is the sole non-root sink,
26    /// all other nodes are roots (in-degree 0).
27    FanIn,
28    /// Multi-level DAG with fan-out at multiple depths (tree-like structure).
29    ///
30    /// Detection: single root, `longest_path` >= 2, max in-degree == 1 for all non-root nodes.
31    Hierarchical,
32    /// None of the above; mixed dependency patterns.
33    Mixed,
34}
35
36/// How the scheduler should dispatch tasks based on topology analysis.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum DispatchStrategy {
40    /// Dispatch all ready tasks immediately up to `max_parallel`.
41    ///
42    /// Used for: `AllParallel`, `FanOut`, `FanIn`.
43    FullParallel,
44    /// Dispatch tasks one at a time in dependency order.
45    ///
46    /// Used for: `LinearChain`.
47    Sequential,
48    /// Dispatch tasks level-by-level with a barrier between levels.
49    ///
50    /// Used for: `Hierarchical`.
51    LevelBarrier,
52    /// Mix of parallel and sequential based on local subgraph structure.
53    ///
54    /// Scheduler falls back to default ready-task dispatch with conservative parallelism.
55    /// Used for: `Mixed`.
56    Adaptive,
57    /// Priority-queue dispatch ordering tasks by critical path distance (deepest first).
58    ///
59    /// Applied to `FanOut`/`FanIn` topologies when `tree_optimized_dispatch = true`.
60    /// Dispatches tasks closer to sinks first to minimise end-to-end latency.
61    TreeOptimized,
62    /// Monitor failure rates per subgraph region; deprioritize tasks in failing subtrees.
63    ///
64    /// Applied to `Mixed` topology when `cascade_routing = true`.
65    CascadeAware,
66}
67
68/// Complete topology analysis result computed in a single O(|V|+|E|) pass.
69#[derive(Debug, Clone)]
70pub struct TopologyAnalysis {
71    pub topology: Topology,
72    pub strategy: DispatchStrategy,
73    pub max_parallel: usize,
74    /// Longest path in the DAG (critical path length).
75    pub depth: usize,
76    /// Per-task depth from root (BFS level). Used by `LevelBarrier` dispatch.
77    ///
78    /// Uses `HashMap` so new tasks injected via `inject_tasks()` can be
79    /// added without index-out-of-bounds on `Vec` access (critic S3).
80    pub depths: HashMap<TaskId, usize>,
81}
82
83/// Stateless DAG topology classifier.
84pub struct TopologyClassifier;
85
86impl TopologyClassifier {
87    /// Classify the topology of a `TaskGraph`.
88    ///
89    /// Empty graphs return `AllParallel` (no constraints).
90    ///
91    /// Calls `compute_longest_path_and_depths` once and delegates to
92    /// [`classify_with_depths`]. Use [`classify_with_depths`] directly when
93    /// depths have already been computed to avoid redundant work.
94    #[must_use]
95    pub fn classify(graph: &TaskGraph) -> Topology {
96        let tasks = &graph.tasks;
97        if tasks.is_empty() {
98            return Topology::AllParallel;
99        }
100        // Early exit for no-edge graphs avoids the toposort in classify_with_depths.
101        let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
102        if edge_count == 0 {
103            return Topology::AllParallel;
104        }
105        let (longest, depths) = compute_longest_path_and_depths(tasks);
106        Self::classify_with_depths(graph, longest, &depths)
107    }
108
109    /// Classify the topology of a `TaskGraph` using pre-computed depth values.
110    ///
111    /// Accepts the `longest_path` and per-task `depths` map produced by a prior
112    /// call to `compute_longest_path_and_depths` (or `compute_depths_for_scheduler`),
113    /// avoiding a redundant toposort pass when those values are already available.
114    #[must_use]
115    pub fn classify_with_depths(
116        graph: &TaskGraph,
117        longest_path: usize,
118        // NOTE: depths is reserved for future heuristics (e.g. per-level density).
119        // Classification currently only needs longest_path and structural edge counts.
120        _depths: &HashMap<TaskId, usize>,
121    ) -> Topology {
122        let tasks = &graph.tasks;
123        let n = tasks.len();
124
125        if n == 0 {
126            return Topology::AllParallel;
127        }
128
129        let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
130
131        if edge_count == 0 {
132            return Topology::AllParallel;
133        }
134
135        // Linear chain: exactly n-1 edges and longest path = n-1.
136        // depths is already computed — check longest_path directly instead of re-computing.
137        if edge_count == n - 1 && longest_path == n - 1 {
138            return Topology::LinearChain;
139        }
140
141        let roots_count = tasks.iter().filter(|t| t.depends_on.is_empty()).count();
142
143        // Fan-out: single root, max depth == 1 (root + one layer of leaves only).
144        if roots_count == 1 && longest_path == 1 {
145            return Topology::FanOut;
146        }
147
148        // FanIn: multiple roots converge to exactly one sink.
149        // The sink has >= 2 dependencies (dep_count >= 2). All other nodes are roots.
150        // Depth must be exactly 1.
151        let non_roots_count = tasks.iter().filter(|t| !t.depends_on.is_empty()).count();
152        if roots_count >= 2 && non_roots_count == 1 && longest_path == 1 {
153            let sink_dep_count = tasks
154                .iter()
155                .filter(|t| !t.depends_on.is_empty())
156                .map(|t| t.depends_on.len())
157                .next()
158                .unwrap_or(0);
159            if sink_dep_count >= 2 {
160                return Topology::FanIn;
161            }
162        }
163
164        // Hierarchical: single root, depth >= 2, max in-degree (dep_count) == 1 for all nodes
165        // (tree-like: no node has multiple parents — ensures no diamond patterns).
166        if roots_count == 1 && longest_path >= 2 {
167            let max_dep_count = tasks.iter().map(|t| t.depends_on.len()).max().unwrap_or(0);
168            if max_dep_count <= 1 {
169                return Topology::Hierarchical;
170            }
171        }
172
173        Topology::Mixed
174    }
175
176    /// Compute the effective `max_parallel` for a given topology and configured base.
177    ///
178    /// Encapsulates the topology-to-parallelism policy in one place so that
179    /// `analyze()` and the `tick()` dirty-reanalysis path use identical logic.
180    ///
181    /// `base` must be the immutable config value (`config.max_parallel`), never a
182    /// previously reduced `self.max_parallel`, to prevent drift across replan cycles.
183    #[must_use]
184    pub fn compute_max_parallel(topology: Topology, base: usize) -> usize {
185        match topology {
186            Topology::AllParallel | Topology::FanOut | Topology::FanIn | Topology::Hierarchical => {
187                base
188            }
189            Topology::LinearChain => 1,
190            Topology::Mixed => (base / 2 + 1).min(base).max(1),
191        }
192    }
193
194    /// Map a `Topology` variant to the appropriate `DispatchStrategy`, considering config.
195    ///
196    /// This is the single source of truth for topology → strategy mapping. All callers
197    /// (including the dirty-reanalysis path in `tick()`) must go through this method.
198    ///
199    /// Pass `config = &OrchestrationConfig::default()` when no config-specific overrides
200    /// are needed (e.g., in unit tests that only care about base topology behaviour).
201    #[must_use]
202    pub fn strategy(topology: Topology, config: &OrchestrationConfig) -> DispatchStrategy {
203        match topology {
204            Topology::FanOut | Topology::FanIn if config.tree_optimized_dispatch => {
205                DispatchStrategy::TreeOptimized
206            }
207            Topology::Mixed if config.cascade_routing => DispatchStrategy::CascadeAware,
208            Topology::AllParallel | Topology::FanOut | Topology::FanIn => {
209                DispatchStrategy::FullParallel
210            }
211            Topology::LinearChain => DispatchStrategy::Sequential,
212            Topology::Hierarchical => DispatchStrategy::LevelBarrier,
213            Topology::Mixed => DispatchStrategy::Adaptive,
214        }
215    }
216
217    /// Compute a complete `TopologyAnalysis` in a single O(|V|+|E|) pass.
218    ///
219    /// When `topology_selection` is disabled in config, returns a default
220    /// `FullParallel` analysis with config's `max_parallel` — zero overhead.
221    ///
222    /// # Performance
223    ///
224    /// Uses a single Kahn's toposort pass to compute both topology classification
225    /// and per-task depths simultaneously.
226    #[must_use]
227    pub fn analyze(graph: &TaskGraph, config: &OrchestrationConfig) -> TopologyAnalysis {
228        let tasks = &graph.tasks;
229        let n = tasks.len();
230
231        if !config.topology_selection || n == 0 {
232            return TopologyAnalysis {
233                topology: Topology::AllParallel,
234                strategy: DispatchStrategy::FullParallel,
235                max_parallel: config.max_parallel as usize,
236                depth: 0,
237                depths: HashMap::new(),
238            };
239        }
240
241        let (longest, depths) = compute_longest_path_and_depths(tasks);
242        let topology = Self::classify_with_depths(graph, longest, &depths);
243        let strategy = Self::strategy(topology, config);
244        let base = config.max_parallel as usize;
245        let max_parallel = Self::compute_max_parallel(topology, base);
246
247        TopologyAnalysis {
248            topology,
249            strategy,
250            max_parallel,
251            depth: longest,
252            depths,
253        }
254    }
255}
256
257/// Compute depths for the scheduler's dirty re-analysis path.
258///
259/// Thin wrapper around `compute_longest_path_and_depths` for use by `DagScheduler::tick()`
260/// when `topology_dirty=true`.
261pub(crate) fn compute_depths_for_scheduler(
262    graph: &TaskGraph,
263) -> (usize, std::collections::HashMap<TaskId, usize>) {
264    compute_longest_path_and_depths(&graph.tasks)
265}
266
267/// Compute the longest path and per-task depth map using Kahn's toposort.
268///
269/// Returns `(longest_path, depths_map)` where `depths_map[task_id] = depth_from_root`.
270///
271/// Single O(|V|+|E|) pass. Assumes a validated DAG (no cycles).
272fn compute_longest_path_and_depths(tasks: &[TaskNode]) -> (usize, HashMap<TaskId, usize>) {
273    let n = tasks.len();
274    if n == 0 {
275        return (0, HashMap::new());
276    }
277
278    let mut in_degree = vec![0usize; n];
279    let mut dependents: Vec<Vec<usize>> = vec![Vec::new(); n];
280    for task in tasks {
281        let i = task.id.index();
282        in_degree[i] = task.depends_on.len();
283        for dep in &task.depends_on {
284            dependents[dep.index()].push(i);
285        }
286    }
287
288    let mut queue: std::collections::VecDeque<usize> = in_degree
289        .iter()
290        .enumerate()
291        .filter(|(_, d)| **d == 0)
292        .map(|(i, _)| i)
293        .collect();
294
295    let mut dist = vec![0usize; n];
296    let mut max_dist = 0usize;
297
298    while let Some(u) = queue.pop_front() {
299        for &v in &dependents[u] {
300            let new_dist = dist[u] + 1;
301            if new_dist > dist[v] {
302                dist[v] = new_dist;
303            }
304            if dist[v] > max_dist {
305                max_dist = dist[v];
306            }
307            in_degree[v] -= 1;
308            if in_degree[v] == 0 {
309                queue.push_back(v);
310            }
311        }
312    }
313
314    let depths: HashMap<TaskId, usize> = tasks.iter().map(|t| (t.id, dist[t.id.index()])).collect();
315
316    (max_dist, depths)
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::graph::{TaskGraph, TaskId, TaskNode};
323
324    fn make_node(id: u32, deps: &[u32]) -> TaskNode {
325        let mut n = TaskNode::new(id, format!("t{id}"), "desc");
326        n.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
327        n
328    }
329
330    fn graph_from(nodes: Vec<TaskNode>) -> TaskGraph {
331        let mut g = TaskGraph::new("test");
332        g.tasks = nodes;
333        g
334    }
335
336    fn default_config() -> zeph_config::OrchestrationConfig {
337        zeph_config::OrchestrationConfig {
338            topology_selection: true,
339            max_parallel: 4,
340            ..zeph_config::OrchestrationConfig::default()
341        }
342    }
343
344    // --- classify tests ---
345
346    #[test]
347    fn classify_empty_graph() {
348        let g = graph_from(vec![]);
349        assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
350    }
351
352    #[test]
353    fn classify_single_task() {
354        let g = graph_from(vec![make_node(0, &[])]);
355        assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
356    }
357
358    #[test]
359    fn classify_all_parallel() {
360        let g = graph_from(vec![
361            make_node(0, &[]),
362            make_node(1, &[]),
363            make_node(2, &[]),
364        ]);
365        assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
366    }
367
368    #[test]
369    fn classify_two_task_chain() {
370        // A(0) -> B(1)
371        let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
372        assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
373    }
374
375    #[test]
376    fn classify_linear_chain() {
377        // A(0) -> B(1) -> C(2)
378        let g = graph_from(vec![
379            make_node(0, &[]),
380            make_node(1, &[0]),
381            make_node(2, &[1]),
382        ]);
383        assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
384    }
385
386    #[test]
387    fn classify_fan_out() {
388        // A(0) -> {B(1), C(2), D(3)}
389        let g = graph_from(vec![
390            make_node(0, &[]),
391            make_node(1, &[0]),
392            make_node(2, &[0]),
393            make_node(3, &[0]),
394        ]);
395        assert_eq!(TopologyClassifier::classify(&g), Topology::FanOut);
396    }
397
398    #[test]
399    fn classify_fan_in() {
400        // A(0), B(1), C(2) all -> D(3): multiple roots, single sink
401        let g = graph_from(vec![
402            make_node(0, &[]),
403            make_node(1, &[]),
404            make_node(2, &[]),
405            make_node(3, &[0, 1, 2]),
406        ]);
407        assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
408    }
409
410    #[test]
411    fn classify_fan_in_two_roots() {
412        // A(0), B(1) -> C(2)
413        let g = graph_from(vec![
414            make_node(0, &[]),
415            make_node(1, &[]),
416            make_node(2, &[0, 1]),
417        ]);
418        assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
419    }
420
421    #[test]
422    fn classify_hierarchical() {
423        // A(0) -> {B(1), C(2)}, B(1) -> D(3), C(2) -> E(4)
424        // Single root, depth=2, max in-degree=1 for non-roots
425        let g = graph_from(vec![
426            make_node(0, &[]),
427            make_node(1, &[0]),
428            make_node(2, &[0]),
429            make_node(3, &[1]),
430            make_node(4, &[2]),
431        ]);
432        assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
433    }
434
435    #[test]
436    fn classify_hierarchical_three_levels() {
437        // A(0) -> B(1) -> C(2) -> D(3): linear but single root, depth=3, in-degree<=1 => Hierarchical?
438        // No — linear chain is caught first (n-1 edges, longest=n-1). This tests a tree.
439        // A(0) -> {B(1), C(2)}, B(1) -> D(3)
440        let g = graph_from(vec![
441            make_node(0, &[]),
442            make_node(1, &[0]),
443            make_node(2, &[0]),
444            make_node(3, &[1]),
445        ]);
446        assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
447    }
448
449    #[test]
450    fn classify_diamond_is_mixed() {
451        // A(0) -> {B(1), C(2)} -> D(3)
452        let g = graph_from(vec![
453            make_node(0, &[]),
454            make_node(1, &[0]),
455            make_node(2, &[0]),
456            make_node(3, &[1, 2]),
457        ]);
458        assert_eq!(TopologyClassifier::classify(&g), Topology::Mixed);
459    }
460
461    #[test]
462    fn classify_fan_out_with_chain_on_branch_is_hierarchical() {
463        // A(0) -> {B(1), C(2)}, B(1) -> D(3) — single root, depth=2, all in-degrees <= 1 → Hierarchical
464        let g = graph_from(vec![
465            make_node(0, &[]),
466            make_node(1, &[0]),
467            make_node(2, &[0]),
468            make_node(3, &[1]),
469        ]);
470        assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
471    }
472
473    // --- strategy tests ---
474
475    fn no_overrides_config() -> zeph_config::OrchestrationConfig {
476        zeph_config::OrchestrationConfig {
477            topology_selection: true,
478            max_parallel: 4,
479            cascade_routing: false,
480            tree_optimized_dispatch: false,
481            ..zeph_config::OrchestrationConfig::default()
482        }
483    }
484
485    #[test]
486    fn strategy_all_parallel_is_full_parallel() {
487        assert_eq!(
488            TopologyClassifier::strategy(Topology::AllParallel, &no_overrides_config()),
489            DispatchStrategy::FullParallel
490        );
491    }
492
493    #[test]
494    fn strategy_fan_out_is_full_parallel() {
495        assert_eq!(
496            TopologyClassifier::strategy(Topology::FanOut, &no_overrides_config()),
497            DispatchStrategy::FullParallel
498        );
499    }
500
501    #[test]
502    fn strategy_fan_in_is_full_parallel() {
503        assert_eq!(
504            TopologyClassifier::strategy(Topology::FanIn, &no_overrides_config()),
505            DispatchStrategy::FullParallel
506        );
507    }
508
509    #[test]
510    fn strategy_linear_chain_is_sequential() {
511        assert_eq!(
512            TopologyClassifier::strategy(Topology::LinearChain, &no_overrides_config()),
513            DispatchStrategy::Sequential
514        );
515    }
516
517    #[test]
518    fn strategy_hierarchical_is_level_barrier() {
519        assert_eq!(
520            TopologyClassifier::strategy(Topology::Hierarchical, &no_overrides_config()),
521            DispatchStrategy::LevelBarrier
522        );
523    }
524
525    #[test]
526    fn strategy_mixed_is_adaptive() {
527        assert_eq!(
528            TopologyClassifier::strategy(Topology::Mixed, &no_overrides_config()),
529            DispatchStrategy::Adaptive
530        );
531    }
532
533    #[test]
534    fn strategy_fan_out_tree_optimized_when_enabled() {
535        let mut cfg = no_overrides_config();
536        cfg.tree_optimized_dispatch = true;
537        assert_eq!(
538            TopologyClassifier::strategy(Topology::FanOut, &cfg),
539            DispatchStrategy::TreeOptimized
540        );
541        assert_eq!(
542            TopologyClassifier::strategy(Topology::FanIn, &cfg),
543            DispatchStrategy::TreeOptimized
544        );
545    }
546
547    #[test]
548    fn strategy_mixed_cascade_aware_when_enabled() {
549        let mut cfg = no_overrides_config();
550        cfg.cascade_routing = true;
551        assert_eq!(
552            TopologyClassifier::strategy(Topology::Mixed, &cfg),
553            DispatchStrategy::CascadeAware
554        );
555    }
556
557    #[test]
558    fn strategy_tree_optimized_does_not_affect_non_fan_topologies() {
559        let mut cfg = no_overrides_config();
560        cfg.tree_optimized_dispatch = true;
561        assert_eq!(
562            TopologyClassifier::strategy(Topology::Hierarchical, &cfg),
563            DispatchStrategy::LevelBarrier
564        );
565        assert_eq!(
566            TopologyClassifier::strategy(Topology::LinearChain, &cfg),
567            DispatchStrategy::Sequential
568        );
569        assert_eq!(
570            TopologyClassifier::strategy(Topology::Mixed, &cfg),
571            DispatchStrategy::Adaptive
572        );
573    }
574
575    // --- analyze tests ---
576
577    #[test]
578    fn analyze_disabled_returns_full_parallel() {
579        let mut cfg = default_config();
580        cfg.topology_selection = false;
581        let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
582        let analysis = TopologyClassifier::analyze(&g, &cfg);
583        assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
584        assert_eq!(analysis.max_parallel, 4);
585        assert_eq!(analysis.topology, Topology::AllParallel);
586    }
587
588    #[test]
589    fn analyze_linear_chain_returns_sequential() {
590        let cfg = default_config();
591        let g = graph_from(vec![
592            make_node(0, &[]),
593            make_node(1, &[0]),
594            make_node(2, &[1]),
595        ]);
596        let analysis = TopologyClassifier::analyze(&g, &cfg);
597        assert_eq!(analysis.topology, Topology::LinearChain);
598        assert_eq!(analysis.strategy, DispatchStrategy::Sequential);
599        assert_eq!(analysis.max_parallel, 1);
600        assert_eq!(analysis.depth, 2);
601    }
602
603    #[test]
604    fn analyze_hierarchical_returns_level_barrier() {
605        let cfg = default_config();
606        // A(0) -> {B(1), C(2)}, B(1) -> D(3)
607        let g = graph_from(vec![
608            make_node(0, &[]),
609            make_node(1, &[0]),
610            make_node(2, &[0]),
611            make_node(3, &[1]),
612        ]);
613        let analysis = TopologyClassifier::analyze(&g, &cfg);
614        assert_eq!(analysis.topology, Topology::Hierarchical);
615        assert_eq!(analysis.strategy, DispatchStrategy::LevelBarrier);
616        assert_eq!(analysis.max_parallel, 4);
617        assert_eq!(analysis.depth, 2);
618        // Verify depths
619        assert_eq!(analysis.depths[&TaskId(0)], 0);
620        assert_eq!(analysis.depths[&TaskId(1)], 1);
621        assert_eq!(analysis.depths[&TaskId(2)], 1);
622        assert_eq!(analysis.depths[&TaskId(3)], 2);
623    }
624
625    #[test]
626    fn analyze_fan_in_returns_full_parallel() {
627        let cfg = default_config();
628        // A(0), B(1), C(2) -> D(3)
629        let g = graph_from(vec![
630            make_node(0, &[]),
631            make_node(1, &[]),
632            make_node(2, &[]),
633            make_node(3, &[0, 1, 2]),
634        ]);
635        let analysis = TopologyClassifier::analyze(&g, &cfg);
636        assert_eq!(analysis.topology, Topology::FanIn);
637        assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
638        assert_eq!(analysis.max_parallel, 4);
639    }
640
641    #[test]
642    fn analyze_mixed_is_conservative() {
643        let cfg = default_config(); // max_parallel=4 -> (4/2+1).min(4).max(1) = 3
644        let g = graph_from(vec![
645            make_node(0, &[]),
646            make_node(1, &[0]),
647            make_node(2, &[0]),
648            make_node(3, &[1, 2]),
649        ]);
650        let analysis = TopologyClassifier::analyze(&g, &cfg);
651        assert_eq!(analysis.topology, Topology::Mixed);
652        assert_eq!(analysis.strategy, DispatchStrategy::Adaptive);
653        assert_eq!(analysis.max_parallel, 3);
654    }
655
656    #[test]
657    fn analyze_depths_correct_for_fan_out() {
658        let cfg = default_config();
659        // A(0) -> {B(1), C(2), D(3)}
660        let g = graph_from(vec![
661            make_node(0, &[]),
662            make_node(1, &[0]),
663            make_node(2, &[0]),
664            make_node(3, &[0]),
665        ]);
666        let analysis = TopologyClassifier::analyze(&g, &cfg);
667        assert_eq!(analysis.depths[&TaskId(0)], 0);
668        assert_eq!(analysis.depths[&TaskId(1)], 1);
669        assert_eq!(analysis.depths[&TaskId(2)], 1);
670        assert_eq!(analysis.depths[&TaskId(3)], 1);
671    }
672
673    #[test]
674    fn analyze_mixed_respects_max_parallel_one() {
675        let mut cfg = default_config();
676        cfg.max_parallel = 1;
677        let g = graph_from(vec![
678            make_node(0, &[]),
679            make_node(1, &[0]),
680            make_node(2, &[0]),
681            make_node(3, &[1, 2]),
682        ]);
683        let analysis = TopologyClassifier::analyze(&g, &cfg);
684        assert_eq!(analysis.max_parallel, 1);
685    }
686
687    // --- classify_with_depths tests ---
688
689    #[test]
690    fn classify_with_depths_matches_classify_for_all_variants() {
691        let graphs = vec![
692            // AllParallel
693            graph_from(vec![
694                make_node(0, &[]),
695                make_node(1, &[]),
696                make_node(2, &[]),
697            ]),
698            // LinearChain
699            graph_from(vec![
700                make_node(0, &[]),
701                make_node(1, &[0]),
702                make_node(2, &[1]),
703            ]),
704            // FanOut
705            graph_from(vec![
706                make_node(0, &[]),
707                make_node(1, &[0]),
708                make_node(2, &[0]),
709                make_node(3, &[0]),
710            ]),
711            // FanIn
712            graph_from(vec![
713                make_node(0, &[]),
714                make_node(1, &[]),
715                make_node(2, &[]),
716                make_node(3, &[0, 1, 2]),
717            ]),
718            // Hierarchical
719            graph_from(vec![
720                make_node(0, &[]),
721                make_node(1, &[0]),
722                make_node(2, &[0]),
723                make_node(3, &[1]),
724            ]),
725            // Mixed (diamond)
726            graph_from(vec![
727                make_node(0, &[]),
728                make_node(1, &[0]),
729                make_node(2, &[0]),
730                make_node(3, &[1, 2]),
731            ]),
732        ];
733
734        for g in &graphs {
735            let expected = TopologyClassifier::classify(g);
736            // Compute depths the same way analyze() does, then call classify_with_depths.
737            let tasks = &g.tasks;
738            let (longest, depths) = if tasks.is_empty() {
739                (0, std::collections::HashMap::new())
740            } else {
741                // Use the public API path via analyze to get depths.
742                let cfg = default_config();
743                let analysis = TopologyClassifier::analyze(g, &cfg);
744                (analysis.depth, analysis.depths)
745            };
746            let actual = TopologyClassifier::classify_with_depths(g, longest, &depths);
747            assert_eq!(
748                actual,
749                expected,
750                "classify_with_depths mismatch for graph with {} tasks",
751                g.tasks.len()
752            );
753        }
754    }
755
756    // --- compute_max_parallel tests ---
757
758    #[test]
759    fn compute_max_parallel_all_parallel_returns_base() {
760        assert_eq!(
761            TopologyClassifier::compute_max_parallel(Topology::AllParallel, 8),
762            8
763        );
764    }
765
766    #[test]
767    fn compute_max_parallel_fan_out_returns_base() {
768        assert_eq!(
769            TopologyClassifier::compute_max_parallel(Topology::FanOut, 6),
770            6
771        );
772    }
773
774    #[test]
775    fn compute_max_parallel_fan_in_returns_base() {
776        assert_eq!(
777            TopologyClassifier::compute_max_parallel(Topology::FanIn, 4),
778            4
779        );
780    }
781
782    #[test]
783    fn compute_max_parallel_hierarchical_returns_base() {
784        assert_eq!(
785            TopologyClassifier::compute_max_parallel(Topology::Hierarchical, 10),
786            10
787        );
788    }
789
790    #[test]
791    fn compute_max_parallel_linear_chain_returns_one() {
792        assert_eq!(
793            TopologyClassifier::compute_max_parallel(Topology::LinearChain, 8),
794            1
795        );
796        assert_eq!(
797            TopologyClassifier::compute_max_parallel(Topology::LinearChain, 1),
798            1
799        );
800    }
801
802    #[test]
803    fn compute_max_parallel_mixed_is_half_plus_one() {
804        // base=4: (4/2+1).min(4).max(1) = 3
805        assert_eq!(
806            TopologyClassifier::compute_max_parallel(Topology::Mixed, 4),
807            3
808        );
809        // base=2: (2/2+1).min(2).max(1) = 2
810        assert_eq!(
811            TopologyClassifier::compute_max_parallel(Topology::Mixed, 2),
812            2
813        );
814        // base=1: (1/2+1).min(1).max(1) = 1
815        assert_eq!(
816            TopologyClassifier::compute_max_parallel(Topology::Mixed, 1),
817            1
818        );
819        // base=8: (8/2+1).min(8).max(1) = 5
820        assert_eq!(
821            TopologyClassifier::compute_max_parallel(Topology::Mixed, 8),
822            5
823        );
824    }
825}