1use std::any::Any;
7use std::collections::HashMap;
8use std::future::Future;
9use std::pin::Pin;
10use std::time::Duration;
11
12use crossbeam_channel::{Receiver, Sender};
13use tokio::runtime::Handle;
14
15use crate::types::TaskId;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum TaskStatus {
20 Pending,
21 Running,
22 Completed,
23 Failed,
24 Timeout,
25}
26
27pub trait AsyncTask: Send + Sync + 'static {
29 fn task_type(&self) -> &str;
31
32 fn execute(&self) -> Pin<Box<dyn Future<Output = AsyncTaskResult> + Send>>;
34
35 fn timeout(&self) -> Duration {
37 Duration::from_secs(30)
38 }
39}
40
41pub struct AsyncTaskResult {
43 pub task_id: TaskId,
45 pub task_type: String,
47 pub payload: Option<Box<dyn Any + Send>>,
49 pub metadata: TaskMetadata,
51}
52
53impl AsyncTaskResult {
54 pub fn success<T: Any + Send + 'static>(
55 task_id: TaskId,
56 task_type: impl Into<String>,
57 payload: T,
58 duration: Duration,
59 ) -> Self {
60 Self {
61 task_id,
62 task_type: task_type.into(),
63 payload: Some(Box::new(payload)),
64 metadata: TaskMetadata {
65 duration,
66 status: TaskStatus::Completed,
67 error: None,
68 },
69 }
70 }
71
72 pub fn failure(
73 task_id: TaskId,
74 task_type: impl Into<String>,
75 error: String,
76 duration: Duration,
77 ) -> Self {
78 Self {
79 task_id,
80 task_type: task_type.into(),
81 payload: None,
82 metadata: TaskMetadata {
83 duration,
84 status: TaskStatus::Failed,
85 error: Some(AsyncTaskError { message: error }),
86 },
87 }
88 }
89
90 pub fn timeout(task_id: TaskId, task_type: impl Into<String>, duration: Duration) -> Self {
91 Self {
92 task_id,
93 task_type: task_type.into(),
94 payload: None,
95 metadata: TaskMetadata {
96 duration,
97 status: TaskStatus::Timeout,
98 error: Some(AsyncTaskError {
99 message: "Task timed out".to_string(),
100 }),
101 },
102 }
103 }
104}
105
106pub struct TaskMetadata {
108 pub duration: Duration,
110 pub status: TaskStatus,
112 pub error: Option<AsyncTaskError>,
114}
115
116#[derive(Debug, Clone, thiserror::Error)]
118#[error("Async task error: {message}")]
119pub struct AsyncTaskError {
120 pub message: String,
121}
122
123impl From<crate::error::SwarmError> for AsyncTaskError {
124 fn from(err: crate::error::SwarmError) -> Self {
125 Self {
126 message: err.message().to_string(),
127 }
128 }
129}
130
131impl From<AsyncTaskError> for crate::error::SwarmError {
132 fn from(err: AsyncTaskError) -> Self {
133 crate::error::SwarmError::AsyncTask {
134 message: err.message,
135 }
136 }
137}
138
139pub struct AsyncTaskSystem {
141 result_tx: Sender<AsyncTaskResult>,
143 result_rx: Receiver<AsyncTaskResult>,
145 runtime: Handle,
147 factories: HashMap<String, Box<dyn AsyncTaskFactory>>,
149}
150
151impl AsyncTaskSystem {
152 pub fn new(runtime: Handle) -> Self {
154 let (tx, rx) = crossbeam_channel::unbounded();
155 Self {
156 result_tx: tx,
157 result_rx: rx,
158 runtime,
159 factories: HashMap::new(),
160 }
161 }
162
163 pub fn spawn<T: AsyncTask>(&self, task: T) -> TaskId {
165 self.spawn_boxed(Box::new(task))
166 }
167
168 pub fn spawn_boxed(&self, task: Box<dyn AsyncTask>) -> TaskId {
170 let task_id = TaskId::new();
171 let tx = self.result_tx.clone();
172 let timeout_duration = task.timeout();
173 let task_type = task.task_type().to_string();
174
175 self.runtime.spawn(async move {
176 let start = std::time::Instant::now();
177 let result = tokio::time::timeout(timeout_duration, task.execute()).await;
178 let duration = start.elapsed();
179
180 let task_result = match result {
181 Ok(mut r) => {
182 r.task_id = task_id;
183 r.task_type = task_type;
184 r.metadata.duration = duration;
185 r
186 }
187 Err(_) => AsyncTaskResult::timeout(task_id, task_type, duration),
188 };
189 let _ = tx.send(task_result);
190 });
191
192 task_id
193 }
194
195 pub fn collect_results(&self) -> Vec<AsyncTaskResult> {
197 let mut results = Vec::new();
198 while let Ok(result) = self.result_rx.try_recv() {
199 results.push(result);
200 }
201 results
202 }
203
204 pub fn register_factory<F: AsyncTaskFactory + 'static>(&mut self, name: &str, factory: F) {
206 self.factories.insert(name.to_string(), Box::new(factory));
207 }
208
209 pub fn create_task(&self, name: &str, params: TaskParams) -> Option<Box<dyn AsyncTask>> {
211 self.factories.get(name).map(|f| f.create(params))
212 }
213}
214
215pub trait AsyncTaskFactory: Send + Sync {
217 fn create(&self, params: TaskParams) -> Box<dyn AsyncTask>;
218}
219
220#[derive(Debug, Clone, Default)]
222pub struct TaskParams {
223 pub data: HashMap<String, String>,
224}
225
226#[derive(Debug, Clone)]
228pub struct AsyncConfig {
229 pub max_concurrent: usize,
231 pub default_timeout_secs: u64,
233}
234
235impl Default for AsyncConfig {
236 fn default() -> Self {
237 Self {
238 max_concurrent: 100,
239 default_timeout_secs: 30,
240 }
241 }
242}
243
244pub struct DelayTask {
250 pub delay: Duration,
251 pub result: String,
252}
253
254impl AsyncTask for DelayTask {
255 fn task_type(&self) -> &str {
256 "delay"
257 }
258
259 fn execute(&self) -> Pin<Box<dyn Future<Output = AsyncTaskResult> + Send>> {
260 let delay = self.delay;
261 let result = self.result.clone();
262
263 Box::pin(async move {
264 tokio::time::sleep(delay).await;
265 AsyncTaskResult::success(TaskId::new(), "delay", result, delay)
266 })
267 }
268
269 fn timeout(&self) -> Duration {
270 self.delay + Duration::from_secs(1)
271 }
272}