1use crate::error::{CoreError, ErrorContext, ErrorLocation};
39use std::collections::{HashMap, HashSet, VecDeque};
40use std::sync::{Arc, Mutex};
41use std::thread;
42
43pub type TaskGraphResult<T> = Result<T, CoreError>;
45
46type TaskFn<T> = Box<dyn FnOnce(&HashMap<String, T>) -> TaskGraphResult<T> + Send>;
51
52struct TaskNode<T: Send + 'static> {
54 name: String,
56 dependencies: Vec<String>,
58 func: Option<TaskFn<T>>,
60}
61
62pub struct TaskGraph<T: Clone + Send + 'static> {
71 nodes: HashMap<String, TaskNode<T>>,
73 insertion_order: Vec<String>,
75}
76
77impl<T: Clone + Send + 'static> TaskGraph<T> {
78 pub fn new() -> Self {
80 Self {
81 nodes: HashMap::new(),
82 insertion_order: Vec::new(),
83 }
84 }
85
86 pub fn len(&self) -> usize {
88 self.nodes.len()
89 }
90
91 pub fn is_empty(&self) -> bool {
93 self.nodes.is_empty()
94 }
95
96 pub fn contains_task(&self, name: &str) -> bool {
98 self.nodes.contains_key(name)
99 }
100
101 pub fn dependencies(&self, name: &str) -> Option<&[String]> {
103 self.nodes.get(name).map(|n| n.dependencies.as_slice())
104 }
105
106 pub fn add_task<F>(&mut self, name: &str, deps: &[&str], func: F) -> TaskGraphResult<()>
118 where
119 F: FnOnce(&HashMap<String, T>) -> TaskGraphResult<T> + Send + 'static,
120 {
121 if self.nodes.contains_key(name) {
123 return Err(CoreError::ValueError(
124 ErrorContext::new(format!("Task '{name}' already exists in the graph"))
125 .with_location(ErrorLocation::new(file!(), line!())),
126 ));
127 }
128
129 for dep in deps {
131 if !self.nodes.contains_key(*dep) {
132 return Err(CoreError::ValueError(
133 ErrorContext::new(format!(
134 "Dependency '{dep}' for task '{name}' does not exist in the graph"
135 ))
136 .with_location(ErrorLocation::new(file!(), line!())),
137 ));
138 }
139 }
140
141 let dep_names: Vec<String> = deps.iter().map(|d| d.to_string()).collect();
142
143 self.nodes.insert(
145 name.to_string(),
146 TaskNode {
147 name: name.to_string(),
148 dependencies: dep_names.clone(),
149 func: None, },
151 );
152
153 if self.has_cycle() {
154 self.nodes.remove(name);
156 return Err(CoreError::ComputationError(
157 ErrorContext::new(format!(
158 "Adding task '{name}' would create a cycle in the dependency graph"
159 ))
160 .with_location(ErrorLocation::new(file!(), line!())),
161 ));
162 }
163
164 if let Some(node) = self.nodes.get_mut(name) {
166 node.func = Some(Box::new(func));
167 }
168
169 self.insertion_order.push(name.to_string());
170 Ok(())
171 }
172
173 fn has_cycle(&self) -> bool {
175 let topo = self.topological_sort();
176 topo.len() != self.nodes.len()
178 }
179
180 fn topological_sort(&self) -> Vec<String> {
186 let mut in_degree: HashMap<&str, usize> = HashMap::new();
188 let mut dependents: HashMap<&str, Vec<&str>> = HashMap::new();
189
190 for (name, node) in &self.nodes {
191 in_degree.entry(name.as_str()).or_insert(0);
192 for dep in &node.dependencies {
193 dependents
194 .entry(dep.as_str())
195 .or_default()
196 .push(name.as_str());
197 *in_degree.entry(name.as_str()).or_insert(0) += 1;
198 }
199 }
200
201 let mut queue: VecDeque<&str> = VecDeque::new();
203 for name in &self.insertion_order {
205 if let Some(°) = in_degree.get(name.as_str()) {
206 if deg == 0 {
207 queue.push_back(name.as_str());
208 }
209 }
210 }
211 for name in self.nodes.keys() {
213 if !self.insertion_order.contains(name) {
214 if let Some(°) = in_degree.get(name.as_str()) {
215 if deg == 0 && !queue.contains(&name.as_str()) {
216 queue.push_back(name.as_str());
217 }
218 }
219 }
220 }
221
222 let mut order: Vec<String> = Vec::with_capacity(self.nodes.len());
223
224 while let Some(current) = queue.pop_front() {
225 order.push(current.to_string());
226 if let Some(deps) = dependents.get(current) {
227 for &dep in deps {
228 if let Some(deg) = in_degree.get_mut(dep) {
229 *deg = deg.saturating_sub(1);
230 if *deg == 0 {
231 queue.push_back(dep);
232 }
233 }
234 }
235 }
236 }
237
238 order
239 }
240
241 fn compute_levels(&self) -> Vec<Vec<String>> {
246 let topo = self.topological_sort();
247 let mut level_of: HashMap<String, usize> = HashMap::new();
248
249 for name in &topo {
250 let node = match self.nodes.get(name) {
251 Some(n) => n,
252 None => continue,
253 };
254 let max_dep_level = node
255 .dependencies
256 .iter()
257 .filter_map(|d| level_of.get(d))
258 .copied()
259 .max()
260 .map(|l| l + 1)
261 .unwrap_or(0);
262 level_of.insert(name.clone(), max_dep_level);
263 }
264
265 let max_level = level_of.values().copied().max().unwrap_or(0);
267 let mut levels: Vec<Vec<String>> = vec![Vec::new(); max_level + 1];
268 for (name, level) in &level_of {
269 levels[*level].push(name.clone());
270 }
271
272 levels
273 }
274
275 pub fn execute(mut self) -> TaskGraphResult<HashMap<String, T>> {
286 if self.nodes.is_empty() {
287 return Ok(HashMap::new());
288 }
289
290 let levels = self.compute_levels();
291 let results: Arc<Mutex<HashMap<String, T>>> = Arc::new(Mutex::new(HashMap::new()));
292 let errors: Arc<Mutex<Vec<(String, CoreError)>>> = Arc::new(Mutex::new(Vec::new()));
293
294 let failed_tasks: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
296
297 for level_tasks in &levels {
298 if level_tasks.is_empty() {
299 continue;
300 }
301
302 if level_tasks.len() == 1 {
303 let task_name = &level_tasks[0];
305
306 let dep_failed = {
308 let ft = failed_tasks.lock().map_err(|_| {
309 CoreError::MutexError(
310 ErrorContext::new("Failed to lock failed_tasks".to_string())
311 .with_location(ErrorLocation::new(file!(), line!())),
312 )
313 })?;
314 let node = self.nodes.get(task_name);
315 node.map(|n| n.dependencies.iter().any(|d| ft.contains(d)))
316 .unwrap_or(false)
317 };
318
319 if dep_failed {
320 let mut ft = failed_tasks.lock().map_err(|_| {
321 CoreError::MutexError(
322 ErrorContext::new("Failed to lock failed_tasks".to_string())
323 .with_location(ErrorLocation::new(file!(), line!())),
324 )
325 })?;
326 ft.insert(task_name.clone());
327 continue;
328 }
329
330 if let Some(node) = self.nodes.get_mut(task_name) {
331 if let Some(func) = node.func.take() {
332 let res_snapshot = {
333 let r = results.lock().map_err(|_| {
334 CoreError::MutexError(
335 ErrorContext::new("Failed to lock results".to_string())
336 .with_location(ErrorLocation::new(file!(), line!())),
337 )
338 })?;
339 r.clone()
340 };
341 match func(&res_snapshot) {
342 Ok(val) => {
343 let mut r = results.lock().map_err(|_| {
344 CoreError::MutexError(
345 ErrorContext::new("Failed to lock results".to_string())
346 .with_location(ErrorLocation::new(file!(), line!())),
347 )
348 })?;
349 r.insert(task_name.clone(), val);
350 }
351 Err(e) => {
352 let mut ft = failed_tasks.lock().map_err(|_| {
353 CoreError::MutexError(
354 ErrorContext::new(
355 "Failed to lock failed_tasks".to_string(),
356 )
357 .with_location(ErrorLocation::new(file!(), line!())),
358 )
359 })?;
360 ft.insert(task_name.clone());
361 let mut errs = errors.lock().map_err(|_| {
362 CoreError::MutexError(
363 ErrorContext::new("Failed to lock errors".to_string())
364 .with_location(ErrorLocation::new(file!(), line!())),
365 )
366 })?;
367 errs.push((task_name.clone(), e));
368 }
369 }
370 }
371 }
372 } else {
373 let mut task_closures: Vec<(String, TaskFn<T>, Vec<String>)> = Vec::new();
376 for task_name in level_tasks {
377 if let Some(node) = self.nodes.get_mut(task_name) {
378 if let Some(func) = node.func.take() {
379 task_closures.push((
380 task_name.clone(),
381 func,
382 node.dependencies.clone(),
383 ));
384 }
385 }
386 }
387
388 let res_snapshot = {
390 let r = results.lock().map_err(|_| {
391 CoreError::MutexError(
392 ErrorContext::new("Failed to lock results".to_string())
393 .with_location(ErrorLocation::new(file!(), line!())),
394 )
395 })?;
396 r.clone()
397 };
398
399 let failed_snapshot: HashSet<String> = {
400 let ft = failed_tasks.lock().map_err(|_| {
401 CoreError::MutexError(
402 ErrorContext::new("Failed to lock failed_tasks".to_string())
403 .with_location(ErrorLocation::new(file!(), line!())),
404 )
405 })?;
406 ft.clone()
407 };
408
409 let mut handles: Vec<(String, thread::JoinHandle<Result<T, CoreError>>)> =
411 Vec::new();
412 let mut skipped: Vec<String> = Vec::new();
413
414 for (task_name, func, deps) in task_closures {
415 let dep_failed = deps.iter().any(|d| failed_snapshot.contains(d));
416 if dep_failed {
417 skipped.push(task_name);
418 continue;
419 }
420
421 let snapshot = res_snapshot.clone();
422 let handle = thread::spawn(move || func(&snapshot));
423 handles.push((task_name, handle));
424 }
425
426 {
428 let mut ft = failed_tasks.lock().map_err(|_| {
429 CoreError::MutexError(
430 ErrorContext::new("Failed to lock failed_tasks".to_string())
431 .with_location(ErrorLocation::new(file!(), line!())),
432 )
433 })?;
434 for s in skipped {
435 ft.insert(s);
436 }
437 }
438
439 for (task_name, handle) in handles {
441 match handle.join() {
442 Ok(Ok(val)) => {
443 let mut r = results.lock().map_err(|_| {
444 CoreError::MutexError(
445 ErrorContext::new("Failed to lock results".to_string())
446 .with_location(ErrorLocation::new(file!(), line!())),
447 )
448 })?;
449 r.insert(task_name, val);
450 }
451 Ok(Err(e)) => {
452 let mut ft = failed_tasks.lock().map_err(|_| {
453 CoreError::MutexError(
454 ErrorContext::new("Failed to lock failed_tasks".to_string())
455 .with_location(ErrorLocation::new(file!(), line!())),
456 )
457 })?;
458 ft.insert(task_name.clone());
459 let mut errs = errors.lock().map_err(|_| {
460 CoreError::MutexError(
461 ErrorContext::new("Failed to lock errors".to_string())
462 .with_location(ErrorLocation::new(file!(), line!())),
463 )
464 })?;
465 errs.push((task_name, e));
466 }
467 Err(_panic) => {
468 let mut ft = failed_tasks.lock().map_err(|_| {
469 CoreError::MutexError(
470 ErrorContext::new("Failed to lock failed_tasks".to_string())
471 .with_location(ErrorLocation::new(file!(), line!())),
472 )
473 })?;
474 ft.insert(task_name.clone());
475 let mut errs = errors.lock().map_err(|_| {
476 CoreError::MutexError(
477 ErrorContext::new("Failed to lock errors".to_string())
478 .with_location(ErrorLocation::new(file!(), line!())),
479 )
480 })?;
481 errs.push((
482 task_name,
483 CoreError::ThreadError(
484 ErrorContext::new("Task thread panicked".to_string())
485 .with_location(ErrorLocation::new(file!(), line!())),
486 ),
487 ));
488 }
489 }
490 }
491 }
492 }
493
494 let errs = errors.lock().map_err(|_| {
496 CoreError::MutexError(
497 ErrorContext::new("Failed to lock errors".to_string())
498 .with_location(ErrorLocation::new(file!(), line!())),
499 )
500 })?;
501 if let Some((task_name, err)) = errs.first() {
502 return Err(CoreError::ComputationError(
503 ErrorContext::new(format!("Task '{task_name}' failed: {err}"))
504 .with_location(ErrorLocation::new(file!(), line!())),
505 ));
506 }
507
508 let final_results = results.lock().map_err(|_| {
509 CoreError::MutexError(
510 ErrorContext::new("Failed to lock results".to_string())
511 .with_location(ErrorLocation::new(file!(), line!())),
512 )
513 })?;
514 Ok(final_results.clone())
515 }
516
517 pub fn execute_partial(self) -> TaskGraphResult<HashMap<String, T>> {
524 let levels = self.compute_levels();
527 let mut all_nodes = self.nodes;
528 let results: Arc<Mutex<HashMap<String, T>>> = Arc::new(Mutex::new(HashMap::new()));
529 let failed_tasks: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
530
531 for level_tasks in &levels {
532 for task_name in level_tasks {
533 let dep_failed = {
534 let ft = failed_tasks.lock().map_err(|_| {
535 CoreError::MutexError(
536 ErrorContext::new("Failed to lock failed_tasks".to_string())
537 .with_location(ErrorLocation::new(file!(), line!())),
538 )
539 })?;
540 let node = all_nodes.get(task_name);
541 node.map(|n| n.dependencies.iter().any(|d| ft.contains(d)))
542 .unwrap_or(false)
543 };
544
545 if dep_failed {
546 let mut ft = failed_tasks.lock().map_err(|_| {
547 CoreError::MutexError(
548 ErrorContext::new("Failed to lock failed_tasks".to_string())
549 .with_location(ErrorLocation::new(file!(), line!())),
550 )
551 })?;
552 ft.insert(task_name.clone());
553 continue;
554 }
555
556 if let Some(node) = all_nodes.get_mut(task_name) {
557 if let Some(func) = node.func.take() {
558 let res_snapshot = {
559 let r = results.lock().map_err(|_| {
560 CoreError::MutexError(
561 ErrorContext::new("Failed to lock results".to_string())
562 .with_location(ErrorLocation::new(file!(), line!())),
563 )
564 })?;
565 r.clone()
566 };
567 match func(&res_snapshot) {
568 Ok(val) => {
569 let mut r = results.lock().map_err(|_| {
570 CoreError::MutexError(
571 ErrorContext::new("Failed to lock results".to_string())
572 .with_location(ErrorLocation::new(file!(), line!())),
573 )
574 })?;
575 r.insert(task_name.clone(), val);
576 }
577 Err(_) => {
578 let mut ft = failed_tasks.lock().map_err(|_| {
579 CoreError::MutexError(
580 ErrorContext::new(
581 "Failed to lock failed_tasks".to_string(),
582 )
583 .with_location(ErrorLocation::new(file!(), line!())),
584 )
585 })?;
586 ft.insert(task_name.clone());
587 }
588 }
589 }
590 }
591 }
592 }
593
594 let final_results = results.lock().map_err(|_| {
595 CoreError::MutexError(
596 ErrorContext::new("Failed to lock results".to_string())
597 .with_location(ErrorLocation::new(file!(), line!())),
598 )
599 })?;
600 Ok(final_results.clone())
601 }
602
603 pub fn execution_order(&self) -> Vec<String> {
607 self.topological_sort()
608 }
609
610 pub fn execution_levels(&self) -> Vec<Vec<String>> {
614 self.compute_levels()
615 }
616}
617
618impl<T: Clone + Send + 'static> Default for TaskGraph<T> {
619 fn default() -> Self {
620 Self::new()
621 }
622}
623
624#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn test_empty_graph() {
634 let graph = TaskGraph::<i32>::new();
635 assert!(graph.is_empty());
636 assert_eq!(graph.len(), 0);
637 let results = graph.execute().expect("empty graph should succeed");
638 assert!(results.is_empty());
639 }
640
641 #[test]
642 fn test_single_task() {
643 let mut graph = TaskGraph::<i32>::new();
644 graph.add_task("only", &[], |_| Ok(99)).expect("add failed");
645 assert_eq!(graph.len(), 1);
646 assert!(graph.contains_task("only"));
647
648 let results = graph.execute().expect("execute failed");
649 assert_eq!(results.get("only"), Some(&99));
650 }
651
652 #[test]
653 fn test_linear_chain() {
654 let mut graph = TaskGraph::<i32>::new();
655 graph.add_task("a", &[], |_| Ok(1)).expect("add a");
656 graph
657 .add_task("b", &["a"], |r| Ok(r.get("a").copied().unwrap_or(0) + 10))
658 .expect("add b");
659 graph
660 .add_task("c", &["b"], |r| Ok(r.get("b").copied().unwrap_or(0) + 100))
661 .expect("add c");
662
663 let results = graph.execute().expect("execute failed");
664 assert_eq!(results.get("a"), Some(&1));
665 assert_eq!(results.get("b"), Some(&11));
666 assert_eq!(results.get("c"), Some(&111));
667 }
668
669 #[test]
670 fn test_diamond_dependency() {
671 let mut graph = TaskGraph::<i32>::new();
673 graph.add_task("a", &[], |_| Ok(1)).expect("add a");
674 graph
675 .add_task("b", &["a"], |r| Ok(r.get("a").copied().unwrap_or(0) * 2))
676 .expect("add b");
677 graph
678 .add_task("c", &["a"], |r| Ok(r.get("a").copied().unwrap_or(0) * 3))
679 .expect("add c");
680 graph
681 .add_task("d", &["b", "c"], |r| {
682 let b = r.get("b").copied().unwrap_or(0);
683 let c = r.get("c").copied().unwrap_or(0);
684 Ok(b + c)
685 })
686 .expect("add d");
687
688 let results = graph.execute().expect("execute failed");
689 assert_eq!(results.get("a"), Some(&1));
690 assert_eq!(results.get("b"), Some(&2));
691 assert_eq!(results.get("c"), Some(&3));
692 assert_eq!(results.get("d"), Some(&5));
693 }
694
695 #[test]
696 fn test_parallel_independent_tasks() {
697 let mut graph = TaskGraph::<String>::new();
698 for i in 0..8 {
699 let name = format!("task_{i}");
700 graph
701 .add_task(&name, &[], move |_| Ok(format!("result_{i}")))
702 .expect("add failed");
703 }
704
705 let levels = graph.execution_levels();
706 assert_eq!(levels.len(), 1);
708 assert_eq!(levels[0].len(), 8);
709
710 let results = graph.execute().expect("execute failed");
711 assert_eq!(results.len(), 8);
712 for i in 0..8 {
713 assert_eq!(
714 results.get(&format!("task_{i}")),
715 Some(&format!("result_{i}"))
716 );
717 }
718 }
719
720 #[test]
721 fn test_duplicate_task_name_rejected() {
722 let mut graph = TaskGraph::<i32>::new();
723 graph.add_task("x", &[], |_| Ok(1)).expect("add x");
724 let err = graph.add_task("x", &[], |_| Ok(2));
725 assert!(err.is_err());
726 }
727
728 #[test]
729 fn test_missing_dependency_rejected() {
730 let mut graph = TaskGraph::<i32>::new();
731 let err = graph.add_task("x", &["nonexistent"], |_| Ok(1));
732 assert!(err.is_err());
733 }
734
735 #[test]
736 fn test_cycle_detection() {
737 let mut graph = TaskGraph::<i32>::new();
740 graph.add_task("a", &[], |_| Ok(1)).expect("add a");
741 graph.add_task("b", &["a"], |_| Ok(2)).expect("add b");
742 assert!(!graph.has_cycle());
746 }
747
748 #[test]
749 fn test_task_failure_propagation() {
750 let mut graph = TaskGraph::<i32>::new();
751 graph
752 .add_task("fail", &[], |_| {
753 Err(CoreError::ComputationError(
754 ErrorContext::new("intentional failure".to_string())
755 .with_location(ErrorLocation::new(file!(), line!())),
756 ))
757 })
758 .expect("add fail");
759 graph
760 .add_task("downstream", &["fail"], |_| Ok(42))
761 .expect("add downstream");
762
763 let result = graph.execute();
764 assert!(result.is_err());
765 }
766
767 #[test]
768 fn test_partial_execution() {
769 let mut graph = TaskGraph::<i32>::new();
770 graph.add_task("ok", &[], |_| Ok(10)).expect("add ok");
771 graph
772 .add_task("fail", &[], |_| {
773 Err(CoreError::ComputationError(
774 ErrorContext::new("boom".to_string())
775 .with_location(ErrorLocation::new(file!(), line!())),
776 ))
777 })
778 .expect("add fail");
779 graph
780 .add_task("depends_on_fail", &["fail"], |_| Ok(20))
781 .expect("add depends");
782
783 let results = graph.execute_partial().expect("partial should not error");
784 assert_eq!(results.get("ok"), Some(&10));
785 assert!(!results.contains_key("fail"));
786 assert!(!results.contains_key("depends_on_fail"));
787 }
788
789 #[test]
790 fn test_execution_order() {
791 let mut graph = TaskGraph::<i32>::new();
792 graph.add_task("a", &[], |_| Ok(1)).expect("add a");
793 graph.add_task("b", &["a"], |_| Ok(2)).expect("add b");
794 graph.add_task("c", &["b"], |_| Ok(3)).expect("add c");
795
796 let order = graph.execution_order();
797 assert_eq!(order, vec!["a", "b", "c"]);
798 }
799
800 #[test]
801 fn test_execution_levels_structure() {
802 let mut graph = TaskGraph::<i32>::new();
803 graph.add_task("a", &[], |_| Ok(1)).expect("add a");
804 graph.add_task("b", &[], |_| Ok(2)).expect("add b");
805 graph.add_task("c", &["a", "b"], |_| Ok(3)).expect("add c");
806
807 let levels = graph.execution_levels();
808 assert_eq!(levels.len(), 2);
809 assert_eq!(levels[0].len(), 2);
811 assert!(levels[0].contains(&"a".to_string()));
812 assert!(levels[0].contains(&"b".to_string()));
813 assert_eq!(levels[1].len(), 1);
815 assert!(levels[1].contains(&"c".to_string()));
816 }
817
818 #[test]
819 fn test_wide_fan_in() {
820 let mut graph = TaskGraph::<i32>::new();
821 let n = 16;
822 let mut dep_names: Vec<String> = Vec::new();
823 for i in 0..n {
824 let name = format!("src_{i}");
825 graph
826 .add_task(&name, &[], move |_| Ok(i as i32))
827 .expect("add src");
828 dep_names.push(name);
829 }
830 let dep_refs: Vec<&str> = dep_names.iter().map(|s| s.as_str()).collect();
831 graph
832 .add_task("sink", &dep_refs, |r| Ok(r.values().sum::<i32>()))
833 .expect("add sink");
834
835 let results = graph.execute().expect("execute failed");
836 let expected_sum: i32 = (0..n as i32).sum();
837 assert_eq!(results.get("sink"), Some(&expected_sum));
838 }
839
840 #[test]
841 fn test_dependencies_accessor() {
842 let mut graph = TaskGraph::<i32>::new();
843 graph.add_task("root", &[], |_| Ok(0)).expect("add root");
844 graph
845 .add_task("child", &["root"], |_| Ok(1))
846 .expect("add child");
847
848 let deps = graph.dependencies("child").expect("should exist");
849 assert_eq!(deps, &["root".to_string()]);
850 assert!(graph.dependencies("nonexistent").is_none());
851 }
852}