taskflow_rs/scheduler/
dependency_resolver.rs

1use crate::error::{Result, TaskFlowError};
2use crate::storage::TaskStorage;
3use crate::task::{Task, TaskStatus};
4use std::collections::HashSet;
5use std::pin::Pin;
6use std::sync::Arc;
7
8pub struct DependencyResolver {
9    storage: Arc<dyn TaskStorage>,
10}
11
12impl DependencyResolver {
13    pub fn new(storage: Arc<dyn TaskStorage>) -> Self {
14        Self { storage }
15    }
16
17    pub async fn validate_dependencies(&self, task: &Task) -> Result<()> {
18        for dep_id in &task.definition.dependencies {
19            if self.storage.get_task(dep_id).await?.is_none() {
20                return Err(TaskFlowError::InvalidConfiguration(format!(
21                    "Dependency task not found: {}",
22                    dep_id
23                )));
24            }
25        }
26
27        if self
28            .has_circular_dependency(task, &mut HashSet::new())
29            .await?
30        {
31            return Err(TaskFlowError::InvalidConfiguration(
32                "Circular dependency detected".to_string(),
33            ));
34        }
35
36        Ok(())
37    }
38
39    pub async fn are_dependencies_satisfied(&self, task: &Task) -> Result<bool> {
40        for dep_id in &task.definition.dependencies {
41            if let Some(dep_task) = self.storage.get_task(dep_id).await? {
42                if !matches!(dep_task.status, TaskStatus::Completed) {
43                    return Ok(false);
44                }
45            } else {
46                return Err(TaskFlowError::TaskNotFound(dep_id.clone()));
47            }
48        }
49        Ok(true)
50    }
51
52    fn has_circular_dependency<'a>(
53        &'a self,
54        task: &'a Task,
55        visited: &'a mut HashSet<String>,
56    ) -> Pin<Box<dyn std::future::Future<Output = Result<bool>> + Send + 'a>> {
57        Box::pin(async move {
58            if visited.contains(&task.definition.id) {
59                return Ok(true);
60            }
61
62            visited.insert(task.definition.id.clone());
63
64            for dep_id in &task.definition.dependencies {
65                if let Some(dep_task) = self.storage.get_task(dep_id).await? {
66                    if self.has_circular_dependency(&dep_task, visited).await? {
67                        return Ok(true);
68                    }
69                }
70            }
71
72            visited.remove(&task.definition.id);
73            Ok(false)
74        })
75    }
76}