1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use taskflow_rs::{
5 Task, TaskDefinition, TaskFlow, TaskResult, error::TaskFlowError, framework::TaskFlowConfig,
6 task::TaskHandler,
7};
8use tracing_subscriber::fmt::init;
9
10struct MathTaskHandler;
11
12#[async_trait]
13impl TaskHandler for MathTaskHandler {
14 async fn execute(&self, task: &Task) -> Result<TaskResult, TaskFlowError> {
15 let start_time = std::time::Instant::now();
16
17 let operation = task
18 .definition
19 .payload
20 .get("operation")
21 .and_then(|v| v.as_str())
22 .ok_or_else(|| {
23 TaskFlowError::InvalidConfiguration("Missing 'operation' in payload".to_string())
24 })?;
25
26 let a = task
27 .definition
28 .payload
29 .get("a")
30 .and_then(|v| v.as_f64())
31 .ok_or_else(|| {
32 TaskFlowError::InvalidConfiguration("Missing 'a' in payload".to_string())
33 })?;
34
35 let b = task
36 .definition
37 .payload
38 .get("b")
39 .and_then(|v| v.as_f64())
40 .ok_or_else(|| {
41 TaskFlowError::InvalidConfiguration("Missing 'b' in payload".to_string())
42 })?;
43
44 let result = match operation {
45 "add" => a + b,
46 "subtract" => a - b,
47 "multiply" => a * b,
48 "divide" => {
49 if b == 0.0 {
50 return Ok(TaskResult {
51 success: false,
52 output: None,
53 error: Some("Division by zero".to_string()),
54 execution_time_ms: start_time.elapsed().as_millis() as u64,
55 metadata: HashMap::new(),
56 });
57 }
58 a / b
59 }
60 _ => {
61 return Ok(TaskResult {
62 success: false,
63 output: None,
64 error: Some(format!("Unknown operation: {}", operation)),
65 execution_time_ms: start_time.elapsed().as_millis() as u64,
66 metadata: HashMap::new(),
67 });
68 }
69 };
70
71 let execution_time = start_time.elapsed().as_millis() as u64;
72 let mut metadata = HashMap::new();
73 metadata.insert("operation".to_string(), operation.to_string());
74 metadata.insert("operand_a".to_string(), a.to_string());
75 metadata.insert("operand_b".to_string(), b.to_string());
76
77 Ok(TaskResult {
78 success: true,
79 output: Some(result.to_string()),
80 error: None,
81 execution_time_ms: execution_time,
82 metadata,
83 })
84 }
85
86 fn task_type(&self) -> &str {
87 "math_operation"
88 }
89}
90
91struct DataProcessingHandler;
92
93#[async_trait]
94impl TaskHandler for DataProcessingHandler {
95 async fn execute(&self, task: &Task) -> Result<TaskResult, TaskFlowError> {
96 let start_time = std::time::Instant::now();
97
98 let data = task
99 .definition
100 .payload
101 .get("data")
102 .and_then(|v| v.as_array())
103 .ok_or_else(|| {
104 TaskFlowError::InvalidConfiguration("Missing 'data' array in payload".to_string())
105 })?;
106
107 let operation = task
108 .definition
109 .payload
110 .get("operation")
111 .and_then(|v| v.as_str())
112 .unwrap_or("sum");
113
114 let numbers: Result<Vec<f64>, _> = data
115 .iter()
116 .map(|v| {
117 v.as_f64().ok_or_else(|| {
118 TaskFlowError::InvalidConfiguration(
119 "All data elements must be numbers".to_string(),
120 )
121 })
122 })
123 .collect();
124
125 let numbers = numbers?;
126
127 let result = match operation {
128 "sum" => numbers.iter().sum::<f64>(),
129 "average" => {
130 if numbers.is_empty() {
131 0.0
132 } else {
133 numbers.iter().sum::<f64>() / numbers.len() as f64
134 }
135 }
136 "max" => numbers.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
137 "min" => numbers.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
138 _ => {
139 return Ok(TaskResult {
140 success: false,
141 output: None,
142 error: Some(format!("Unknown operation: {}", operation)),
143 execution_time_ms: start_time.elapsed().as_millis() as u64,
144 metadata: HashMap::new(),
145 });
146 }
147 };
148
149 let execution_time = start_time.elapsed().as_millis() as u64;
150 let mut metadata = HashMap::new();
151 metadata.insert("operation".to_string(), operation.to_string());
152 metadata.insert("data_count".to_string(), numbers.len().to_string());
153
154 Ok(TaskResult {
155 success: true,
156 output: Some(result.to_string()),
157 error: None,
158 execution_time_ms: execution_time,
159 metadata,
160 })
161 }
162
163 fn task_type(&self) -> &str {
164 "data_processing"
165 }
166}
167
168#[tokio::main]
169async fn main() -> Result<(), Box<dyn std::error::Error>> {
170 init();
171
172 let config = TaskFlowConfig::with_in_memory();
173 let taskflow = TaskFlow::new(config).await?;
174
175 taskflow.register_handler(Arc::new(MathTaskHandler)).await;
176 taskflow
177 .register_handler(Arc::new(DataProcessingHandler))
178 .await;
179
180 println!("TaskFlow with custom handlers started!");
181
182 let add_task = TaskDefinition::new("addition", "math_operation")
183 .with_payload("operation", serde_json::Value::String("add".to_string()))
184 .with_payload("a", serde_json::Value::Number(serde_json::Number::from(10)))
185 .with_payload("b", serde_json::Value::Number(serde_json::Number::from(5)));
186
187 let add_task_id = taskflow.submit_task(add_task).await?;
188 println!("Submitted addition task: {}", add_task_id);
189
190 let multiply_task = TaskDefinition::new("multiplication", "math_operation")
191 .with_payload(
192 "operation",
193 serde_json::Value::String("multiply".to_string()),
194 )
195 .with_payload("a", serde_json::Value::Number(serde_json::Number::from(7)))
196 .with_payload("b", serde_json::Value::Number(serde_json::Number::from(3)));
197
198 let multiply_task_id = taskflow.submit_task(multiply_task).await?;
199 println!("Submitted multiplication task: {}", multiply_task_id);
200
201 let data_array = vec![
202 serde_json::Value::Number(serde_json::Number::from(1)),
203 serde_json::Value::Number(serde_json::Number::from(2)),
204 serde_json::Value::Number(serde_json::Number::from(3)),
205 serde_json::Value::Number(serde_json::Number::from(4)),
206 serde_json::Value::Number(serde_json::Number::from(5)),
207 ];
208
209 let sum_task = TaskDefinition::new("sum_data", "data_processing")
210 .with_payload("operation", serde_json::Value::String("sum".to_string()))
211 .with_payload("data", serde_json::Value::Array(data_array.clone()));
212
213 let sum_task_id = taskflow.submit_task(sum_task).await?;
214 println!("Submitted sum task: {}", sum_task_id);
215
216 let avg_task = TaskDefinition::new("average_data", "data_processing")
217 .with_payload(
218 "operation",
219 serde_json::Value::String("average".to_string()),
220 )
221 .with_payload("data", serde_json::Value::Array(data_array))
222 .with_dependencies(vec![sum_task_id.clone()]);
223
224 let avg_task_id = taskflow.submit_task(avg_task).await?;
225 println!("Submitted average task (depends on sum): {}", avg_task_id);
226
227 let taskflow_clone = std::sync::Arc::new(taskflow);
228 let taskflow_for_execution = taskflow_clone.clone();
229
230 let execution_handle = tokio::spawn(async move {
231 if let Err(e) = taskflow_for_execution.start().await {
232 eprintln!("TaskFlow execution failed: {}", e);
233 }
234 });
235
236 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
237
238 loop {
239 let metrics = taskflow_clone.get_task_metrics().await?;
240 println!(
241 "Task metrics: pending={}, running={}, completed={}, failed={}",
242 metrics.pending, metrics.running, metrics.completed, metrics.failed
243 );
244
245 if metrics.pending == 0 && metrics.running == 0 {
246 break;
247 }
248
249 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
250 }
251
252 println!("\nAll tasks completed! Results:");
253
254 let tasks = taskflow_clone.list_tasks(None).await?;
255 for task in tasks {
256 println!(
257 "\nTask: {} ({})",
258 task.definition.name, task.definition.task_type
259 );
260 println!(" Status: {:?}", task.status);
261 if let Some(result) = &task.result {
262 if result.success {
263 println!(
264 " Result: {}",
265 result.output.as_ref().unwrap_or(&"No output".to_string())
266 );
267 println!(" Execution time: {}ms", result.execution_time_ms);
268 if !result.metadata.is_empty() {
269 println!(" Metadata: {:?}", result.metadata);
270 }
271 } else {
272 println!(
273 " Error: {}",
274 result
275 .error
276 .as_ref()
277 .unwrap_or(&"Unknown error".to_string())
278 );
279 }
280 }
281 }
282
283 let final_metrics = taskflow_clone.get_task_metrics().await?;
284 println!("\nFinal metrics:");
285 println!(" Total tasks: {}", final_metrics.total);
286 println!(
287 " Success rate: {:.1}%",
288 final_metrics.success_rate() * 100.0
289 );
290
291 taskflow_clone.shutdown().await?;
292 execution_handle.abort();
293
294 Ok(())
295}