1use std::collections::{HashMap, HashSet, VecDeque};
35use std::fmt;
36use std::sync::{Arc, Mutex};
37use std::time::{Duration, Instant};
38
39use crate::error::{CoreError, CoreResult, ErrorContext};
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
47pub struct TaskId(usize);
48
49impl fmt::Display for TaskId {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(f, "Task({})", self.0)
52 }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum TaskStatus {
62 Success,
64 Skipped,
66 Failed(String),
68}
69
70#[derive(Debug, Clone)]
76pub struct TaskResult<T: Clone> {
77 pub task_id: TaskId,
79 pub task_name: String,
81 pub value: Option<T>,
83 pub status: TaskStatus,
85 pub elapsed: Duration,
87 pub started_at: Instant,
89}
90
91pub struct TaskNode<T: Clone + Send + 'static> {
99 id: TaskId,
100 name: String,
101 compute: Box<dyn Fn() -> T + Send + Sync>,
102 estimated_ms: u64,
104 memory_bytes: usize,
106}
107
108impl<T: Clone + Send + 'static> TaskNode<T> {
109 pub fn new<F>(id: TaskId, name: impl Into<String>, f: F) -> Self
111 where
112 F: Fn() -> T + Send + Sync + 'static,
113 {
114 Self {
115 id,
116 name: name.into(),
117 compute: Box::new(f),
118 estimated_ms: 1,
119 memory_bytes: 0,
120 }
121 }
122
123 pub fn with_estimated_ms(mut self, ms: u64) -> Self {
125 self.estimated_ms = ms;
126 self
127 }
128
129 pub fn with_memory_bytes(mut self, bytes: usize) -> Self {
131 self.memory_bytes = bytes;
132 self
133 }
134
135 fn execute(&self) -> TaskResult<T> {
137 let started_at = Instant::now();
138 let value = (self.compute)();
139 let elapsed = started_at.elapsed();
140 TaskResult {
141 task_id: self.id,
142 task_name: self.name.clone(),
143 value: Some(value),
144 status: TaskStatus::Success,
145 elapsed,
146 started_at,
147 }
148 }
149}
150
151pub struct TaskGraph<T: Clone + Send + 'static> {
160 nodes: HashMap<TaskId, TaskNode<T>>,
161 deps: HashMap<TaskId, HashSet<TaskId>>,
163 dependents: HashMap<TaskId, HashSet<TaskId>>,
165 next_id: usize,
166}
167
168impl<T: Clone + Send + 'static> TaskGraph<T> {
169 pub fn new() -> Self {
171 Self {
172 nodes: HashMap::new(),
173 deps: HashMap::new(),
174 dependents: HashMap::new(),
175 next_id: 0,
176 }
177 }
178
179 pub fn add_task<F>(&mut self, name: impl Into<String>, f: F) -> TaskId
181 where
182 F: Fn() -> T + Send + Sync + 'static,
183 {
184 let id = TaskId(self.next_id);
185 self.next_id += 1;
186 let node = TaskNode::new(id, name, f);
187 self.nodes.insert(id, node);
188 self.deps.insert(id, HashSet::new());
189 self.dependents.insert(id, HashSet::new());
190 id
191 }
192
193 pub fn add_node(&mut self, node: TaskNode<T>) -> TaskId {
195 let id = node.id;
196 self.nodes.insert(id, node);
197 self.deps.entry(id).or_default();
198 self.dependents.entry(id).or_default();
199 id
200 }
201
202 pub fn add_dependency(&mut self, dependent: TaskId, dependency: TaskId) -> CoreResult<()> {
207 if !self.nodes.contains_key(&dependent) {
208 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
209 "add_dependency: {dependent} not found"
210 ))));
211 }
212 if !self.nodes.contains_key(&dependency) {
213 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
214 "add_dependency: {dependency} not found"
215 ))));
216 }
217 if dependent == dependency {
218 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
219 "add_dependency: self-loop on {dependent}"
220 ))));
221 }
222 if self.is_reachable(dependency, dependent) {
225 return Err(CoreError::InvalidInput(ErrorContext::new(format!(
226 "add_dependency: cycle detected ({dependency} already depends on {dependent})"
227 ))));
228 }
229 self.deps.entry(dependent).or_default().insert(dependency);
230 self.dependents
231 .entry(dependency)
232 .or_default()
233 .insert(dependent);
234 Ok(())
235 }
236
237 fn is_reachable(&self, from: TaskId, target: TaskId) -> bool {
239 let mut visited = HashSet::new();
240 let mut queue = VecDeque::new();
241 queue.push_back(from);
242 while let Some(current) = queue.pop_front() {
243 if current == target {
244 return true;
245 }
246 if visited.contains(¤t) {
247 continue;
248 }
249 visited.insert(current);
250 if let Some(deps) = self.deps.get(¤t) {
251 for dep in deps {
252 if !visited.contains(dep) {
253 queue.push_back(*dep);
254 }
255 }
256 }
257 }
258 false
259 }
260
261 pub fn topological_order(&self) -> CoreResult<Vec<TaskId>> {
264 let mut in_degree: HashMap<TaskId, usize> = self
265 .nodes
266 .keys()
267 .map(|id| (*id, self.deps[id].len()))
268 .collect();
269
270 let mut ready: VecDeque<TaskId> = in_degree
271 .iter()
272 .filter(|(_, °)| deg == 0)
273 .map(|(id, _)| *id)
274 .collect();
275
276 let mut order = Vec::with_capacity(self.nodes.len());
277
278 while let Some(id) = ready.pop_front() {
279 order.push(id);
280 if let Some(children) = self.dependents.get(&id) {
281 for child in children {
282 let deg = in_degree.entry(*child).or_insert(0);
283 if *deg > 0 {
284 *deg -= 1;
285 }
286 if *deg == 0 {
287 ready.push_back(*child);
288 }
289 }
290 }
291 }
292
293 if order.len() != self.nodes.len() {
294 return Err(CoreError::InvalidInput(ErrorContext::new(
295 "topological_order: cycle detected",
296 )));
297 }
298 Ok(order)
299 }
300
301 pub fn len(&self) -> usize {
303 self.nodes.len()
304 }
305
306 pub fn is_empty(&self) -> bool {
308 self.nodes.is_empty()
309 }
310
311 pub fn dependencies(&self, id: TaskId) -> Option<&HashSet<TaskId>> {
313 self.deps.get(&id)
314 }
315
316 pub fn dependents_of(&self, id: TaskId) -> Option<&HashSet<TaskId>> {
318 self.dependents.get(&id)
319 }
320}
321
322pub struct CriticalPath {
331 pub path: Vec<TaskId>,
333 pub total_estimated_ms: u64,
335}
336
337impl CriticalPath {
338 pub fn compute<T: Clone + Send + 'static>(graph: &TaskGraph<T>) -> CoreResult<Self> {
342 let order = graph.topological_order()?;
343
344 let mut earliest_finish: HashMap<TaskId, u64> = HashMap::new();
346 let mut predecessor: HashMap<TaskId, Option<TaskId>> = HashMap::new();
348
349 for &id in &order {
350 let node = &graph.nodes[&id];
351 let max_pred_finish = graph
352 .deps
353 .get(&id)
354 .map(|deps| {
355 deps.iter()
356 .map(|d| earliest_finish.get(d).copied().unwrap_or(0))
357 .max()
358 .unwrap_or(0)
359 })
360 .unwrap_or(0);
361
362 let ef = max_pred_finish + node.estimated_ms;
363 earliest_finish.insert(id, ef);
364
365 let pred = graph.deps.get(&id).and_then(|deps| {
367 deps.iter()
368 .max_by_key(|d| earliest_finish.get(d).copied().unwrap_or(0))
369 .copied()
370 });
371 predecessor.insert(id, pred);
372 }
373
374 let sink = earliest_finish
376 .iter()
377 .max_by_key(|(_, &ef)| ef)
378 .map(|(id, _)| *id);
379
380 let total_ms = sink
381 .and_then(|id| earliest_finish.get(&id).copied())
382 .unwrap_or(0);
383
384 let mut path = Vec::new();
386 let mut current = sink;
387 while let Some(id) = current {
388 path.push(id);
389 current = predecessor.get(&id).and_then(|opt| *opt);
390 }
391 path.reverse();
392
393 Ok(CriticalPath {
394 path,
395 total_estimated_ms: total_ms,
396 })
397 }
398}
399
400pub struct TopologicalScheduler<T: Clone + Send + 'static> {
411 graph: TaskGraph<T>,
412}
413
414impl<T: Clone + Send + 'static> TopologicalScheduler<T> {
415 pub fn new(graph: TaskGraph<T>) -> Self {
417 Self { graph }
418 }
419
420 pub fn run_serial(&self) -> CoreResult<Vec<TaskResult<T>>> {
425 let order = self.graph.topological_order()?;
426 let mut results: HashMap<TaskId, TaskResult<T>> = HashMap::new();
427
428 for id in &order {
429 let any_dep_failed = self
431 .graph
432 .deps
433 .get(id)
434 .map(|deps| {
435 deps.iter().any(|d| {
436 results
437 .get(d)
438 .map(|r| r.status != TaskStatus::Success)
439 .unwrap_or(false)
440 })
441 })
442 .unwrap_or(false);
443
444 let node = &self.graph.nodes[id];
445 let result = if any_dep_failed {
446 TaskResult {
447 task_id: *id,
448 task_name: node.name.clone(),
449 value: None,
450 status: TaskStatus::Skipped,
451 elapsed: Duration::ZERO,
452 started_at: Instant::now(),
453 }
454 } else {
455 node.execute()
456 };
457 results.insert(*id, result);
458 }
459
460 Ok(order
462 .into_iter()
463 .filter_map(|id| results.remove(&id))
464 .collect())
465 }
466
467 pub fn run_parallel(&self) -> CoreResult<Vec<TaskResult<T>>> {
472 #[cfg(feature = "parallel")]
473 {
474 self.run_parallel_impl()
475 }
476 #[cfg(not(feature = "parallel"))]
477 {
478 self.run_serial()
479 }
480 }
481
482 #[cfg(feature = "parallel")]
483 fn run_parallel_impl(&self) -> CoreResult<Vec<TaskResult<T>>> {
484 use rayon::prelude::*;
485
486 let order = self.graph.topological_order()?;
487 let results_map: Arc<Mutex<HashMap<TaskId, TaskResult<T>>>> =
488 Arc::new(Mutex::new(HashMap::new()));
489
490 let mut remaining: HashSet<TaskId> = order.iter().cloned().collect();
492 let mut all_results: Vec<TaskResult<T>> = Vec::new();
493
494 while !remaining.is_empty() {
495 let completed: HashSet<TaskId> = {
497 let rm = results_map.lock().map_err(|_| {
498 CoreError::InvalidInput(ErrorContext::new("parallel_run: mutex poisoned"))
499 })?;
500 rm.keys().cloned().collect()
501 };
502
503 let wave: Vec<TaskId> = remaining
504 .iter()
505 .filter(|id| {
506 self.graph
507 .deps
508 .get(id)
509 .map(|deps| deps.iter().all(|d| completed.contains(d)))
510 .unwrap_or(true)
511 })
512 .cloned()
513 .collect();
514
515 if wave.is_empty() {
516 return Err(CoreError::InvalidInput(ErrorContext::new(
518 "parallel_run: deadlock — no runnable tasks remain",
519 )));
520 }
521
522 let wave_results: Vec<TaskResult<T>> = wave
524 .par_iter()
525 .map(|id| {
526 let any_dep_failed = self
527 .graph
528 .deps
529 .get(id)
530 .map(|deps| {
531 let rm = results_map.lock().ok();
532 deps.iter().any(|d| {
533 rm.as_ref()
534 .and_then(|r| r.get(d))
535 .map(|r| r.status != TaskStatus::Success)
536 .unwrap_or(false)
537 })
538 })
539 .unwrap_or(false);
540
541 let node = &self.graph.nodes[id];
542 if any_dep_failed {
543 TaskResult {
544 task_id: *id,
545 task_name: node.name.clone(),
546 value: None,
547 status: TaskStatus::Skipped,
548 elapsed: Duration::ZERO,
549 started_at: Instant::now(),
550 }
551 } else {
552 node.execute()
553 }
554 })
555 .collect();
556
557 {
559 let mut rm = results_map.lock().map_err(|_| {
560 CoreError::InvalidInput(ErrorContext::new(
561 "parallel_run: mutex poisoned (merge)",
562 ))
563 })?;
564 for r in &wave_results {
565 rm.insert(r.task_id, r.clone());
566 }
567 }
568
569 for id in &wave {
570 remaining.remove(id);
571 }
572 all_results.extend(wave_results);
573 }
574
575 Ok(all_results)
576 }
577
578 pub fn into_graph(self) -> TaskGraph<T> {
580 self.graph
581 }
582}
583
584#[derive(Debug, Clone)]
590pub struct ResourceConstraints {
591 pub max_concurrent: usize,
593 pub max_memory_bytes: usize,
595}
596
597impl Default for ResourceConstraints {
598 fn default() -> Self {
599 Self {
600 max_concurrent: 4,
601 max_memory_bytes: 1 << 30, }
603 }
604}
605
606pub struct ResourceConstrainedScheduler<T: Clone + Send + 'static> {
611 graph: TaskGraph<T>,
612 constraints: ResourceConstraints,
613}
614
615impl<T: Clone + Send + 'static> ResourceConstrainedScheduler<T> {
616 pub fn new(graph: TaskGraph<T>, constraints: ResourceConstraints) -> Self {
618 Self { graph, constraints }
619 }
620
621 pub fn run(&self) -> CoreResult<Vec<TaskResult<T>>> {
628 let order = self.graph.topological_order()?;
629 let mut completed: HashSet<TaskId> = HashSet::new();
630 let mut results: Vec<TaskResult<T>> = Vec::new();
631 let mut remaining: Vec<TaskId> = order;
632 let mut in_flight_memory: usize = 0;
633
634 loop {
635 let ready_idx = remaining.iter().position(|id| {
637 let deps_done = self
638 .graph
639 .deps
640 .get(id)
641 .map(|deps| deps.iter().all(|d| completed.contains(d)))
642 .unwrap_or(true);
643 if !deps_done {
644 return false;
645 }
646 let mem = self
647 .graph
648 .nodes
649 .get(id)
650 .map(|n| n.memory_bytes)
651 .unwrap_or(0);
652 in_flight_memory + mem <= self.constraints.max_memory_bytes
653 });
654
655 match ready_idx {
656 None => {
657 if remaining.is_empty() {
658 break;
659 }
660 let fallback = remaining.iter().position(|id| {
663 self.graph
664 .deps
665 .get(id)
666 .map(|deps| deps.iter().all(|d| completed.contains(d)))
667 .unwrap_or(true)
668 });
669 match fallback {
670 None => break, Some(idx) => {
672 let id = remaining.remove(idx);
673 let node = &self.graph.nodes[&id];
674 let mem = node.memory_bytes;
675 in_flight_memory = in_flight_memory.saturating_add(mem);
676 let r = node.execute();
677 in_flight_memory = in_flight_memory.saturating_sub(mem);
678 completed.insert(id);
679 results.push(r);
680 }
681 }
682 }
683 Some(idx) => {
684 let id = remaining.remove(idx);
685 let node = &self.graph.nodes[&id];
686 let mem = node.memory_bytes;
687 in_flight_memory = in_flight_memory.saturating_add(mem);
688 let r = node.execute();
689 in_flight_memory = in_flight_memory.saturating_sub(mem);
690 completed.insert(id);
691 results.push(r);
692 }
693 }
694 }
695
696 Ok(results)
697 }
698}
699
700pub mod dependency_graph;
702
703#[cfg(test)]
708mod tests {
709 use super::*;
710
711 fn build_linear_graph() -> TaskGraph<u64> {
712 let mut g = TaskGraph::new();
713 let t1 = g.add_task("a", || 1u64);
714 let t2 = g.add_task("b", || 2u64);
715 let t3 = g.add_task("c", || 3u64);
716 g.add_dependency(t2, t1).expect("dep b→a");
717 g.add_dependency(t3, t2).expect("dep c→b");
718 g
719 }
720
721 #[test]
722 fn topological_order_linear() {
723 let g = build_linear_graph();
724 let order = g.topological_order().expect("acyclic");
725 assert_eq!(order.len(), 3);
726 }
727
728 #[test]
729 fn cycle_detection() {
730 let mut g: TaskGraph<u64> = TaskGraph::new();
731 let a = g.add_task("a", || 0u64);
732 let b = g.add_task("b", || 0u64);
733 g.add_dependency(b, a).expect("b→a");
734 assert!(g.add_dependency(a, b).is_err(), "cycle should be rejected");
735 }
736
737 #[test]
738 fn topological_scheduler_serial() {
739 let g = build_linear_graph();
740 let sched = TopologicalScheduler::new(g);
741 let results = sched.run_serial().expect("serial run");
742 assert_eq!(results.len(), 3);
743 assert!(results.iter().all(|r| r.status == TaskStatus::Success));
744 let names: Vec<&str> = results.iter().map(|r| r.task_name.as_str()).collect();
745 assert_eq!(names, vec!["a", "b", "c"]);
746 }
747
748 #[test]
749 fn topological_scheduler_parallel() {
750 let g = build_linear_graph();
751 let sched = TopologicalScheduler::new(g);
752 let results = sched.run_parallel().expect("parallel run");
753 assert_eq!(results.len(), 3);
754 }
755
756 #[test]
757 fn critical_path_linear() {
758 let mut g: TaskGraph<u64> = TaskGraph::new();
759 let t1id = TaskId(0);
760 let t2id = TaskId(1);
761 let t3id = TaskId(2);
762 g.next_id = 3;
763 g.nodes.insert(
764 t1id,
765 TaskNode::new(t1id, "a", || 0u64).with_estimated_ms(10),
766 );
767 g.nodes.insert(
768 t2id,
769 TaskNode::new(t2id, "b", || 0u64).with_estimated_ms(20),
770 );
771 g.nodes.insert(
772 t3id,
773 TaskNode::new(t3id, "c", || 0u64).with_estimated_ms(15),
774 );
775 g.deps.insert(t1id, HashSet::new());
776 g.deps.insert(t2id, {
777 let mut s = HashSet::new();
778 s.insert(t1id);
779 s
780 });
781 g.deps.insert(t3id, {
782 let mut s = HashSet::new();
783 s.insert(t2id);
784 s
785 });
786 g.dependents.insert(t1id, {
787 let mut s = HashSet::new();
788 s.insert(t2id);
789 s
790 });
791 g.dependents.insert(t2id, {
792 let mut s = HashSet::new();
793 s.insert(t3id);
794 s
795 });
796 g.dependents.insert(t3id, HashSet::new());
797
798 let cp = CriticalPath::compute(&g).expect("critical path");
799 assert_eq!(cp.total_estimated_ms, 45, "10 + 20 + 15 = 45");
800 assert_eq!(cp.path.len(), 3);
801 }
802
803 #[test]
804 fn resource_constrained_scheduler_basic() {
805 let mut g: TaskGraph<u64> = TaskGraph::new();
806 g.add_task("a", || 1u64);
807 g.add_task("b", || 2u64);
808 g.add_task("c", || 3u64);
809
810 let sched = ResourceConstrainedScheduler::new(
811 g,
812 ResourceConstraints {
813 max_concurrent: 2,
814 max_memory_bytes: 1024,
815 },
816 );
817 let results = sched.run().expect("constrained run");
818 assert_eq!(results.len(), 3);
819 }
820
821 #[test]
822 fn skip_on_dep_failure() {
823 let mut g: TaskGraph<Result<u64, String>> = TaskGraph::new();
824 let a = g.add_task("fail", || Err::<u64, _>("error".to_string()));
825 let b = g.add_task("skip_me", || Ok::<u64, _>(42));
826 g.add_dependency(b, a).expect("b→a");
827
828 let sched = TopologicalScheduler::new(g);
833 let results = sched.run_serial().expect("run");
834 assert_eq!(results.len(), 2);
835 }
836}