1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9use zeph_config::OrchestrationConfig;
10
11use super::graph::{TaskGraph, TaskId, TaskNode};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum Topology {
17 AllParallel,
19 LinearChain,
21 FanOut,
23 FanIn,
28 Hierarchical,
32 Mixed,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum DispatchStrategy {
40 FullParallel,
44 Sequential,
48 LevelBarrier,
52 Adaptive,
57}
58
59#[derive(Debug, Clone)]
61pub struct TopologyAnalysis {
62 pub topology: Topology,
63 pub strategy: DispatchStrategy,
64 pub max_parallel: usize,
65 pub depth: usize,
67 pub depths: HashMap<TaskId, usize>,
72}
73
74pub struct TopologyClassifier;
76
77impl TopologyClassifier {
78 #[must_use]
86 pub fn classify(graph: &TaskGraph) -> Topology {
87 let tasks = &graph.tasks;
88 if tasks.is_empty() {
89 return Topology::AllParallel;
90 }
91 let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
93 if edge_count == 0 {
94 return Topology::AllParallel;
95 }
96 let (longest, depths) = compute_longest_path_and_depths(tasks);
97 Self::classify_with_depths(graph, longest, &depths)
98 }
99
100 #[must_use]
106 pub fn classify_with_depths(
107 graph: &TaskGraph,
108 longest_path: usize,
109 _depths: &HashMap<TaskId, usize>,
112 ) -> Topology {
113 let tasks = &graph.tasks;
114 let n = tasks.len();
115
116 if n == 0 {
117 return Topology::AllParallel;
118 }
119
120 let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
121
122 if edge_count == 0 {
123 return Topology::AllParallel;
124 }
125
126 if edge_count == n - 1 && longest_path == n - 1 {
129 return Topology::LinearChain;
130 }
131
132 let roots_count = tasks.iter().filter(|t| t.depends_on.is_empty()).count();
133
134 if roots_count == 1 && longest_path == 1 {
136 return Topology::FanOut;
137 }
138
139 let non_roots_count = tasks.iter().filter(|t| !t.depends_on.is_empty()).count();
143 if roots_count >= 2 && non_roots_count == 1 && longest_path == 1 {
144 let sink_dep_count = tasks
145 .iter()
146 .filter(|t| !t.depends_on.is_empty())
147 .map(|t| t.depends_on.len())
148 .next()
149 .unwrap_or(0);
150 if sink_dep_count >= 2 {
151 return Topology::FanIn;
152 }
153 }
154
155 if roots_count == 1 && longest_path >= 2 {
158 let max_dep_count = tasks.iter().map(|t| t.depends_on.len()).max().unwrap_or(0);
159 if max_dep_count <= 1 {
160 return Topology::Hierarchical;
161 }
162 }
163
164 Topology::Mixed
165 }
166
167 #[must_use]
175 pub fn compute_max_parallel(topology: Topology, base: usize) -> usize {
176 match topology {
177 Topology::AllParallel | Topology::FanOut | Topology::FanIn | Topology::Hierarchical => {
178 base
179 }
180 Topology::LinearChain => 1,
181 Topology::Mixed => (base / 2 + 1).min(base).max(1),
182 }
183 }
184
185 #[must_use]
187 pub fn strategy(topology: Topology) -> DispatchStrategy {
188 match topology {
189 Topology::AllParallel | Topology::FanOut | Topology::FanIn => {
190 DispatchStrategy::FullParallel
191 }
192 Topology::LinearChain => DispatchStrategy::Sequential,
193 Topology::Hierarchical => DispatchStrategy::LevelBarrier,
194 Topology::Mixed => DispatchStrategy::Adaptive,
195 }
196 }
197
198 #[must_use]
208 pub fn analyze(graph: &TaskGraph, config: &OrchestrationConfig) -> TopologyAnalysis {
209 let tasks = &graph.tasks;
210 let n = tasks.len();
211
212 if !config.topology_selection || n == 0 {
213 return TopologyAnalysis {
214 topology: Topology::AllParallel,
215 strategy: DispatchStrategy::FullParallel,
216 max_parallel: config.max_parallel as usize,
217 depth: 0,
218 depths: HashMap::new(),
219 };
220 }
221
222 let (longest, depths) = compute_longest_path_and_depths(tasks);
223 let topology = Self::classify_with_depths(graph, longest, &depths);
224 let strategy = Self::strategy(topology);
225 let base = config.max_parallel as usize;
226 let max_parallel = Self::compute_max_parallel(topology, base);
227
228 TopologyAnalysis {
229 topology,
230 strategy,
231 max_parallel,
232 depth: longest,
233 depths,
234 }
235 }
236}
237
238pub(crate) fn compute_depths_for_scheduler(
243 graph: &TaskGraph,
244) -> (usize, std::collections::HashMap<TaskId, usize>) {
245 compute_longest_path_and_depths(&graph.tasks)
246}
247
248fn compute_longest_path_and_depths(tasks: &[TaskNode]) -> (usize, HashMap<TaskId, usize>) {
254 let n = tasks.len();
255 if n == 0 {
256 return (0, HashMap::new());
257 }
258
259 let mut in_degree = vec![0usize; n];
260 let mut dependents: Vec<Vec<usize>> = vec![Vec::new(); n];
261 for task in tasks {
262 let i = task.id.index();
263 in_degree[i] = task.depends_on.len();
264 for dep in &task.depends_on {
265 dependents[dep.index()].push(i);
266 }
267 }
268
269 let mut queue: std::collections::VecDeque<usize> = in_degree
270 .iter()
271 .enumerate()
272 .filter(|(_, d)| **d == 0)
273 .map(|(i, _)| i)
274 .collect();
275
276 let mut dist = vec![0usize; n];
277 let mut max_dist = 0usize;
278
279 while let Some(u) = queue.pop_front() {
280 for &v in &dependents[u] {
281 let new_dist = dist[u] + 1;
282 if new_dist > dist[v] {
283 dist[v] = new_dist;
284 }
285 if dist[v] > max_dist {
286 max_dist = dist[v];
287 }
288 in_degree[v] -= 1;
289 if in_degree[v] == 0 {
290 queue.push_back(v);
291 }
292 }
293 }
294
295 let depths: HashMap<TaskId, usize> = tasks.iter().map(|t| (t.id, dist[t.id.index()])).collect();
296
297 (max_dist, depths)
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::graph::{TaskGraph, TaskId, TaskNode};
304
305 fn make_node(id: u32, deps: &[u32]) -> TaskNode {
306 let mut n = TaskNode::new(id, format!("t{id}"), "desc");
307 n.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
308 n
309 }
310
311 fn graph_from(nodes: Vec<TaskNode>) -> TaskGraph {
312 let mut g = TaskGraph::new("test");
313 g.tasks = nodes;
314 g
315 }
316
317 fn default_config() -> zeph_config::OrchestrationConfig {
318 zeph_config::OrchestrationConfig {
319 topology_selection: true,
320 max_parallel: 4,
321 ..zeph_config::OrchestrationConfig::default()
322 }
323 }
324
325 #[test]
328 fn classify_empty_graph() {
329 let g = graph_from(vec![]);
330 assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
331 }
332
333 #[test]
334 fn classify_single_task() {
335 let g = graph_from(vec![make_node(0, &[])]);
336 assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
337 }
338
339 #[test]
340 fn classify_all_parallel() {
341 let g = graph_from(vec![
342 make_node(0, &[]),
343 make_node(1, &[]),
344 make_node(2, &[]),
345 ]);
346 assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
347 }
348
349 #[test]
350 fn classify_two_task_chain() {
351 let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
353 assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
354 }
355
356 #[test]
357 fn classify_linear_chain() {
358 let g = graph_from(vec![
360 make_node(0, &[]),
361 make_node(1, &[0]),
362 make_node(2, &[1]),
363 ]);
364 assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
365 }
366
367 #[test]
368 fn classify_fan_out() {
369 let g = graph_from(vec![
371 make_node(0, &[]),
372 make_node(1, &[0]),
373 make_node(2, &[0]),
374 make_node(3, &[0]),
375 ]);
376 assert_eq!(TopologyClassifier::classify(&g), Topology::FanOut);
377 }
378
379 #[test]
380 fn classify_fan_in() {
381 let g = graph_from(vec![
383 make_node(0, &[]),
384 make_node(1, &[]),
385 make_node(2, &[]),
386 make_node(3, &[0, 1, 2]),
387 ]);
388 assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
389 }
390
391 #[test]
392 fn classify_fan_in_two_roots() {
393 let g = graph_from(vec![
395 make_node(0, &[]),
396 make_node(1, &[]),
397 make_node(2, &[0, 1]),
398 ]);
399 assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
400 }
401
402 #[test]
403 fn classify_hierarchical() {
404 let g = graph_from(vec![
407 make_node(0, &[]),
408 make_node(1, &[0]),
409 make_node(2, &[0]),
410 make_node(3, &[1]),
411 make_node(4, &[2]),
412 ]);
413 assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
414 }
415
416 #[test]
417 fn classify_hierarchical_three_levels() {
418 let g = graph_from(vec![
422 make_node(0, &[]),
423 make_node(1, &[0]),
424 make_node(2, &[0]),
425 make_node(3, &[1]),
426 ]);
427 assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
428 }
429
430 #[test]
431 fn classify_diamond_is_mixed() {
432 let g = graph_from(vec![
434 make_node(0, &[]),
435 make_node(1, &[0]),
436 make_node(2, &[0]),
437 make_node(3, &[1, 2]),
438 ]);
439 assert_eq!(TopologyClassifier::classify(&g), Topology::Mixed);
440 }
441
442 #[test]
443 fn classify_fan_out_with_chain_on_branch_is_hierarchical() {
444 let g = graph_from(vec![
446 make_node(0, &[]),
447 make_node(1, &[0]),
448 make_node(2, &[0]),
449 make_node(3, &[1]),
450 ]);
451 assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
452 }
453
454 #[test]
457 fn strategy_all_parallel_is_full_parallel() {
458 assert_eq!(
459 TopologyClassifier::strategy(Topology::AllParallel),
460 DispatchStrategy::FullParallel
461 );
462 }
463
464 #[test]
465 fn strategy_fan_out_is_full_parallel() {
466 assert_eq!(
467 TopologyClassifier::strategy(Topology::FanOut),
468 DispatchStrategy::FullParallel
469 );
470 }
471
472 #[test]
473 fn strategy_fan_in_is_full_parallel() {
474 assert_eq!(
475 TopologyClassifier::strategy(Topology::FanIn),
476 DispatchStrategy::FullParallel
477 );
478 }
479
480 #[test]
481 fn strategy_linear_chain_is_sequential() {
482 assert_eq!(
483 TopologyClassifier::strategy(Topology::LinearChain),
484 DispatchStrategy::Sequential
485 );
486 }
487
488 #[test]
489 fn strategy_hierarchical_is_level_barrier() {
490 assert_eq!(
491 TopologyClassifier::strategy(Topology::Hierarchical),
492 DispatchStrategy::LevelBarrier
493 );
494 }
495
496 #[test]
497 fn strategy_mixed_is_adaptive() {
498 assert_eq!(
499 TopologyClassifier::strategy(Topology::Mixed),
500 DispatchStrategy::Adaptive
501 );
502 }
503
504 #[test]
507 fn analyze_disabled_returns_full_parallel() {
508 let mut cfg = default_config();
509 cfg.topology_selection = false;
510 let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
511 let analysis = TopologyClassifier::analyze(&g, &cfg);
512 assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
513 assert_eq!(analysis.max_parallel, 4);
514 assert_eq!(analysis.topology, Topology::AllParallel);
515 }
516
517 #[test]
518 fn analyze_linear_chain_returns_sequential() {
519 let cfg = default_config();
520 let g = graph_from(vec![
521 make_node(0, &[]),
522 make_node(1, &[0]),
523 make_node(2, &[1]),
524 ]);
525 let analysis = TopologyClassifier::analyze(&g, &cfg);
526 assert_eq!(analysis.topology, Topology::LinearChain);
527 assert_eq!(analysis.strategy, DispatchStrategy::Sequential);
528 assert_eq!(analysis.max_parallel, 1);
529 assert_eq!(analysis.depth, 2);
530 }
531
532 #[test]
533 fn analyze_hierarchical_returns_level_barrier() {
534 let cfg = default_config();
535 let g = graph_from(vec![
537 make_node(0, &[]),
538 make_node(1, &[0]),
539 make_node(2, &[0]),
540 make_node(3, &[1]),
541 ]);
542 let analysis = TopologyClassifier::analyze(&g, &cfg);
543 assert_eq!(analysis.topology, Topology::Hierarchical);
544 assert_eq!(analysis.strategy, DispatchStrategy::LevelBarrier);
545 assert_eq!(analysis.max_parallel, 4);
546 assert_eq!(analysis.depth, 2);
547 assert_eq!(analysis.depths[&TaskId(0)], 0);
549 assert_eq!(analysis.depths[&TaskId(1)], 1);
550 assert_eq!(analysis.depths[&TaskId(2)], 1);
551 assert_eq!(analysis.depths[&TaskId(3)], 2);
552 }
553
554 #[test]
555 fn analyze_fan_in_returns_full_parallel() {
556 let cfg = default_config();
557 let g = graph_from(vec![
559 make_node(0, &[]),
560 make_node(1, &[]),
561 make_node(2, &[]),
562 make_node(3, &[0, 1, 2]),
563 ]);
564 let analysis = TopologyClassifier::analyze(&g, &cfg);
565 assert_eq!(analysis.topology, Topology::FanIn);
566 assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
567 assert_eq!(analysis.max_parallel, 4);
568 }
569
570 #[test]
571 fn analyze_mixed_is_conservative() {
572 let cfg = default_config(); let g = graph_from(vec![
574 make_node(0, &[]),
575 make_node(1, &[0]),
576 make_node(2, &[0]),
577 make_node(3, &[1, 2]),
578 ]);
579 let analysis = TopologyClassifier::analyze(&g, &cfg);
580 assert_eq!(analysis.topology, Topology::Mixed);
581 assert_eq!(analysis.strategy, DispatchStrategy::Adaptive);
582 assert_eq!(analysis.max_parallel, 3);
583 }
584
585 #[test]
586 fn analyze_depths_correct_for_fan_out() {
587 let cfg = default_config();
588 let g = graph_from(vec![
590 make_node(0, &[]),
591 make_node(1, &[0]),
592 make_node(2, &[0]),
593 make_node(3, &[0]),
594 ]);
595 let analysis = TopologyClassifier::analyze(&g, &cfg);
596 assert_eq!(analysis.depths[&TaskId(0)], 0);
597 assert_eq!(analysis.depths[&TaskId(1)], 1);
598 assert_eq!(analysis.depths[&TaskId(2)], 1);
599 assert_eq!(analysis.depths[&TaskId(3)], 1);
600 }
601
602 #[test]
603 fn analyze_mixed_respects_max_parallel_one() {
604 let mut cfg = default_config();
605 cfg.max_parallel = 1;
606 let g = graph_from(vec![
607 make_node(0, &[]),
608 make_node(1, &[0]),
609 make_node(2, &[0]),
610 make_node(3, &[1, 2]),
611 ]);
612 let analysis = TopologyClassifier::analyze(&g, &cfg);
613 assert_eq!(analysis.max_parallel, 1);
614 }
615
616 #[test]
619 fn classify_with_depths_matches_classify_for_all_variants() {
620 let graphs = vec![
621 graph_from(vec![
623 make_node(0, &[]),
624 make_node(1, &[]),
625 make_node(2, &[]),
626 ]),
627 graph_from(vec![
629 make_node(0, &[]),
630 make_node(1, &[0]),
631 make_node(2, &[1]),
632 ]),
633 graph_from(vec![
635 make_node(0, &[]),
636 make_node(1, &[0]),
637 make_node(2, &[0]),
638 make_node(3, &[0]),
639 ]),
640 graph_from(vec![
642 make_node(0, &[]),
643 make_node(1, &[]),
644 make_node(2, &[]),
645 make_node(3, &[0, 1, 2]),
646 ]),
647 graph_from(vec![
649 make_node(0, &[]),
650 make_node(1, &[0]),
651 make_node(2, &[0]),
652 make_node(3, &[1]),
653 ]),
654 graph_from(vec![
656 make_node(0, &[]),
657 make_node(1, &[0]),
658 make_node(2, &[0]),
659 make_node(3, &[1, 2]),
660 ]),
661 ];
662
663 for g in &graphs {
664 let expected = TopologyClassifier::classify(g);
665 let tasks = &g.tasks;
667 let (longest, depths) = if tasks.is_empty() {
668 (0, std::collections::HashMap::new())
669 } else {
670 let cfg = default_config();
672 let analysis = TopologyClassifier::analyze(g, &cfg);
673 (analysis.depth, analysis.depths)
674 };
675 let actual = TopologyClassifier::classify_with_depths(g, longest, &depths);
676 assert_eq!(
677 actual,
678 expected,
679 "classify_with_depths mismatch for graph with {} tasks",
680 g.tasks.len()
681 );
682 }
683 }
684
685 #[test]
688 fn compute_max_parallel_all_parallel_returns_base() {
689 assert_eq!(
690 TopologyClassifier::compute_max_parallel(Topology::AllParallel, 8),
691 8
692 );
693 }
694
695 #[test]
696 fn compute_max_parallel_fan_out_returns_base() {
697 assert_eq!(
698 TopologyClassifier::compute_max_parallel(Topology::FanOut, 6),
699 6
700 );
701 }
702
703 #[test]
704 fn compute_max_parallel_fan_in_returns_base() {
705 assert_eq!(
706 TopologyClassifier::compute_max_parallel(Topology::FanIn, 4),
707 4
708 );
709 }
710
711 #[test]
712 fn compute_max_parallel_hierarchical_returns_base() {
713 assert_eq!(
714 TopologyClassifier::compute_max_parallel(Topology::Hierarchical, 10),
715 10
716 );
717 }
718
719 #[test]
720 fn compute_max_parallel_linear_chain_returns_one() {
721 assert_eq!(
722 TopologyClassifier::compute_max_parallel(Topology::LinearChain, 8),
723 1
724 );
725 assert_eq!(
726 TopologyClassifier::compute_max_parallel(Topology::LinearChain, 1),
727 1
728 );
729 }
730
731 #[test]
732 fn compute_max_parallel_mixed_is_half_plus_one() {
733 assert_eq!(
735 TopologyClassifier::compute_max_parallel(Topology::Mixed, 4),
736 3
737 );
738 assert_eq!(
740 TopologyClassifier::compute_max_parallel(Topology::Mixed, 2),
741 2
742 );
743 assert_eq!(
745 TopologyClassifier::compute_max_parallel(Topology::Mixed, 1),
746 1
747 );
748 assert_eq!(
750 TopologyClassifier::compute_max_parallel(Topology::Mixed, 8),
751 5
752 );
753 }
754}