custom_handler/
custom_handler.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use taskflow_rs::{
5    Task, TaskDefinition, TaskFlow, TaskResult, error::TaskFlowError, framework::TaskFlowConfig,
6    task::TaskHandler,
7};
8use tracing_subscriber::fmt::init;
9
10struct MathTaskHandler;
11
12#[async_trait]
13impl TaskHandler for MathTaskHandler {
14    async fn execute(&self, task: &Task) -> Result<TaskResult, TaskFlowError> {
15        let start_time = std::time::Instant::now();
16
17        let operation = task
18            .definition
19            .payload
20            .get("operation")
21            .and_then(|v| v.as_str())
22            .ok_or_else(|| {
23                TaskFlowError::InvalidConfiguration("Missing 'operation' in payload".to_string())
24            })?;
25
26        let a = task
27            .definition
28            .payload
29            .get("a")
30            .and_then(|v| v.as_f64())
31            .ok_or_else(|| {
32                TaskFlowError::InvalidConfiguration("Missing 'a' in payload".to_string())
33            })?;
34
35        let b = task
36            .definition
37            .payload
38            .get("b")
39            .and_then(|v| v.as_f64())
40            .ok_or_else(|| {
41                TaskFlowError::InvalidConfiguration("Missing 'b' in payload".to_string())
42            })?;
43
44        let result = match operation {
45            "add" => a + b,
46            "subtract" => a - b,
47            "multiply" => a * b,
48            "divide" => {
49                if b == 0.0 {
50                    return Ok(TaskResult {
51                        success: false,
52                        output: None,
53                        error: Some("Division by zero".to_string()),
54                        execution_time_ms: start_time.elapsed().as_millis() as u64,
55                        metadata: HashMap::new(),
56                    });
57                }
58                a / b
59            }
60            _ => {
61                return Ok(TaskResult {
62                    success: false,
63                    output: None,
64                    error: Some(format!("Unknown operation: {}", operation)),
65                    execution_time_ms: start_time.elapsed().as_millis() as u64,
66                    metadata: HashMap::new(),
67                });
68            }
69        };
70
71        let execution_time = start_time.elapsed().as_millis() as u64;
72        let mut metadata = HashMap::new();
73        metadata.insert("operation".to_string(), operation.to_string());
74        metadata.insert("operand_a".to_string(), a.to_string());
75        metadata.insert("operand_b".to_string(), b.to_string());
76
77        Ok(TaskResult {
78            success: true,
79            output: Some(result.to_string()),
80            error: None,
81            execution_time_ms: execution_time,
82            metadata,
83        })
84    }
85
86    fn task_type(&self) -> &str {
87        "math_operation"
88    }
89}
90
91struct DataProcessingHandler;
92
93#[async_trait]
94impl TaskHandler for DataProcessingHandler {
95    async fn execute(&self, task: &Task) -> Result<TaskResult, TaskFlowError> {
96        let start_time = std::time::Instant::now();
97
98        let data = task
99            .definition
100            .payload
101            .get("data")
102            .and_then(|v| v.as_array())
103            .ok_or_else(|| {
104                TaskFlowError::InvalidConfiguration("Missing 'data' array in payload".to_string())
105            })?;
106
107        let operation = task
108            .definition
109            .payload
110            .get("operation")
111            .and_then(|v| v.as_str())
112            .unwrap_or("sum");
113
114        let numbers: Result<Vec<f64>, _> = data
115            .iter()
116            .map(|v| {
117                v.as_f64().ok_or_else(|| {
118                    TaskFlowError::InvalidConfiguration(
119                        "All data elements must be numbers".to_string(),
120                    )
121                })
122            })
123            .collect();
124
125        let numbers = numbers?;
126
127        let result = match operation {
128            "sum" => numbers.iter().sum::<f64>(),
129            "average" => {
130                if numbers.is_empty() {
131                    0.0
132                } else {
133                    numbers.iter().sum::<f64>() / numbers.len() as f64
134                }
135            }
136            "max" => numbers.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
137            "min" => numbers.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
138            _ => {
139                return Ok(TaskResult {
140                    success: false,
141                    output: None,
142                    error: Some(format!("Unknown operation: {}", operation)),
143                    execution_time_ms: start_time.elapsed().as_millis() as u64,
144                    metadata: HashMap::new(),
145                });
146            }
147        };
148
149        let execution_time = start_time.elapsed().as_millis() as u64;
150        let mut metadata = HashMap::new();
151        metadata.insert("operation".to_string(), operation.to_string());
152        metadata.insert("data_count".to_string(), numbers.len().to_string());
153
154        Ok(TaskResult {
155            success: true,
156            output: Some(result.to_string()),
157            error: None,
158            execution_time_ms: execution_time,
159            metadata,
160        })
161    }
162
163    fn task_type(&self) -> &str {
164        "data_processing"
165    }
166}
167
168#[tokio::main]
169async fn main() -> Result<(), Box<dyn std::error::Error>> {
170    init();
171
172    let config = TaskFlowConfig::with_in_memory();
173    let taskflow = TaskFlow::new(config).await?;
174
175    taskflow.register_handler(Arc::new(MathTaskHandler)).await;
176    taskflow
177        .register_handler(Arc::new(DataProcessingHandler))
178        .await;
179
180    println!("TaskFlow with custom handlers started!");
181
182    let add_task = TaskDefinition::new("addition", "math_operation")
183        .with_payload("operation", serde_json::Value::String("add".to_string()))
184        .with_payload("a", serde_json::Value::Number(serde_json::Number::from(10)))
185        .with_payload("b", serde_json::Value::Number(serde_json::Number::from(5)));
186
187    let add_task_id = taskflow.submit_task(add_task).await?;
188    println!("Submitted addition task: {}", add_task_id);
189
190    let multiply_task = TaskDefinition::new("multiplication", "math_operation")
191        .with_payload(
192            "operation",
193            serde_json::Value::String("multiply".to_string()),
194        )
195        .with_payload("a", serde_json::Value::Number(serde_json::Number::from(7)))
196        .with_payload("b", serde_json::Value::Number(serde_json::Number::from(3)));
197
198    let multiply_task_id = taskflow.submit_task(multiply_task).await?;
199    println!("Submitted multiplication task: {}", multiply_task_id);
200
201    let data_array = vec![
202        serde_json::Value::Number(serde_json::Number::from(1)),
203        serde_json::Value::Number(serde_json::Number::from(2)),
204        serde_json::Value::Number(serde_json::Number::from(3)),
205        serde_json::Value::Number(serde_json::Number::from(4)),
206        serde_json::Value::Number(serde_json::Number::from(5)),
207    ];
208
209    let sum_task = TaskDefinition::new("sum_data", "data_processing")
210        .with_payload("operation", serde_json::Value::String("sum".to_string()))
211        .with_payload("data", serde_json::Value::Array(data_array.clone()));
212
213    let sum_task_id = taskflow.submit_task(sum_task).await?;
214    println!("Submitted sum task: {}", sum_task_id);
215
216    let avg_task = TaskDefinition::new("average_data", "data_processing")
217        .with_payload(
218            "operation",
219            serde_json::Value::String("average".to_string()),
220        )
221        .with_payload("data", serde_json::Value::Array(data_array))
222        .with_dependencies(vec![sum_task_id.clone()]);
223
224    let avg_task_id = taskflow.submit_task(avg_task).await?;
225    println!("Submitted average task (depends on sum): {}", avg_task_id);
226
227    let taskflow_clone = std::sync::Arc::new(taskflow);
228    let taskflow_for_execution = taskflow_clone.clone();
229
230    let execution_handle = tokio::spawn(async move {
231        if let Err(e) = taskflow_for_execution.start().await {
232            eprintln!("TaskFlow execution failed: {}", e);
233        }
234    });
235
236    tokio::time::sleep(std::time::Duration::from_secs(1)).await;
237
238    loop {
239        let metrics = taskflow_clone.get_task_metrics().await?;
240        println!(
241            "Task metrics: pending={}, running={}, completed={}, failed={}",
242            metrics.pending, metrics.running, metrics.completed, metrics.failed
243        );
244
245        if metrics.pending == 0 && metrics.running == 0 {
246            break;
247        }
248
249        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
250    }
251
252    println!("\nAll tasks completed! Results:");
253
254    let tasks = taskflow_clone.list_tasks(None).await?;
255    for task in tasks {
256        println!(
257            "\nTask: {} ({})",
258            task.definition.name, task.definition.task_type
259        );
260        println!("  Status: {:?}", task.status);
261        if let Some(result) = &task.result {
262            if result.success {
263                println!(
264                    "  Result: {}",
265                    result.output.as_ref().unwrap_or(&"No output".to_string())
266                );
267                println!("  Execution time: {}ms", result.execution_time_ms);
268                if !result.metadata.is_empty() {
269                    println!("  Metadata: {:?}", result.metadata);
270                }
271            } else {
272                println!(
273                    "  Error: {}",
274                    result
275                        .error
276                        .as_ref()
277                        .unwrap_or(&"Unknown error".to_string())
278                );
279            }
280        }
281    }
282
283    let final_metrics = taskflow_clone.get_task_metrics().await?;
284    println!("\nFinal metrics:");
285    println!("  Total tasks: {}", final_metrics.total);
286    println!(
287        "  Success rate: {:.1}%",
288        final_metrics.success_rate() * 100.0
289    );
290
291    taskflow_clone.shutdown().await?;
292    execution_handle.abort();
293
294    Ok(())
295}