taskflow_rs/framework/
mod.rs

1use crate::error::{Result, TaskFlowError};
2use crate::executor::{
3    Executor,
4    handlers::{
5        FileTaskHandler, HttpTaskHandler, NotificationTaskHandler, PythonTaskHandler,
6        ShellTaskHandler,
7    },
8};
9use crate::scheduler::Scheduler;
10use crate::storage::{InMemoryStorage, TaskStorage};
11use crate::task::{TaskDefinition, TaskHandler, TaskStatus};
12use std::path::Path;
13use std::sync::Arc;
14use tracing::{error, info};
15
16pub mod config;
17pub mod metrics;
18pub mod yaml_config;
19
20pub use config::{StorageType, TaskFlowConfig};
21pub use metrics::TaskMetrics;
22pub use yaml_config::{load_from_yaml_file, load_from_yaml_str};
23
24pub struct TaskFlow {
25    scheduler: Arc<Scheduler>,
26    executor: Arc<tokio::sync::Mutex<Executor>>,
27    storage: Arc<dyn TaskStorage>,
28}
29
30impl TaskFlow {
31    pub async fn new(config: TaskFlowConfig) -> Result<Self> {
32        let storage: Arc<dyn TaskStorage> = match config.storage_type {
33            StorageType::InMemory => {
34                info!("Using in-memory storage");
35                Arc::new(InMemoryStorage::new())
36            }
37        };
38
39        let scheduler = Arc::new(Scheduler::new(Arc::clone(&storage), config.scheduler));
40        let mut executor = Executor::new(Arc::clone(&scheduler), config.executor);
41
42        executor.register_handler(Arc::new(HttpTaskHandler::new()));
43        executor.register_handler(Arc::new(ShellTaskHandler::new()));
44        executor.register_handler(Arc::new(PythonTaskHandler::new()));
45        executor.register_handler(Arc::new(FileTaskHandler::new()));
46        executor.register_handler(Arc::new(NotificationTaskHandler::new()));
47
48        Ok(Self {
49            scheduler,
50            executor: Arc::new(tokio::sync::Mutex::new(executor)),
51            storage,
52        })
53    }
54
55    pub async fn from_yaml_file<P: AsRef<Path>>(path: P) -> Result<Self> {
56        let config = load_from_yaml_file(path)?;
57        Self::new(config).await
58    }
59
60    pub async fn from_yaml_str(config_content: &str) -> Result<Self> {
61        let config = load_from_yaml_str(config_content)?;
62        Self::new(config).await
63    }
64
65    pub async fn register_handler(&self, handler: Arc<dyn TaskHandler>) {
66        let mut executor = self.executor.lock().await;
67        executor.register_handler(handler);
68    }
69
70    pub async fn submit_task(&self, definition: TaskDefinition) -> Result<String> {
71        self.scheduler.submit_task(definition).await
72    }
73
74    pub async fn submit_http_task(
75        &self,
76        name: &str,
77        url: &str,
78        method: Option<&str>,
79    ) -> Result<String> {
80        let mut definition = TaskDefinition::new(name, "http_request")
81            .with_payload("url", serde_json::Value::String(url.to_string()));
82
83        if let Some(method) = method {
84            definition =
85                definition.with_payload("method", serde_json::Value::String(method.to_string()));
86        }
87
88        self.submit_task(definition).await
89    }
90
91    pub async fn submit_shell_task(
92        &self,
93        name: &str,
94        command: &str,
95        args: Vec<&str>,
96    ) -> Result<String> {
97        let args_json: Vec<serde_json::Value> = args
98            .into_iter()
99            .map(|arg| serde_json::Value::String(arg.to_string()))
100            .collect();
101
102        let definition = TaskDefinition::new(name, "shell_command")
103            .with_payload("command", serde_json::Value::String(command.to_string()))
104            .with_payload("args", serde_json::Value::Array(args_json));
105
106        self.submit_task(definition).await
107    }
108
109    pub async fn get_task_status(&self, task_id: &str) -> Result<Option<TaskStatus>> {
110        self.scheduler.get_task_status(task_id).await
111    }
112
113    pub async fn cancel_task(&self, task_id: &str) -> Result<()> {
114        self.scheduler.cancel_task(task_id).await
115    }
116
117    pub async fn list_tasks(&self, status: Option<TaskStatus>) -> Result<Vec<crate::task::Task>> {
118        self.scheduler.list_tasks(status).await
119    }
120
121    pub async fn start(&self) -> Result<()> {
122        info!("Starting TaskFlow framework");
123
124        let scheduler_clone = Arc::clone(&self.scheduler);
125        let scheduler_task = tokio::spawn(async move {
126            if let Err(e) = scheduler_clone.start().await {
127                error!("Scheduler failed: {}", e);
128            }
129        });
130
131        let executor_clone = Arc::clone(&self.executor);
132        let executor_task = tokio::spawn(async move {
133            let executor = executor_clone.lock().await;
134            if let Err(e) = executor.start().await {
135                error!("Executor failed: {}", e);
136            }
137        });
138
139        tokio::select! {
140            result = scheduler_task => {
141                if let Err(e) = result {
142                    error!("Scheduler task panicked: {}", e);
143                }
144            }
145            result = executor_task => {
146                if let Err(e) = result {
147                    error!("Executor task panicked: {}", e);
148                }
149            }
150        }
151
152        Ok(())
153    }
154
155    pub async fn shutdown(&self) -> Result<()> {
156        info!("Shutting down TaskFlow framework");
157
158        let executor = self.executor.lock().await;
159        executor.shutdown().await;
160
161        Ok(())
162    }
163
164    pub async fn wait_for_completion(
165        &self,
166        task_id: &str,
167        timeout_seconds: Option<u64>,
168    ) -> Result<crate::task::Task> {
169        let timeout_duration = timeout_seconds.unwrap_or(300);
170        let start_time = std::time::Instant::now();
171
172        loop {
173            if let Some(task) = self.storage.get_task(task_id).await? {
174                if task.is_finished() {
175                    return Ok(task);
176                }
177            } else {
178                return Err(TaskFlowError::TaskNotFound(task_id.to_string()));
179            }
180
181            if start_time.elapsed().as_secs() > timeout_duration {
182                return Err(TaskFlowError::TimeoutError);
183            }
184
185            tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
186        }
187    }
188
189    pub async fn get_task_metrics(&self) -> Result<TaskMetrics> {
190        let pending = self.list_tasks(Some(TaskStatus::Pending)).await?.len();
191        let running = self.list_tasks(Some(TaskStatus::Running)).await?.len();
192        let completed = self.list_tasks(Some(TaskStatus::Completed)).await?.len();
193        let failed = self.list_tasks(Some(TaskStatus::Failed)).await?.len();
194        let cancelled = self.list_tasks(Some(TaskStatus::Cancelled)).await?.len();
195
196        Ok(TaskMetrics {
197            pending,
198            running,
199            completed,
200            failed,
201            cancelled,
202            total: pending + running + completed + failed + cancelled,
203        })
204    }
205}