taskflow_rs/scheduler/
dependency_resolver.rs1use crate::error::{Result, TaskFlowError};
2use crate::storage::TaskStorage;
3use crate::task::{Task, TaskStatus};
4use std::collections::HashSet;
5use std::pin::Pin;
6use std::sync::Arc;
7
8pub struct DependencyResolver {
9 storage: Arc<dyn TaskStorage>,
10}
11
12impl DependencyResolver {
13 pub fn new(storage: Arc<dyn TaskStorage>) -> Self {
14 Self { storage }
15 }
16
17 pub async fn validate_dependencies(&self, task: &Task) -> Result<()> {
18 for dep_id in &task.definition.dependencies {
19 if self.storage.get_task(dep_id).await?.is_none() {
20 return Err(TaskFlowError::InvalidConfiguration(format!(
21 "Dependency task not found: {}",
22 dep_id
23 )));
24 }
25 }
26
27 if self
28 .has_circular_dependency(task, &mut HashSet::new())
29 .await?
30 {
31 return Err(TaskFlowError::InvalidConfiguration(
32 "Circular dependency detected".to_string(),
33 ));
34 }
35
36 Ok(())
37 }
38
39 pub async fn are_dependencies_satisfied(&self, task: &Task) -> Result<bool> {
40 for dep_id in &task.definition.dependencies {
41 if let Some(dep_task) = self.storage.get_task(dep_id).await? {
42 if !matches!(dep_task.status, TaskStatus::Completed) {
43 return Ok(false);
44 }
45 } else {
46 return Err(TaskFlowError::TaskNotFound(dep_id.clone()));
47 }
48 }
49 Ok(true)
50 }
51
52 fn has_circular_dependency<'a>(
53 &'a self,
54 task: &'a Task,
55 visited: &'a mut HashSet<String>,
56 ) -> Pin<Box<dyn std::future::Future<Output = Result<bool>> + Send + 'a>> {
57 Box::pin(async move {
58 if visited.contains(&task.definition.id) {
59 return Ok(true);
60 }
61
62 visited.insert(task.definition.id.clone());
63
64 for dep_id in &task.definition.dependencies {
65 if let Some(dep_task) = self.storage.get_task(dep_id).await? {
66 if self.has_circular_dependency(&dep_task, visited).await? {
67 return Ok(true);
68 }
69 }
70 }
71
72 visited.remove(&task.definition.id);
73 Ok(false)
74 })
75 }
76}