Skip to main content

zeph_orchestration/
topology.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Heuristic topology classification for `TaskGraph` DAGs.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9use zeph_config::OrchestrationConfig;
10
11use super::graph::{TaskGraph, TaskId, TaskNode};
12
13/// Structural classification of a `TaskGraph`.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum Topology {
17    /// All tasks are independent (zero edges). Max parallelism applies.
18    AllParallel,
19    /// Strict linear chain: each task depends on exactly the previous one.
20    LinearChain,
21    /// Single root fans out to multiple independent leaves.
22    FanOut,
23    /// Multiple independent roots converge to a single sink node. Dual of `FanOut`.
24    ///
25    /// Detection: single node with in-degree >= 2 that is the sole non-root sink,
26    /// all other nodes are roots (in-degree 0).
27    FanIn,
28    /// Multi-level DAG with fan-out at multiple depths (tree-like structure).
29    ///
30    /// Detection: single root, `longest_path` >= 2, max in-degree == 1 for all non-root nodes.
31    Hierarchical,
32    /// None of the above; mixed dependency patterns.
33    Mixed,
34}
35
36/// How the scheduler should dispatch tasks based on topology analysis.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum DispatchStrategy {
40    /// Dispatch all ready tasks immediately up to `max_parallel`.
41    ///
42    /// Used for: `AllParallel`, `FanOut`, `FanIn`.
43    FullParallel,
44    /// Dispatch tasks one at a time in dependency order.
45    ///
46    /// Used for: `LinearChain`.
47    Sequential,
48    /// Dispatch tasks level-by-level with a barrier between levels.
49    ///
50    /// Used for: `Hierarchical`.
51    LevelBarrier,
52    /// Mix of parallel and sequential based on local subgraph structure.
53    ///
54    /// Scheduler falls back to default ready-task dispatch with conservative parallelism.
55    /// Used for: `Mixed`.
56    Adaptive,
57}
58
59/// Complete topology analysis result computed in a single O(|V|+|E|) pass.
60#[derive(Debug, Clone)]
61pub struct TopologyAnalysis {
62    pub topology: Topology,
63    pub strategy: DispatchStrategy,
64    pub max_parallel: usize,
65    /// Longest path in the DAG (critical path length).
66    pub depth: usize,
67    /// Per-task depth from root (BFS level). Used by `LevelBarrier` dispatch.
68    ///
69    /// Uses `HashMap` so new tasks injected via `inject_tasks()` can be
70    /// added without index-out-of-bounds on `Vec` access (critic S3).
71    pub depths: HashMap<TaskId, usize>,
72}
73
74/// Stateless DAG topology classifier.
75pub struct TopologyClassifier;
76
77impl TopologyClassifier {
78    /// Classify the topology of a `TaskGraph`.
79    ///
80    /// Empty graphs return `AllParallel` (no constraints).
81    ///
82    /// Calls `compute_longest_path_and_depths` once and delegates to
83    /// [`classify_with_depths`]. Use [`classify_with_depths`] directly when
84    /// depths have already been computed to avoid redundant work.
85    #[must_use]
86    pub fn classify(graph: &TaskGraph) -> Topology {
87        let tasks = &graph.tasks;
88        if tasks.is_empty() {
89            return Topology::AllParallel;
90        }
91        // Early exit for no-edge graphs avoids the toposort in classify_with_depths.
92        let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
93        if edge_count == 0 {
94            return Topology::AllParallel;
95        }
96        let (longest, depths) = compute_longest_path_and_depths(tasks);
97        Self::classify_with_depths(graph, longest, &depths)
98    }
99
100    /// Classify the topology of a `TaskGraph` using pre-computed depth values.
101    ///
102    /// Accepts the `longest_path` and per-task `depths` map produced by a prior
103    /// call to `compute_longest_path_and_depths` (or `compute_depths_for_scheduler`),
104    /// avoiding a redundant toposort pass when those values are already available.
105    #[must_use]
106    pub fn classify_with_depths(
107        graph: &TaskGraph,
108        longest_path: usize,
109        // NOTE: depths is reserved for future heuristics (e.g. per-level density).
110        // Classification currently only needs longest_path and structural edge counts.
111        _depths: &HashMap<TaskId, usize>,
112    ) -> Topology {
113        let tasks = &graph.tasks;
114        let n = tasks.len();
115
116        if n == 0 {
117            return Topology::AllParallel;
118        }
119
120        let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
121
122        if edge_count == 0 {
123            return Topology::AllParallel;
124        }
125
126        // Linear chain: exactly n-1 edges and longest path = n-1.
127        // depths is already computed — check longest_path directly instead of re-computing.
128        if edge_count == n - 1 && longest_path == n - 1 {
129            return Topology::LinearChain;
130        }
131
132        let roots_count = tasks.iter().filter(|t| t.depends_on.is_empty()).count();
133
134        // Fan-out: single root, max depth == 1 (root + one layer of leaves only).
135        if roots_count == 1 && longest_path == 1 {
136            return Topology::FanOut;
137        }
138
139        // FanIn: multiple roots converge to exactly one sink.
140        // The sink has >= 2 dependencies (dep_count >= 2). All other nodes are roots.
141        // Depth must be exactly 1.
142        let non_roots_count = tasks.iter().filter(|t| !t.depends_on.is_empty()).count();
143        if roots_count >= 2 && non_roots_count == 1 && longest_path == 1 {
144            let sink_dep_count = tasks
145                .iter()
146                .filter(|t| !t.depends_on.is_empty())
147                .map(|t| t.depends_on.len())
148                .next()
149                .unwrap_or(0);
150            if sink_dep_count >= 2 {
151                return Topology::FanIn;
152            }
153        }
154
155        // Hierarchical: single root, depth >= 2, max in-degree (dep_count) == 1 for all nodes
156        // (tree-like: no node has multiple parents — ensures no diamond patterns).
157        if roots_count == 1 && longest_path >= 2 {
158            let max_dep_count = tasks.iter().map(|t| t.depends_on.len()).max().unwrap_or(0);
159            if max_dep_count <= 1 {
160                return Topology::Hierarchical;
161            }
162        }
163
164        Topology::Mixed
165    }
166
167    /// Compute the effective `max_parallel` for a given topology and configured base.
168    ///
169    /// Encapsulates the topology-to-parallelism policy in one place so that
170    /// `analyze()` and the `tick()` dirty-reanalysis path use identical logic.
171    ///
172    /// `base` must be the immutable config value (`config.max_parallel`), never a
173    /// previously reduced `self.max_parallel`, to prevent drift across replan cycles.
174    #[must_use]
175    pub fn compute_max_parallel(topology: Topology, base: usize) -> usize {
176        match topology {
177            Topology::AllParallel | Topology::FanOut | Topology::FanIn | Topology::Hierarchical => {
178                base
179            }
180            Topology::LinearChain => 1,
181            Topology::Mixed => (base / 2 + 1).min(base).max(1),
182        }
183    }
184
185    /// Map a `Topology` variant to the appropriate `DispatchStrategy`.
186    #[must_use]
187    pub fn strategy(topology: Topology) -> DispatchStrategy {
188        match topology {
189            Topology::AllParallel | Topology::FanOut | Topology::FanIn => {
190                DispatchStrategy::FullParallel
191            }
192            Topology::LinearChain => DispatchStrategy::Sequential,
193            Topology::Hierarchical => DispatchStrategy::LevelBarrier,
194            Topology::Mixed => DispatchStrategy::Adaptive,
195        }
196    }
197
198    /// Compute a complete `TopologyAnalysis` in a single O(|V|+|E|) pass.
199    ///
200    /// When `topology_selection` is disabled in config, returns a default
201    /// `FullParallel` analysis with config's `max_parallel` — zero overhead.
202    ///
203    /// # Performance
204    ///
205    /// Uses a single Kahn's toposort pass to compute both topology classification
206    /// and per-task depths simultaneously.
207    #[must_use]
208    pub fn analyze(graph: &TaskGraph, config: &OrchestrationConfig) -> TopologyAnalysis {
209        let tasks = &graph.tasks;
210        let n = tasks.len();
211
212        if !config.topology_selection || n == 0 {
213            return TopologyAnalysis {
214                topology: Topology::AllParallel,
215                strategy: DispatchStrategy::FullParallel,
216                max_parallel: config.max_parallel as usize,
217                depth: 0,
218                depths: HashMap::new(),
219            };
220        }
221
222        let (longest, depths) = compute_longest_path_and_depths(tasks);
223        let topology = Self::classify_with_depths(graph, longest, &depths);
224        let strategy = Self::strategy(topology);
225        let base = config.max_parallel as usize;
226        let max_parallel = Self::compute_max_parallel(topology, base);
227
228        TopologyAnalysis {
229            topology,
230            strategy,
231            max_parallel,
232            depth: longest,
233            depths,
234        }
235    }
236}
237
238/// Compute depths for the scheduler's dirty re-analysis path.
239///
240/// Thin wrapper around `compute_longest_path_and_depths` for use by `DagScheduler::tick()`
241/// when `topology_dirty=true`.
242pub(crate) fn compute_depths_for_scheduler(
243    graph: &TaskGraph,
244) -> (usize, std::collections::HashMap<TaskId, usize>) {
245    compute_longest_path_and_depths(&graph.tasks)
246}
247
248/// Compute the longest path and per-task depth map using Kahn's toposort.
249///
250/// Returns `(longest_path, depths_map)` where `depths_map[task_id] = depth_from_root`.
251///
252/// Single O(|V|+|E|) pass. Assumes a validated DAG (no cycles).
253fn compute_longest_path_and_depths(tasks: &[TaskNode]) -> (usize, HashMap<TaskId, usize>) {
254    let n = tasks.len();
255    if n == 0 {
256        return (0, HashMap::new());
257    }
258
259    let mut in_degree = vec![0usize; n];
260    let mut dependents: Vec<Vec<usize>> = vec![Vec::new(); n];
261    for task in tasks {
262        let i = task.id.index();
263        in_degree[i] = task.depends_on.len();
264        for dep in &task.depends_on {
265            dependents[dep.index()].push(i);
266        }
267    }
268
269    let mut queue: std::collections::VecDeque<usize> = in_degree
270        .iter()
271        .enumerate()
272        .filter(|(_, d)| **d == 0)
273        .map(|(i, _)| i)
274        .collect();
275
276    let mut dist = vec![0usize; n];
277    let mut max_dist = 0usize;
278
279    while let Some(u) = queue.pop_front() {
280        for &v in &dependents[u] {
281            let new_dist = dist[u] + 1;
282            if new_dist > dist[v] {
283                dist[v] = new_dist;
284            }
285            if dist[v] > max_dist {
286                max_dist = dist[v];
287            }
288            in_degree[v] -= 1;
289            if in_degree[v] == 0 {
290                queue.push_back(v);
291            }
292        }
293    }
294
295    let depths: HashMap<TaskId, usize> = tasks.iter().map(|t| (t.id, dist[t.id.index()])).collect();
296
297    (max_dist, depths)
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::graph::{TaskGraph, TaskId, TaskNode};
304
305    fn make_node(id: u32, deps: &[u32]) -> TaskNode {
306        let mut n = TaskNode::new(id, format!("t{id}"), "desc");
307        n.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
308        n
309    }
310
311    fn graph_from(nodes: Vec<TaskNode>) -> TaskGraph {
312        let mut g = TaskGraph::new("test");
313        g.tasks = nodes;
314        g
315    }
316
317    fn default_config() -> zeph_config::OrchestrationConfig {
318        zeph_config::OrchestrationConfig {
319            topology_selection: true,
320            max_parallel: 4,
321            ..zeph_config::OrchestrationConfig::default()
322        }
323    }
324
325    // --- classify tests ---
326
327    #[test]
328    fn classify_empty_graph() {
329        let g = graph_from(vec![]);
330        assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
331    }
332
333    #[test]
334    fn classify_single_task() {
335        let g = graph_from(vec![make_node(0, &[])]);
336        assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
337    }
338
339    #[test]
340    fn classify_all_parallel() {
341        let g = graph_from(vec![
342            make_node(0, &[]),
343            make_node(1, &[]),
344            make_node(2, &[]),
345        ]);
346        assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
347    }
348
349    #[test]
350    fn classify_two_task_chain() {
351        // A(0) -> B(1)
352        let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
353        assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
354    }
355
356    #[test]
357    fn classify_linear_chain() {
358        // A(0) -> B(1) -> C(2)
359        let g = graph_from(vec![
360            make_node(0, &[]),
361            make_node(1, &[0]),
362            make_node(2, &[1]),
363        ]);
364        assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
365    }
366
367    #[test]
368    fn classify_fan_out() {
369        // A(0) -> {B(1), C(2), D(3)}
370        let g = graph_from(vec![
371            make_node(0, &[]),
372            make_node(1, &[0]),
373            make_node(2, &[0]),
374            make_node(3, &[0]),
375        ]);
376        assert_eq!(TopologyClassifier::classify(&g), Topology::FanOut);
377    }
378
379    #[test]
380    fn classify_fan_in() {
381        // A(0), B(1), C(2) all -> D(3): multiple roots, single sink
382        let g = graph_from(vec![
383            make_node(0, &[]),
384            make_node(1, &[]),
385            make_node(2, &[]),
386            make_node(3, &[0, 1, 2]),
387        ]);
388        assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
389    }
390
391    #[test]
392    fn classify_fan_in_two_roots() {
393        // A(0), B(1) -> C(2)
394        let g = graph_from(vec![
395            make_node(0, &[]),
396            make_node(1, &[]),
397            make_node(2, &[0, 1]),
398        ]);
399        assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
400    }
401
402    #[test]
403    fn classify_hierarchical() {
404        // A(0) -> {B(1), C(2)}, B(1) -> D(3), C(2) -> E(4)
405        // Single root, depth=2, max in-degree=1 for non-roots
406        let g = graph_from(vec![
407            make_node(0, &[]),
408            make_node(1, &[0]),
409            make_node(2, &[0]),
410            make_node(3, &[1]),
411            make_node(4, &[2]),
412        ]);
413        assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
414    }
415
416    #[test]
417    fn classify_hierarchical_three_levels() {
418        // A(0) -> B(1) -> C(2) -> D(3): linear but single root, depth=3, in-degree<=1 => Hierarchical?
419        // No — linear chain is caught first (n-1 edges, longest=n-1). This tests a tree.
420        // A(0) -> {B(1), C(2)}, B(1) -> D(3)
421        let g = graph_from(vec![
422            make_node(0, &[]),
423            make_node(1, &[0]),
424            make_node(2, &[0]),
425            make_node(3, &[1]),
426        ]);
427        assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
428    }
429
430    #[test]
431    fn classify_diamond_is_mixed() {
432        // A(0) -> {B(1), C(2)} -> D(3)
433        let g = graph_from(vec![
434            make_node(0, &[]),
435            make_node(1, &[0]),
436            make_node(2, &[0]),
437            make_node(3, &[1, 2]),
438        ]);
439        assert_eq!(TopologyClassifier::classify(&g), Topology::Mixed);
440    }
441
442    #[test]
443    fn classify_fan_out_with_chain_on_branch_is_hierarchical() {
444        // A(0) -> {B(1), C(2)}, B(1) -> D(3) — single root, depth=2, all in-degrees <= 1 → Hierarchical
445        let g = graph_from(vec![
446            make_node(0, &[]),
447            make_node(1, &[0]),
448            make_node(2, &[0]),
449            make_node(3, &[1]),
450        ]);
451        assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
452    }
453
454    // --- strategy tests ---
455
456    #[test]
457    fn strategy_all_parallel_is_full_parallel() {
458        assert_eq!(
459            TopologyClassifier::strategy(Topology::AllParallel),
460            DispatchStrategy::FullParallel
461        );
462    }
463
464    #[test]
465    fn strategy_fan_out_is_full_parallel() {
466        assert_eq!(
467            TopologyClassifier::strategy(Topology::FanOut),
468            DispatchStrategy::FullParallel
469        );
470    }
471
472    #[test]
473    fn strategy_fan_in_is_full_parallel() {
474        assert_eq!(
475            TopologyClassifier::strategy(Topology::FanIn),
476            DispatchStrategy::FullParallel
477        );
478    }
479
480    #[test]
481    fn strategy_linear_chain_is_sequential() {
482        assert_eq!(
483            TopologyClassifier::strategy(Topology::LinearChain),
484            DispatchStrategy::Sequential
485        );
486    }
487
488    #[test]
489    fn strategy_hierarchical_is_level_barrier() {
490        assert_eq!(
491            TopologyClassifier::strategy(Topology::Hierarchical),
492            DispatchStrategy::LevelBarrier
493        );
494    }
495
496    #[test]
497    fn strategy_mixed_is_adaptive() {
498        assert_eq!(
499            TopologyClassifier::strategy(Topology::Mixed),
500            DispatchStrategy::Adaptive
501        );
502    }
503
504    // --- analyze tests ---
505
506    #[test]
507    fn analyze_disabled_returns_full_parallel() {
508        let mut cfg = default_config();
509        cfg.topology_selection = false;
510        let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
511        let analysis = TopologyClassifier::analyze(&g, &cfg);
512        assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
513        assert_eq!(analysis.max_parallel, 4);
514        assert_eq!(analysis.topology, Topology::AllParallel);
515    }
516
517    #[test]
518    fn analyze_linear_chain_returns_sequential() {
519        let cfg = default_config();
520        let g = graph_from(vec![
521            make_node(0, &[]),
522            make_node(1, &[0]),
523            make_node(2, &[1]),
524        ]);
525        let analysis = TopologyClassifier::analyze(&g, &cfg);
526        assert_eq!(analysis.topology, Topology::LinearChain);
527        assert_eq!(analysis.strategy, DispatchStrategy::Sequential);
528        assert_eq!(analysis.max_parallel, 1);
529        assert_eq!(analysis.depth, 2);
530    }
531
532    #[test]
533    fn analyze_hierarchical_returns_level_barrier() {
534        let cfg = default_config();
535        // A(0) -> {B(1), C(2)}, B(1) -> D(3)
536        let g = graph_from(vec![
537            make_node(0, &[]),
538            make_node(1, &[0]),
539            make_node(2, &[0]),
540            make_node(3, &[1]),
541        ]);
542        let analysis = TopologyClassifier::analyze(&g, &cfg);
543        assert_eq!(analysis.topology, Topology::Hierarchical);
544        assert_eq!(analysis.strategy, DispatchStrategy::LevelBarrier);
545        assert_eq!(analysis.max_parallel, 4);
546        assert_eq!(analysis.depth, 2);
547        // Verify depths
548        assert_eq!(analysis.depths[&TaskId(0)], 0);
549        assert_eq!(analysis.depths[&TaskId(1)], 1);
550        assert_eq!(analysis.depths[&TaskId(2)], 1);
551        assert_eq!(analysis.depths[&TaskId(3)], 2);
552    }
553
554    #[test]
555    fn analyze_fan_in_returns_full_parallel() {
556        let cfg = default_config();
557        // A(0), B(1), C(2) -> D(3)
558        let g = graph_from(vec![
559            make_node(0, &[]),
560            make_node(1, &[]),
561            make_node(2, &[]),
562            make_node(3, &[0, 1, 2]),
563        ]);
564        let analysis = TopologyClassifier::analyze(&g, &cfg);
565        assert_eq!(analysis.topology, Topology::FanIn);
566        assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
567        assert_eq!(analysis.max_parallel, 4);
568    }
569
570    #[test]
571    fn analyze_mixed_is_conservative() {
572        let cfg = default_config(); // max_parallel=4 -> (4/2+1).min(4).max(1) = 3
573        let g = graph_from(vec![
574            make_node(0, &[]),
575            make_node(1, &[0]),
576            make_node(2, &[0]),
577            make_node(3, &[1, 2]),
578        ]);
579        let analysis = TopologyClassifier::analyze(&g, &cfg);
580        assert_eq!(analysis.topology, Topology::Mixed);
581        assert_eq!(analysis.strategy, DispatchStrategy::Adaptive);
582        assert_eq!(analysis.max_parallel, 3);
583    }
584
585    #[test]
586    fn analyze_depths_correct_for_fan_out() {
587        let cfg = default_config();
588        // A(0) -> {B(1), C(2), D(3)}
589        let g = graph_from(vec![
590            make_node(0, &[]),
591            make_node(1, &[0]),
592            make_node(2, &[0]),
593            make_node(3, &[0]),
594        ]);
595        let analysis = TopologyClassifier::analyze(&g, &cfg);
596        assert_eq!(analysis.depths[&TaskId(0)], 0);
597        assert_eq!(analysis.depths[&TaskId(1)], 1);
598        assert_eq!(analysis.depths[&TaskId(2)], 1);
599        assert_eq!(analysis.depths[&TaskId(3)], 1);
600    }
601
602    #[test]
603    fn analyze_mixed_respects_max_parallel_one() {
604        let mut cfg = default_config();
605        cfg.max_parallel = 1;
606        let g = graph_from(vec![
607            make_node(0, &[]),
608            make_node(1, &[0]),
609            make_node(2, &[0]),
610            make_node(3, &[1, 2]),
611        ]);
612        let analysis = TopologyClassifier::analyze(&g, &cfg);
613        assert_eq!(analysis.max_parallel, 1);
614    }
615
616    // --- classify_with_depths tests ---
617
618    #[test]
619    fn classify_with_depths_matches_classify_for_all_variants() {
620        let graphs = vec![
621            // AllParallel
622            graph_from(vec![
623                make_node(0, &[]),
624                make_node(1, &[]),
625                make_node(2, &[]),
626            ]),
627            // LinearChain
628            graph_from(vec![
629                make_node(0, &[]),
630                make_node(1, &[0]),
631                make_node(2, &[1]),
632            ]),
633            // FanOut
634            graph_from(vec![
635                make_node(0, &[]),
636                make_node(1, &[0]),
637                make_node(2, &[0]),
638                make_node(3, &[0]),
639            ]),
640            // FanIn
641            graph_from(vec![
642                make_node(0, &[]),
643                make_node(1, &[]),
644                make_node(2, &[]),
645                make_node(3, &[0, 1, 2]),
646            ]),
647            // Hierarchical
648            graph_from(vec![
649                make_node(0, &[]),
650                make_node(1, &[0]),
651                make_node(2, &[0]),
652                make_node(3, &[1]),
653            ]),
654            // Mixed (diamond)
655            graph_from(vec![
656                make_node(0, &[]),
657                make_node(1, &[0]),
658                make_node(2, &[0]),
659                make_node(3, &[1, 2]),
660            ]),
661        ];
662
663        for g in &graphs {
664            let expected = TopologyClassifier::classify(g);
665            // Compute depths the same way analyze() does, then call classify_with_depths.
666            let tasks = &g.tasks;
667            let (longest, depths) = if tasks.is_empty() {
668                (0, std::collections::HashMap::new())
669            } else {
670                // Use the public API path via analyze to get depths.
671                let cfg = default_config();
672                let analysis = TopologyClassifier::analyze(g, &cfg);
673                (analysis.depth, analysis.depths)
674            };
675            let actual = TopologyClassifier::classify_with_depths(g, longest, &depths);
676            assert_eq!(
677                actual,
678                expected,
679                "classify_with_depths mismatch for graph with {} tasks",
680                g.tasks.len()
681            );
682        }
683    }
684
685    // --- compute_max_parallel tests ---
686
687    #[test]
688    fn compute_max_parallel_all_parallel_returns_base() {
689        assert_eq!(
690            TopologyClassifier::compute_max_parallel(Topology::AllParallel, 8),
691            8
692        );
693    }
694
695    #[test]
696    fn compute_max_parallel_fan_out_returns_base() {
697        assert_eq!(
698            TopologyClassifier::compute_max_parallel(Topology::FanOut, 6),
699            6
700        );
701    }
702
703    #[test]
704    fn compute_max_parallel_fan_in_returns_base() {
705        assert_eq!(
706            TopologyClassifier::compute_max_parallel(Topology::FanIn, 4),
707            4
708        );
709    }
710
711    #[test]
712    fn compute_max_parallel_hierarchical_returns_base() {
713        assert_eq!(
714            TopologyClassifier::compute_max_parallel(Topology::Hierarchical, 10),
715            10
716        );
717    }
718
719    #[test]
720    fn compute_max_parallel_linear_chain_returns_one() {
721        assert_eq!(
722            TopologyClassifier::compute_max_parallel(Topology::LinearChain, 8),
723            1
724        );
725        assert_eq!(
726            TopologyClassifier::compute_max_parallel(Topology::LinearChain, 1),
727            1
728        );
729    }
730
731    #[test]
732    fn compute_max_parallel_mixed_is_half_plus_one() {
733        // base=4: (4/2+1).min(4).max(1) = 3
734        assert_eq!(
735            TopologyClassifier::compute_max_parallel(Topology::Mixed, 4),
736            3
737        );
738        // base=2: (2/2+1).min(2).max(1) = 2
739        assert_eq!(
740            TopologyClassifier::compute_max_parallel(Topology::Mixed, 2),
741            2
742        );
743        // base=1: (1/2+1).min(1).max(1) = 1
744        assert_eq!(
745            TopologyClassifier::compute_max_parallel(Topology::Mixed, 1),
746            1
747        );
748        // base=8: (8/2+1).min(8).max(1) = 5
749        assert_eq!(
750            TopologyClassifier::compute_max_parallel(Topology::Mixed, 8),
751            5
752        );
753    }
754}