1use std::collections::{HashMap, HashSet, VecDeque};
39
40use crate::error::{CoreError, CoreResult, ErrorContext};
41
42pub type TaskId = u64;
46
47#[non_exhaustive]
53#[derive(Debug, Clone)]
54pub struct DepTaskNode {
55 pub id: TaskId,
57 pub name: String,
59 pub priority: i32,
61 pub estimated_cost: f64,
63 pub metadata: HashMap<String, String>,
65}
66
67#[non_exhaustive]
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum TopologicalAlgorithm {
73 Kahn,
75 DfsBased,
77}
78
79#[non_exhaustive]
83#[derive(Debug, Clone)]
84pub struct DependencyGraphConfig {
85 pub enable_cycle_detection: bool,
87 pub max_depth: usize,
89 pub topological_order: TopologicalAlgorithm,
91}
92
93impl Default for DependencyGraphConfig {
94 fn default() -> Self {
95 Self {
96 enable_cycle_detection: true,
97 max_depth: 1000,
98 topological_order: TopologicalAlgorithm::Kahn,
99 }
100 }
101}
102
103pub struct DependencyGraph {
110 config: DependencyGraphConfig,
111 nodes: HashMap<TaskId, DepTaskNode>,
113 edges: HashMap<TaskId, Vec<TaskId>>,
115 rev_edges: HashMap<TaskId, Vec<TaskId>>,
117 next_id: TaskId,
118}
119
120impl DependencyGraph {
121 pub fn new(config: DependencyGraphConfig) -> Self {
123 Self {
124 config,
125 nodes: HashMap::new(),
126 edges: HashMap::new(),
127 rev_edges: HashMap::new(),
128 next_id: 0,
129 }
130 }
131
132 pub fn add_task(&mut self, name: &str, priority: i32) -> TaskId {
136 self.add_task_with_cost(name, priority, 1.0)
137 }
138
139 pub fn add_task_with_cost(&mut self, name: &str, priority: i32, cost: f64) -> TaskId {
141 let id = self.next_id;
142 self.next_id += 1;
143 let node = DepTaskNode {
144 id,
145 name: name.to_owned(),
146 priority,
147 estimated_cost: cost,
148 metadata: HashMap::new(),
149 };
150 self.nodes.insert(id, node);
151 self.edges.insert(id, Vec::new());
152 self.rev_edges.insert(id, Vec::new());
153 id
154 }
155
156 pub fn add_dependency(&mut self, task: TaskId, dep: TaskId) -> CoreResult<()> {
165 if !self.nodes.contains_key(&task) {
166 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
167 "add_dependency: task {task} not found"
168 ))));
169 }
170 if !self.nodes.contains_key(&dep) {
171 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
172 "add_dependency: dep {dep} not found"
173 ))));
174 }
175 if task == dep {
176 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
177 "add_dependency: self-loop on task {task}"
178 ))));
179 }
180 if self.config.enable_cycle_detection && self.is_reachable(dep, task) {
183 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
184 "add_dependency: cycle detected — dep {dep} is already reachable from task {task}"
185 ))));
186 }
187 let deps = self.edges.entry(task).or_default();
189 if !deps.contains(&dep) {
190 deps.push(dep);
191 }
192 let rev = self.rev_edges.entry(dep).or_default();
193 if !rev.contains(&task) {
194 rev.push(task);
195 }
196 Ok(())
197 }
198
199 pub fn n_tasks(&self) -> usize {
203 self.nodes.len()
204 }
205
206 pub fn n_edges(&self) -> usize {
208 self.edges.values().map(|v| v.len()).sum()
209 }
210
211 pub fn get_task(&self, id: TaskId) -> Option<&DepTaskNode> {
213 self.nodes.get(&id)
214 }
215
216 pub fn dependencies(&self, id: TaskId) -> &[TaskId] {
218 self.edges.get(&id).map(|v| v.as_slice()).unwrap_or(&[])
219 }
220
221 pub fn dependents(&self, id: TaskId) -> Vec<TaskId> {
223 self.rev_edges.get(&id).cloned().unwrap_or_default()
224 }
225
226 pub fn is_ready(&self, id: TaskId, completed: &HashSet<TaskId>) -> bool {
228 self.edges
229 .get(&id)
230 .map(|deps| deps.iter().all(|d| completed.contains(d)))
231 .unwrap_or(true)
232 }
233
234 pub fn topological_sort(&self) -> CoreResult<Vec<TaskId>> {
240 match self.config.topological_order {
241 TopologicalAlgorithm::Kahn => self.topological_sort_kahn(),
242 TopologicalAlgorithm::DfsBased => self.topological_sort_dfs(),
243 }
244 }
245
246 pub fn topological_sort_kahn(&self) -> CoreResult<Vec<TaskId>> {
250 let mut in_degree: HashMap<TaskId, usize> = self
252 .nodes
253 .keys()
254 .map(|&id| (id, self.edges[&id].len()))
255 .collect();
256
257 let mut ready: Vec<TaskId> = in_degree
259 .iter()
260 .filter_map(|(&id, °)| if deg == 0 { Some(id) } else { None })
261 .collect();
262 ready.sort_unstable_by(|&a, &b| {
264 let pa = self.nodes[&a].priority;
265 let pb = self.nodes[&b].priority;
266 pb.cmp(&pa).then(a.cmp(&b))
267 });
268
269 let mut order = Vec::with_capacity(self.nodes.len());
270 while !ready.is_empty() {
271 let id = ready.remove(0);
273 order.push(id);
274 let new_ready: Vec<TaskId> = if let Some(children) = self.rev_edges.get(&id) {
276 children
277 .iter()
278 .filter_map(|&child| {
279 let deg = in_degree.entry(child).or_insert(0);
280 if *deg > 0 {
281 *deg -= 1;
282 }
283 if *deg == 0 {
284 Some(child)
285 } else {
286 None
287 }
288 })
289 .collect()
290 } else {
291 Vec::new()
292 };
293 for nid in new_ready {
295 let pos = ready.partition_point(|&x| {
296 let px = self.nodes[&x].priority;
297 let pn = self.nodes[&nid].priority;
298 px > pn || (px == pn && x < nid)
299 });
300 ready.insert(pos, nid);
301 }
302 }
303
304 if order.len() != self.nodes.len() {
305 return Err(CoreError::InvalidInput(ErrorContext::new(
306 "topological_sort: cycle detected in graph",
307 )));
308 }
309 Ok(order)
310 }
311
312 pub fn topological_sort_dfs(&self) -> CoreResult<Vec<TaskId>> {
320 let mut color: HashMap<TaskId, u8> = self.nodes.keys().map(|&id| (id, 0u8)).collect();
329 let mut result: Vec<TaskId> = Vec::with_capacity(self.nodes.len());
330
331 let mut all_ids: Vec<TaskId> = self.nodes.keys().cloned().collect();
333 all_ids.sort_unstable();
334
335 let mut call_stack: Vec<(TaskId, usize)> = Vec::new();
337
338 for start in all_ids {
339 if color[&start] != 0 {
340 continue;
341 }
342 call_stack.push((start, 0));
343 *color.entry(start).or_insert(0) = 1; while let Some(frame) = call_stack.last_mut() {
346 let (node, idx) = *frame;
347 let successors: Vec<TaskId> =
348 self.rev_edges.get(&node).cloned().unwrap_or_default();
349
350 if idx < successors.len() {
351 let child = successors[idx];
352 frame.1 += 1; match color[&child] {
354 1 => {
355 return Err(CoreError::InvalidInput(ErrorContext::new(
357 "topological_sort_dfs: cycle detected",
358 )));
359 }
360 0 => {
361 *color.entry(child).or_insert(0) = 1;
363 call_stack.push((child, 0));
364 }
365 _ => {} }
367 } else {
368 call_stack.pop();
370 *color.entry(node).or_insert(1) = 2;
371 result.push(node);
372 }
373 }
374 }
375
376 result.reverse();
378 Ok(result)
379 }
380
381 pub fn find_cycles(&self) -> Vec<Vec<TaskId>> {
388 let mut color: HashMap<TaskId, u8> = self.nodes.keys().map(|&id| (id, 0u8)).collect();
390 let mut cycles: Vec<Vec<TaskId>> = Vec::new();
391 let mut stack: Vec<TaskId> = Vec::new();
392
393 for &start in self.nodes.keys() {
394 if color[&start] != 0 {
395 continue;
396 }
397 self.dfs_find_cycles(start, &mut color, &mut stack, &mut cycles);
398 }
399 cycles
400 }
401
402 fn dfs_find_cycles(
403 &self,
404 node: TaskId,
405 color: &mut HashMap<TaskId, u8>,
406 stack: &mut Vec<TaskId>,
407 cycles: &mut Vec<Vec<TaskId>>,
408 ) {
409 if stack.len() >= self.config.max_depth {
410 return;
411 }
412 *color.entry(node).or_insert(0) = 1; stack.push(node);
414
415 let deps: Vec<TaskId> = self.edges.get(&node).cloned().unwrap_or_default();
416 for dep in deps {
417 match *color.entry(dep).or_insert(0) {
418 1 => {
419 if let Some(pos) = stack.iter().position(|&x| x == dep) {
421 let cycle: Vec<TaskId> = stack[pos..].to_vec();
422 cycles.push(cycle);
423 }
424 }
425 0 => self.dfs_find_cycles(dep, color, stack, cycles),
426 _ => {}
427 }
428 }
429 stack.pop();
430 *color.entry(node).or_insert(1) = 2; }
432
433 pub fn critical_path(&self) -> Vec<TaskId> {
441 let order = match self.topological_sort() {
442 Ok(o) => o,
443 Err(_) => return Vec::new(),
444 };
445 let mut dist: HashMap<TaskId, f64> = HashMap::new();
447 let mut prev: HashMap<TaskId, Option<TaskId>> = HashMap::new();
448
449 for &id in &order {
450 let cost = self.nodes.get(&id).map(|n| n.estimated_cost).unwrap_or(1.0);
451 let max_pred_dist = self
452 .edges
453 .get(&id)
454 .map(|deps| {
455 deps.iter()
456 .filter_map(|d| dist.get(d).copied())
457 .fold(f64::NEG_INFINITY, f64::max)
458 })
459 .unwrap_or(f64::NEG_INFINITY);
460 let pred = if max_pred_dist.is_finite() {
461 self.edges.get(&id).and_then(|deps| {
462 deps.iter()
463 .max_by(|&&a, &&b| {
464 dist.get(&a)
465 .copied()
466 .unwrap_or(f64::NEG_INFINITY)
467 .partial_cmp(&dist.get(&b).copied().unwrap_or(f64::NEG_INFINITY))
468 .unwrap_or(std::cmp::Ordering::Equal)
469 })
470 .copied()
471 })
472 } else {
473 None
474 };
475 let d = if max_pred_dist.is_finite() {
476 max_pred_dist + cost
477 } else {
478 cost
479 };
480 dist.insert(id, d);
481 prev.insert(id, pred);
482 }
483
484 let sink = dist
486 .iter()
487 .max_by(|(_, &da), (_, &db)| da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal))
488 .map(|(&id, _)| id);
489
490 let mut path = Vec::new();
491 let mut current = sink;
492 while let Some(id) = current {
493 path.push(id);
494 current = prev.get(&id).and_then(|opt| *opt);
495 }
496 path.reverse();
497 path
498 }
499
500 pub fn execution_layers(&self) -> CoreResult<Vec<Vec<TaskId>>> {
509 let mut in_deg: HashMap<TaskId, usize> = self
511 .nodes
512 .keys()
513 .map(|&id| (id, self.edges[&id].len()))
514 .collect();
515
516 let mut current_layer: Vec<TaskId> = in_deg
517 .iter()
518 .filter_map(|(&id, °)| if deg == 0 { Some(id) } else { None })
519 .collect();
520 current_layer.sort_unstable();
521
522 let mut layers: Vec<Vec<TaskId>> = Vec::new();
523 let mut processed = 0usize;
524
525 while !current_layer.is_empty() {
526 layers.push(current_layer.clone());
527 processed += current_layer.len();
528 let mut next_layer: Vec<TaskId> = Vec::new();
529 for id in ¤t_layer {
530 if let Some(children) = self.rev_edges.get(id) {
531 for &child in children {
532 let deg = in_deg.entry(child).or_insert(0);
533 if *deg > 0 {
534 *deg -= 1;
535 }
536 if *deg == 0 {
537 next_layer.push(child);
538 }
539 }
540 }
541 }
542 next_layer.sort_unstable();
543 current_layer = next_layer;
544 }
545
546 if processed != self.nodes.len() {
547 return Err(CoreError::InvalidInput(ErrorContext::new(
548 "execution_layers: cycle detected",
549 )));
550 }
551 Ok(layers)
552 }
553
554 pub fn parallel_schedule(&self, n_workers: usize) -> CoreResult<Vec<Vec<TaskId>>> {
563 let layers = self.execution_layers()?;
564 let n_workers = n_workers.max(1);
565 let mut schedule: Vec<Vec<TaskId>> = vec![Vec::new(); n_workers];
566
567 let mut worker = 0usize;
569 for layer in &layers {
570 let mut sorted_layer = layer.clone();
572 sorted_layer.sort_unstable_by(|&a, &b| {
573 let ca = self.nodes.get(&a).map(|n| n.estimated_cost).unwrap_or(1.0);
574 let cb = self.nodes.get(&b).map(|n| n.estimated_cost).unwrap_or(1.0);
575 cb.partial_cmp(&ca).unwrap_or(std::cmp::Ordering::Equal)
576 });
577 for task_id in sorted_layer {
578 schedule[worker % n_workers].push(task_id);
579 worker += 1;
580 }
581 }
582 Ok(schedule)
583 }
584
585 fn is_reachable(&self, from: TaskId, target: TaskId) -> bool {
589 let mut visited: HashSet<TaskId> = HashSet::new();
590 let mut queue = VecDeque::new();
591 queue.push_back(from);
592 while let Some(cur) = queue.pop_front() {
593 if cur == target {
594 return true;
595 }
596 if !visited.insert(cur) {
597 continue;
598 }
599 if let Some(deps) = self.edges.get(&cur) {
600 for &dep in deps {
601 if !visited.contains(&dep) {
602 queue.push_back(dep);
603 }
604 }
605 }
606 }
607 false
608 }
609}
610
611#[cfg(test)]
614mod tests {
615 use super::*;
616
617 fn make_config() -> DependencyGraphConfig {
618 DependencyGraphConfig::default()
619 }
620
621 fn build_chain() -> (DependencyGraph, TaskId, TaskId, TaskId) {
623 let mut g = DependencyGraph::new(make_config());
624 let a = g.add_task("A", 0);
625 let b = g.add_task("B", 0);
626 let c = g.add_task("C", 0);
627 g.add_dependency(b, a).expect("b depends on a");
628 g.add_dependency(c, b).expect("c depends on b");
629 (g, a, b, c)
630 }
631
632 #[test]
633 fn test_add_task_and_dependency_no_error() {
634 let mut g = DependencyGraph::new(make_config());
635 let a = g.add_task("a", 0);
636 let b = g.add_task("b", 0);
637 g.add_dependency(b, a).expect("valid DAG edge");
638 assert_eq!(g.n_tasks(), 2);
639 assert_eq!(g.n_edges(), 1);
640 }
641
642 #[test]
643 fn test_topological_sort_chain() {
644 let (g, a, b, c) = build_chain();
645 let order = g.topological_sort().expect("acyclic");
646 assert_eq!(order.len(), 3);
647 let pos_a = order.iter().position(|&x| x == a).expect("a in order");
648 let pos_b = order.iter().position(|&x| x == b).expect("b in order");
649 let pos_c = order.iter().position(|&x| x == c).expect("c in order");
650 assert!(pos_a < pos_b, "A must precede B");
651 assert!(pos_b < pos_c, "B must precede C");
652 }
653
654 #[test]
655 fn test_topological_sort_cycle_returns_err() {
656 let mut g = DependencyGraph::new(DependencyGraphConfig {
657 enable_cycle_detection: false,
658 ..DependencyGraphConfig::default()
659 });
660 let a = g.add_task("a", 0);
661 let b = g.add_task("b", 0);
662 g.edges.get_mut(&a).expect("a edges").push(b);
664 g.edges.get_mut(&b).expect("b edges").push(a);
665 assert!(
666 g.topological_sort_kahn().is_err(),
667 "cycle must be detected by Kahn"
668 );
669 }
670
671 #[test]
672 fn test_add_dependency_cycle_rejected() {
673 let mut g = DependencyGraph::new(make_config());
674 let a = g.add_task("a", 0);
675 let b = g.add_task("b", 0);
676 g.add_dependency(b, a).expect("b → a");
677 assert!(g.add_dependency(a, b).is_err(), "cycle must be rejected");
679 }
680
681 #[test]
682 fn test_find_cycles_returns_cycle() {
683 let mut g = DependencyGraph::new(DependencyGraphConfig {
684 enable_cycle_detection: false,
685 ..DependencyGraphConfig::default()
686 });
687 let a = g.add_task("a", 0);
688 let b = g.add_task("b", 0);
689 g.edges.get_mut(&a).expect("a").push(b);
690 g.edges.get_mut(&b).expect("b").push(a);
691 let cycles = g.find_cycles();
692 assert!(!cycles.is_empty(), "should find at least one cycle");
693 }
694
695 #[test]
696 fn test_execution_layers_independent_tasks_in_layer_0() {
697 let mut g = DependencyGraph::new(make_config());
698 g.add_task("x", 0);
699 g.add_task("y", 0);
700 g.add_task("z", 0);
701 let layers = g.execution_layers().expect("acyclic");
702 assert_eq!(layers.len(), 1, "all independent tasks in one layer");
703 assert_eq!(layers[0].len(), 3);
704 }
705
706 #[test]
707 fn test_execution_layers_chain() {
708 let (g, _a, _b, _c) = build_chain();
709 let layers = g.execution_layers().expect("acyclic");
710 assert_eq!(layers.len(), 3, "chain has 3 layers");
711 assert_eq!(layers[0].len(), 1); assert_eq!(layers[1].len(), 1); assert_eq!(layers[2].len(), 1); }
715
716 #[test]
717 fn test_critical_path_selects_longest_cost_path() {
718 let mut g = DependencyGraph::new(make_config());
719 let source = g.add_task_with_cost("source", 0, 1.0);
723 let cheap = g.add_task_with_cost("cheap", 0, 1.0);
724 let expensive = g.add_task_with_cost("expensive", 0, 10.0);
725 let sink = g.add_task_with_cost("sink", 0, 1.0);
726 g.add_dependency(cheap, source).expect("cheap dep");
727 g.add_dependency(expensive, source).expect("expensive dep");
728 g.add_dependency(sink, cheap).expect("sink dep cheap");
729 g.add_dependency(sink, expensive)
730 .expect("sink dep expensive");
731
732 let path = g.critical_path();
733 assert!(!path.is_empty(), "critical path should be non-empty");
734 assert!(
736 path.contains(&expensive),
737 "critical path must go through 'expensive' node"
738 );
739 }
740
741 #[test]
742 fn test_parallel_schedule_all_tasks_covered() {
743 let (g, _a, _b, _c) = build_chain();
744 let schedule = g.parallel_schedule(2).expect("valid schedule");
745 let all_tasks: HashSet<TaskId> = schedule.into_iter().flatten().collect();
746 assert_eq!(all_tasks.len(), 3, "all tasks must be in schedule");
747 }
748
749 #[test]
750 fn test_dependency_graph_config_default() {
751 let cfg = DependencyGraphConfig::default();
752 assert!(cfg.enable_cycle_detection);
753 assert_eq!(cfg.max_depth, 1000);
754 assert_eq!(cfg.topological_order, TopologicalAlgorithm::Kahn);
755 }
756
757 #[test]
758 fn test_is_ready_task_with_all_deps_complete() {
759 let mut g = DependencyGraph::new(make_config());
760 let a = g.add_task("a", 0);
761 let b = g.add_task("b", 0);
762 g.add_dependency(b, a).expect("b dep a");
763 let completed: HashSet<TaskId> = [a].into();
764 assert!(g.is_ready(b, &completed), "b is ready when a is complete");
765 let empty: HashSet<TaskId> = HashSet::new();
766 assert!(!g.is_ready(b, &empty), "b not ready when a is incomplete");
767 }
768
769 #[test]
770 fn test_topological_sort_dfs_chain() {
771 let (g, a, b, c) = build_chain();
772 let order = g.topological_sort_dfs().expect("acyclic DFS");
773 let pos_a = order.iter().position(|&x| x == a).expect("a in order");
774 let pos_b = order.iter().position(|&x| x == b).expect("b in order");
775 let pos_c = order.iter().position(|&x| x == c).expect("c in order");
776 assert!(pos_a < pos_b, "DFS: A must precede B");
777 assert!(pos_b < pos_c, "DFS: B must precede C");
778 }
779
780 #[test]
781 fn test_dependencies_and_dependents() {
782 let mut g = DependencyGraph::new(make_config());
783 let a = g.add_task("a", 0);
784 let b = g.add_task("b", 0);
785 g.add_dependency(b, a).expect("b dep a");
786 assert_eq!(g.dependencies(b), &[a]);
787 assert_eq!(g.dependents(a), vec![b]);
788 }
789
790 #[test]
791 fn test_get_task_metadata() {
792 let mut g = DependencyGraph::new(make_config());
793 let id = g.add_task_with_cost("my_task", 5, 42.0);
794 let node = g.get_task(id).expect("task should exist");
795 assert_eq!(node.name, "my_task");
796 assert_eq!(node.priority, 5);
797 assert!((node.estimated_cost - 42.0).abs() < f64::EPSILON);
798 }
799}