1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9use zeph_config::OrchestrationConfig;
10
11use super::graph::{TaskGraph, TaskId, TaskNode};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum Topology {
17 AllParallel,
19 LinearChain,
21 FanOut,
23 FanIn,
28 Hierarchical,
32 Mixed,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum DispatchStrategy {
40 FullParallel,
44 Sequential,
48 LevelBarrier,
52 Adaptive,
57 TreeOptimized,
62 CascadeAware,
66}
67
68#[derive(Debug, Clone)]
70pub struct TopologyAnalysis {
71 pub topology: Topology,
72 pub strategy: DispatchStrategy,
73 pub max_parallel: usize,
74 pub depth: usize,
76 pub depths: HashMap<TaskId, usize>,
81}
82
83pub struct TopologyClassifier;
85
86impl TopologyClassifier {
87 #[must_use]
95 pub fn classify(graph: &TaskGraph) -> Topology {
96 let tasks = &graph.tasks;
97 if tasks.is_empty() {
98 return Topology::AllParallel;
99 }
100 let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
102 if edge_count == 0 {
103 return Topology::AllParallel;
104 }
105 let (longest, depths) = compute_longest_path_and_depths(tasks);
106 Self::classify_with_depths(graph, longest, &depths)
107 }
108
109 #[must_use]
115 pub fn classify_with_depths(
116 graph: &TaskGraph,
117 longest_path: usize,
118 _depths: &HashMap<TaskId, usize>,
121 ) -> Topology {
122 let tasks = &graph.tasks;
123 let n = tasks.len();
124
125 if n == 0 {
126 return Topology::AllParallel;
127 }
128
129 let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
130
131 if edge_count == 0 {
132 return Topology::AllParallel;
133 }
134
135 if edge_count == n - 1 && longest_path == n - 1 {
138 return Topology::LinearChain;
139 }
140
141 let roots_count = tasks.iter().filter(|t| t.depends_on.is_empty()).count();
142
143 if roots_count == 1 && longest_path == 1 {
145 return Topology::FanOut;
146 }
147
148 let non_roots_count = tasks.iter().filter(|t| !t.depends_on.is_empty()).count();
152 if roots_count >= 2 && non_roots_count == 1 && longest_path == 1 {
153 let sink_dep_count = tasks
154 .iter()
155 .filter(|t| !t.depends_on.is_empty())
156 .map(|t| t.depends_on.len())
157 .next()
158 .unwrap_or(0);
159 if sink_dep_count >= 2 {
160 return Topology::FanIn;
161 }
162 }
163
164 if roots_count == 1 && longest_path >= 2 {
167 let max_dep_count = tasks.iter().map(|t| t.depends_on.len()).max().unwrap_or(0);
168 if max_dep_count <= 1 {
169 return Topology::Hierarchical;
170 }
171 }
172
173 Topology::Mixed
174 }
175
176 #[must_use]
184 pub fn compute_max_parallel(topology: Topology, base: usize) -> usize {
185 match topology {
186 Topology::AllParallel | Topology::FanOut | Topology::FanIn | Topology::Hierarchical => {
187 base
188 }
189 Topology::LinearChain => 1,
190 Topology::Mixed => (base / 2 + 1).min(base).max(1),
191 }
192 }
193
194 #[must_use]
202 pub fn strategy(topology: Topology, config: &OrchestrationConfig) -> DispatchStrategy {
203 match topology {
204 Topology::FanOut | Topology::FanIn if config.tree_optimized_dispatch => {
205 DispatchStrategy::TreeOptimized
206 }
207 Topology::Mixed if config.cascade_routing => DispatchStrategy::CascadeAware,
208 Topology::AllParallel | Topology::FanOut | Topology::FanIn => {
209 DispatchStrategy::FullParallel
210 }
211 Topology::LinearChain => DispatchStrategy::Sequential,
212 Topology::Hierarchical => DispatchStrategy::LevelBarrier,
213 Topology::Mixed => DispatchStrategy::Adaptive,
214 }
215 }
216
217 #[must_use]
227 pub fn analyze(graph: &TaskGraph, config: &OrchestrationConfig) -> TopologyAnalysis {
228 let tasks = &graph.tasks;
229 let n = tasks.len();
230
231 if !config.topology_selection || n == 0 {
232 return TopologyAnalysis {
233 topology: Topology::AllParallel,
234 strategy: DispatchStrategy::FullParallel,
235 max_parallel: config.max_parallel as usize,
236 depth: 0,
237 depths: HashMap::new(),
238 };
239 }
240
241 let (longest, depths) = compute_longest_path_and_depths(tasks);
242 let topology = Self::classify_with_depths(graph, longest, &depths);
243 let strategy = Self::strategy(topology, config);
244 let base = config.max_parallel as usize;
245 let max_parallel = Self::compute_max_parallel(topology, base);
246
247 TopologyAnalysis {
248 topology,
249 strategy,
250 max_parallel,
251 depth: longest,
252 depths,
253 }
254 }
255}
256
257pub(crate) fn compute_depths_for_scheduler(
262 graph: &TaskGraph,
263) -> (usize, std::collections::HashMap<TaskId, usize>) {
264 compute_longest_path_and_depths(&graph.tasks)
265}
266
267fn compute_longest_path_and_depths(tasks: &[TaskNode]) -> (usize, HashMap<TaskId, usize>) {
273 let n = tasks.len();
274 if n == 0 {
275 return (0, HashMap::new());
276 }
277
278 let mut in_degree = vec![0usize; n];
279 let mut dependents: Vec<Vec<usize>> = vec![Vec::new(); n];
280 for task in tasks {
281 let i = task.id.index();
282 in_degree[i] = task.depends_on.len();
283 for dep in &task.depends_on {
284 dependents[dep.index()].push(i);
285 }
286 }
287
288 let mut queue: std::collections::VecDeque<usize> = in_degree
289 .iter()
290 .enumerate()
291 .filter(|(_, d)| **d == 0)
292 .map(|(i, _)| i)
293 .collect();
294
295 let mut dist = vec![0usize; n];
296 let mut max_dist = 0usize;
297
298 while let Some(u) = queue.pop_front() {
299 for &v in &dependents[u] {
300 let new_dist = dist[u] + 1;
301 if new_dist > dist[v] {
302 dist[v] = new_dist;
303 }
304 if dist[v] > max_dist {
305 max_dist = dist[v];
306 }
307 in_degree[v] -= 1;
308 if in_degree[v] == 0 {
309 queue.push_back(v);
310 }
311 }
312 }
313
314 let depths: HashMap<TaskId, usize> = tasks.iter().map(|t| (t.id, dist[t.id.index()])).collect();
315
316 (max_dist, depths)
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::graph::{TaskGraph, TaskId, TaskNode};
323
324 fn make_node(id: u32, deps: &[u32]) -> TaskNode {
325 let mut n = TaskNode::new(id, format!("t{id}"), "desc");
326 n.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
327 n
328 }
329
330 fn graph_from(nodes: Vec<TaskNode>) -> TaskGraph {
331 let mut g = TaskGraph::new("test");
332 g.tasks = nodes;
333 g
334 }
335
336 fn default_config() -> zeph_config::OrchestrationConfig {
337 zeph_config::OrchestrationConfig {
338 topology_selection: true,
339 max_parallel: 4,
340 ..zeph_config::OrchestrationConfig::default()
341 }
342 }
343
344 #[test]
347 fn classify_empty_graph() {
348 let g = graph_from(vec![]);
349 assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
350 }
351
352 #[test]
353 fn classify_single_task() {
354 let g = graph_from(vec![make_node(0, &[])]);
355 assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
356 }
357
358 #[test]
359 fn classify_all_parallel() {
360 let g = graph_from(vec![
361 make_node(0, &[]),
362 make_node(1, &[]),
363 make_node(2, &[]),
364 ]);
365 assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
366 }
367
368 #[test]
369 fn classify_two_task_chain() {
370 let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
372 assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
373 }
374
375 #[test]
376 fn classify_linear_chain() {
377 let g = graph_from(vec![
379 make_node(0, &[]),
380 make_node(1, &[0]),
381 make_node(2, &[1]),
382 ]);
383 assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
384 }
385
386 #[test]
387 fn classify_fan_out() {
388 let g = graph_from(vec![
390 make_node(0, &[]),
391 make_node(1, &[0]),
392 make_node(2, &[0]),
393 make_node(3, &[0]),
394 ]);
395 assert_eq!(TopologyClassifier::classify(&g), Topology::FanOut);
396 }
397
398 #[test]
399 fn classify_fan_in() {
400 let g = graph_from(vec![
402 make_node(0, &[]),
403 make_node(1, &[]),
404 make_node(2, &[]),
405 make_node(3, &[0, 1, 2]),
406 ]);
407 assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
408 }
409
410 #[test]
411 fn classify_fan_in_two_roots() {
412 let g = graph_from(vec![
414 make_node(0, &[]),
415 make_node(1, &[]),
416 make_node(2, &[0, 1]),
417 ]);
418 assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
419 }
420
421 #[test]
422 fn classify_hierarchical() {
423 let g = graph_from(vec![
426 make_node(0, &[]),
427 make_node(1, &[0]),
428 make_node(2, &[0]),
429 make_node(3, &[1]),
430 make_node(4, &[2]),
431 ]);
432 assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
433 }
434
435 #[test]
436 fn classify_hierarchical_three_levels() {
437 let g = graph_from(vec![
441 make_node(0, &[]),
442 make_node(1, &[0]),
443 make_node(2, &[0]),
444 make_node(3, &[1]),
445 ]);
446 assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
447 }
448
449 #[test]
450 fn classify_diamond_is_mixed() {
451 let g = graph_from(vec![
453 make_node(0, &[]),
454 make_node(1, &[0]),
455 make_node(2, &[0]),
456 make_node(3, &[1, 2]),
457 ]);
458 assert_eq!(TopologyClassifier::classify(&g), Topology::Mixed);
459 }
460
461 #[test]
462 fn classify_fan_out_with_chain_on_branch_is_hierarchical() {
463 let g = graph_from(vec![
465 make_node(0, &[]),
466 make_node(1, &[0]),
467 make_node(2, &[0]),
468 make_node(3, &[1]),
469 ]);
470 assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
471 }
472
473 fn no_overrides_config() -> zeph_config::OrchestrationConfig {
476 zeph_config::OrchestrationConfig {
477 topology_selection: true,
478 max_parallel: 4,
479 cascade_routing: false,
480 tree_optimized_dispatch: false,
481 ..zeph_config::OrchestrationConfig::default()
482 }
483 }
484
485 #[test]
486 fn strategy_all_parallel_is_full_parallel() {
487 assert_eq!(
488 TopologyClassifier::strategy(Topology::AllParallel, &no_overrides_config()),
489 DispatchStrategy::FullParallel
490 );
491 }
492
493 #[test]
494 fn strategy_fan_out_is_full_parallel() {
495 assert_eq!(
496 TopologyClassifier::strategy(Topology::FanOut, &no_overrides_config()),
497 DispatchStrategy::FullParallel
498 );
499 }
500
501 #[test]
502 fn strategy_fan_in_is_full_parallel() {
503 assert_eq!(
504 TopologyClassifier::strategy(Topology::FanIn, &no_overrides_config()),
505 DispatchStrategy::FullParallel
506 );
507 }
508
509 #[test]
510 fn strategy_linear_chain_is_sequential() {
511 assert_eq!(
512 TopologyClassifier::strategy(Topology::LinearChain, &no_overrides_config()),
513 DispatchStrategy::Sequential
514 );
515 }
516
517 #[test]
518 fn strategy_hierarchical_is_level_barrier() {
519 assert_eq!(
520 TopologyClassifier::strategy(Topology::Hierarchical, &no_overrides_config()),
521 DispatchStrategy::LevelBarrier
522 );
523 }
524
525 #[test]
526 fn strategy_mixed_is_adaptive() {
527 assert_eq!(
528 TopologyClassifier::strategy(Topology::Mixed, &no_overrides_config()),
529 DispatchStrategy::Adaptive
530 );
531 }
532
533 #[test]
534 fn strategy_fan_out_tree_optimized_when_enabled() {
535 let mut cfg = no_overrides_config();
536 cfg.tree_optimized_dispatch = true;
537 assert_eq!(
538 TopologyClassifier::strategy(Topology::FanOut, &cfg),
539 DispatchStrategy::TreeOptimized
540 );
541 assert_eq!(
542 TopologyClassifier::strategy(Topology::FanIn, &cfg),
543 DispatchStrategy::TreeOptimized
544 );
545 }
546
547 #[test]
548 fn strategy_mixed_cascade_aware_when_enabled() {
549 let mut cfg = no_overrides_config();
550 cfg.cascade_routing = true;
551 assert_eq!(
552 TopologyClassifier::strategy(Topology::Mixed, &cfg),
553 DispatchStrategy::CascadeAware
554 );
555 }
556
557 #[test]
558 fn strategy_tree_optimized_does_not_affect_non_fan_topologies() {
559 let mut cfg = no_overrides_config();
560 cfg.tree_optimized_dispatch = true;
561 assert_eq!(
562 TopologyClassifier::strategy(Topology::Hierarchical, &cfg),
563 DispatchStrategy::LevelBarrier
564 );
565 assert_eq!(
566 TopologyClassifier::strategy(Topology::LinearChain, &cfg),
567 DispatchStrategy::Sequential
568 );
569 assert_eq!(
570 TopologyClassifier::strategy(Topology::Mixed, &cfg),
571 DispatchStrategy::Adaptive
572 );
573 }
574
575 #[test]
578 fn analyze_disabled_returns_full_parallel() {
579 let mut cfg = default_config();
580 cfg.topology_selection = false;
581 let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
582 let analysis = TopologyClassifier::analyze(&g, &cfg);
583 assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
584 assert_eq!(analysis.max_parallel, 4);
585 assert_eq!(analysis.topology, Topology::AllParallel);
586 }
587
588 #[test]
589 fn analyze_linear_chain_returns_sequential() {
590 let cfg = default_config();
591 let g = graph_from(vec![
592 make_node(0, &[]),
593 make_node(1, &[0]),
594 make_node(2, &[1]),
595 ]);
596 let analysis = TopologyClassifier::analyze(&g, &cfg);
597 assert_eq!(analysis.topology, Topology::LinearChain);
598 assert_eq!(analysis.strategy, DispatchStrategy::Sequential);
599 assert_eq!(analysis.max_parallel, 1);
600 assert_eq!(analysis.depth, 2);
601 }
602
603 #[test]
604 fn analyze_hierarchical_returns_level_barrier() {
605 let cfg = default_config();
606 let g = graph_from(vec![
608 make_node(0, &[]),
609 make_node(1, &[0]),
610 make_node(2, &[0]),
611 make_node(3, &[1]),
612 ]);
613 let analysis = TopologyClassifier::analyze(&g, &cfg);
614 assert_eq!(analysis.topology, Topology::Hierarchical);
615 assert_eq!(analysis.strategy, DispatchStrategy::LevelBarrier);
616 assert_eq!(analysis.max_parallel, 4);
617 assert_eq!(analysis.depth, 2);
618 assert_eq!(analysis.depths[&TaskId(0)], 0);
620 assert_eq!(analysis.depths[&TaskId(1)], 1);
621 assert_eq!(analysis.depths[&TaskId(2)], 1);
622 assert_eq!(analysis.depths[&TaskId(3)], 2);
623 }
624
625 #[test]
626 fn analyze_fan_in_returns_full_parallel() {
627 let cfg = default_config();
628 let g = graph_from(vec![
630 make_node(0, &[]),
631 make_node(1, &[]),
632 make_node(2, &[]),
633 make_node(3, &[0, 1, 2]),
634 ]);
635 let analysis = TopologyClassifier::analyze(&g, &cfg);
636 assert_eq!(analysis.topology, Topology::FanIn);
637 assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
638 assert_eq!(analysis.max_parallel, 4);
639 }
640
641 #[test]
642 fn analyze_mixed_is_conservative() {
643 let cfg = default_config(); let g = graph_from(vec![
645 make_node(0, &[]),
646 make_node(1, &[0]),
647 make_node(2, &[0]),
648 make_node(3, &[1, 2]),
649 ]);
650 let analysis = TopologyClassifier::analyze(&g, &cfg);
651 assert_eq!(analysis.topology, Topology::Mixed);
652 assert_eq!(analysis.strategy, DispatchStrategy::Adaptive);
653 assert_eq!(analysis.max_parallel, 3);
654 }
655
656 #[test]
657 fn analyze_depths_correct_for_fan_out() {
658 let cfg = default_config();
659 let g = graph_from(vec![
661 make_node(0, &[]),
662 make_node(1, &[0]),
663 make_node(2, &[0]),
664 make_node(3, &[0]),
665 ]);
666 let analysis = TopologyClassifier::analyze(&g, &cfg);
667 assert_eq!(analysis.depths[&TaskId(0)], 0);
668 assert_eq!(analysis.depths[&TaskId(1)], 1);
669 assert_eq!(analysis.depths[&TaskId(2)], 1);
670 assert_eq!(analysis.depths[&TaskId(3)], 1);
671 }
672
673 #[test]
674 fn analyze_mixed_respects_max_parallel_one() {
675 let mut cfg = default_config();
676 cfg.max_parallel = 1;
677 let g = graph_from(vec![
678 make_node(0, &[]),
679 make_node(1, &[0]),
680 make_node(2, &[0]),
681 make_node(3, &[1, 2]),
682 ]);
683 let analysis = TopologyClassifier::analyze(&g, &cfg);
684 assert_eq!(analysis.max_parallel, 1);
685 }
686
687 #[test]
690 fn classify_with_depths_matches_classify_for_all_variants() {
691 let graphs = vec![
692 graph_from(vec![
694 make_node(0, &[]),
695 make_node(1, &[]),
696 make_node(2, &[]),
697 ]),
698 graph_from(vec![
700 make_node(0, &[]),
701 make_node(1, &[0]),
702 make_node(2, &[1]),
703 ]),
704 graph_from(vec![
706 make_node(0, &[]),
707 make_node(1, &[0]),
708 make_node(2, &[0]),
709 make_node(3, &[0]),
710 ]),
711 graph_from(vec![
713 make_node(0, &[]),
714 make_node(1, &[]),
715 make_node(2, &[]),
716 make_node(3, &[0, 1, 2]),
717 ]),
718 graph_from(vec![
720 make_node(0, &[]),
721 make_node(1, &[0]),
722 make_node(2, &[0]),
723 make_node(3, &[1]),
724 ]),
725 graph_from(vec![
727 make_node(0, &[]),
728 make_node(1, &[0]),
729 make_node(2, &[0]),
730 make_node(3, &[1, 2]),
731 ]),
732 ];
733
734 for g in &graphs {
735 let expected = TopologyClassifier::classify(g);
736 let tasks = &g.tasks;
738 let (longest, depths) = if tasks.is_empty() {
739 (0, std::collections::HashMap::new())
740 } else {
741 let cfg = default_config();
743 let analysis = TopologyClassifier::analyze(g, &cfg);
744 (analysis.depth, analysis.depths)
745 };
746 let actual = TopologyClassifier::classify_with_depths(g, longest, &depths);
747 assert_eq!(
748 actual,
749 expected,
750 "classify_with_depths mismatch for graph with {} tasks",
751 g.tasks.len()
752 );
753 }
754 }
755
756 #[test]
759 fn compute_max_parallel_all_parallel_returns_base() {
760 assert_eq!(
761 TopologyClassifier::compute_max_parallel(Topology::AllParallel, 8),
762 8
763 );
764 }
765
766 #[test]
767 fn compute_max_parallel_fan_out_returns_base() {
768 assert_eq!(
769 TopologyClassifier::compute_max_parallel(Topology::FanOut, 6),
770 6
771 );
772 }
773
774 #[test]
775 fn compute_max_parallel_fan_in_returns_base() {
776 assert_eq!(
777 TopologyClassifier::compute_max_parallel(Topology::FanIn, 4),
778 4
779 );
780 }
781
782 #[test]
783 fn compute_max_parallel_hierarchical_returns_base() {
784 assert_eq!(
785 TopologyClassifier::compute_max_parallel(Topology::Hierarchical, 10),
786 10
787 );
788 }
789
790 #[test]
791 fn compute_max_parallel_linear_chain_returns_one() {
792 assert_eq!(
793 TopologyClassifier::compute_max_parallel(Topology::LinearChain, 8),
794 1
795 );
796 assert_eq!(
797 TopologyClassifier::compute_max_parallel(Topology::LinearChain, 1),
798 1
799 );
800 }
801
802 #[test]
803 fn compute_max_parallel_mixed_is_half_plus_one() {
804 assert_eq!(
806 TopologyClassifier::compute_max_parallel(Topology::Mixed, 4),
807 3
808 );
809 assert_eq!(
811 TopologyClassifier::compute_max_parallel(Topology::Mixed, 2),
812 2
813 );
814 assert_eq!(
816 TopologyClassifier::compute_max_parallel(Topology::Mixed, 1),
817 1
818 );
819 assert_eq!(
821 TopologyClassifier::compute_max_parallel(Topology::Mixed, 8),
822 5
823 );
824 }
825}