1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::sync::Arc;
7use uuid::Uuid;
8
9pub type TaskId = Uuid;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TaskMetadata {
13 pub id: TaskId,
14 pub name: String,
15 pub created_at: DateTime<Utc>,
16 pub attempts: u32,
17 pub max_retries: u32,
18 pub timeout_seconds: u64,
19}
20
21impl Default for TaskMetadata {
22 fn default() -> Self {
23 Self {
24 id: Uuid::new_v4(),
25 name: "unknown".to_string(),
26 created_at: Utc::now(),
27 attempts: 0,
28 max_retries: 3,
29 timeout_seconds: 300,
30 }
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct TaskWrapper {
36 pub metadata: TaskMetadata,
37 pub payload: Vec<u8>,
38}
39
40#[async_trait]
42pub trait Task: Send + Sync + Serialize + for<'de> Deserialize<'de> + Debug {
43 async fn execute(&self) -> TaskResult;
45
46 fn name(&self) -> &str;
48
49 fn max_retries(&self) -> u32 {
51 3
52 }
53
54 fn timeout_seconds(&self) -> u64 {
56 300
57 }
58
59 fn priority(&self) -> TaskPriority {
61 TaskPriority::Normal
62 }
63
64 fn resource_requirements(&self) -> TaskResourceRequirements {
66 TaskResourceRequirements::default()
67 }
68
69 fn retry_delay_strategy(&self) -> RetryStrategy {
71 RetryStrategy::ExponentialBackoff {
72 base_delay_ms: 1000,
73 max_delay_ms: 60000,
74 multiplier: 2.0,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
81pub enum TaskPriority {
82 Low = 0,
83 Normal = 1,
84 High = 2,
85 Critical = 3,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct TaskResourceRequirements {
91 pub memory_bytes: Option<u64>,
93 pub cpu_intensity: Option<f32>,
95 pub io_ops_per_second: Option<u32>,
97 pub network_bandwidth_bytes: Option<u64>,
99}
100
101impl Default for TaskResourceRequirements {
102 fn default() -> Self {
103 Self {
104 memory_bytes: None,
105 cpu_intensity: Some(0.1), io_ops_per_second: None,
107 network_bandwidth_bytes: None,
108 }
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub enum RetryStrategy {
115 FixedDelay { delay_ms: u64 },
117 ExponentialBackoff {
119 base_delay_ms: u64,
120 max_delay_ms: u64,
121 multiplier: f64,
122 },
123 CustomIntervals { intervals_ms: Vec<u64> },
125 NoRetry,
127}
128
129pub type TaskResult = Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>;
131
132pub type TaskFuture = std::pin::Pin<Box<dyn std::future::Future<Output = TaskResult> + Send>>;
134
135pub type TaskExecutor = Arc<dyn Fn(Vec<u8>) -> TaskFuture + Send + Sync>;
137
138#[cfg(feature = "auto-register")]
140pub struct TaskRegistration {
141 pub type_name: &'static str,
142 pub register_fn: fn(&TaskRegistry) -> Result<(), Box<dyn std::error::Error + Send + Sync>>,
143}
144
145#[cfg(feature = "auto-register")]
146inventory::collect!(TaskRegistration);
147
148pub struct TaskRegistry {
150 executors: DashMap<String, TaskExecutor>,
151}
152
153impl TaskRegistry {
154 pub fn new() -> Self {
155 Self {
156 executors: DashMap::new(),
157 }
158 }
159
160 #[cfg(feature = "auto-register")]
162 pub fn with_auto_registered() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
163 let registry = Self::new();
164 registry.auto_register_tasks()?;
165 Ok(registry)
166 }
167
168 #[cfg(feature = "auto-register")]
170 pub fn with_auto_registered_and_config(
171 config: Option<&crate::config::AutoRegisterConfig>,
172 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
173 let registry = Self::new();
174 registry.auto_register_tasks_with_config(config)?;
175 Ok(registry)
176 }
177
178 #[cfg(feature = "auto-register")]
180 pub fn auto_register_tasks(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
181 self.auto_register_tasks_with_config(None)
182 }
183
184 #[cfg(feature = "auto-register")]
186 pub fn auto_register_tasks_with_config(
187 &self,
188 _config: Option<&crate::config::AutoRegisterConfig>,
189 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
190 #[cfg(feature = "tracing")]
191 tracing::info!("Auto-registering tasks...");
192
193 let mut registered_count = 0;
194 let mut errors = Vec::new();
195
196 for registration in inventory::iter::<TaskRegistration> {
198 #[cfg(feature = "tracing")]
199 tracing::debug!("Auto-registering task type: {}", registration.type_name);
200
201 match (registration.register_fn)(self) {
202 Ok(()) => {
203 registered_count += 1;
204 #[cfg(feature = "tracing")]
205 tracing::debug!(
206 "Successfully registered task type: {}",
207 registration.type_name
208 );
209 }
210 Err(e) => {
211 #[cfg(feature = "tracing")]
212 tracing::error!(
213 "Failed to register task type {}: {}",
214 registration.type_name,
215 e
216 );
217 errors.push(format!(
218 "Failed to register {}: {}",
219 registration.type_name, e
220 ));
221 }
222 }
223 }
224
225 if !errors.is_empty() {
226 return Err(format!("Task registration errors: {}", errors.join(", ")).into());
227 }
228
229 #[cfg(feature = "tracing")]
230 tracing::info!("Auto-registered {} task types", registered_count);
231
232 Ok(())
233 }
234
235 pub fn register_with_name<T>(
237 &self,
238 task_name: &str,
239 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
240 where
241 T: Task + 'static,
242 {
243 let executor: TaskExecutor = Arc::new(move |payload| {
244 Box::pin(async move {
245 match rmp_serde::from_slice::<T>(&payload) {
246 Ok(task) => task.execute().await,
247 Err(e) => Err(format!("Failed to deserialize task: {}", e).into()),
248 }
249 })
250 });
251
252 self.executors.insert(task_name.to_string(), executor);
253
254 Ok(())
255 }
256
257 pub async fn execute(&self, task_name: &str, payload: Vec<u8>) -> TaskResult {
259 let executor = self.executors.get(task_name).map(|e| e.clone());
260
261 if let Some(executor) = executor {
262 executor(payload).await
263 } else {
264 Err(format!("Unknown task type: {}", task_name).into())
265 }
266 }
267
268 pub fn registered_tasks(&self) -> Vec<String> {
270 self.executors
271 .iter()
272 .map(|entry| entry.key().clone())
273 .collect()
274 }
275}
276
277impl Default for TaskRegistry {
278 fn default() -> Self {
279 Self::new()
280 }
281}
282
283#[macro_export]
285macro_rules! manual_register_task {
286 ($registry:expr, $task_type:ty) => {{
287 let temp_instance = <$task_type as Default>::default();
290 let task_name = temp_instance.name().to_string();
291 $registry.register_with_name::<$task_type>(&task_name)
292 }};
293}
294
295#[macro_export]
297macro_rules! register_tasks {
298 ($registry:expr, $($task_type:ty),+ $(,)?) => {
299 {
300 let mut results = Vec::new();
301 $(
302 results.push($crate::manual_register_task!($registry, $task_type));
303 )+
304
305 for result in results {
307 if let Err(e) = result {
308 return Err(e);
309 }
310 }
311 Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
312 }
313 };
314}
315
316#[macro_export]
318macro_rules! register_task_with_name {
319 ($registry:expr, $task_type:ty, $name:expr) => {
320 $registry.register_with_name::<$task_type>($name)
321 };
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use serde::{Deserialize, Serialize};
328
329 #[derive(Debug, Serialize, Deserialize, Clone, Default)]
330 struct TestTask {
331 pub data: String,
332 pub should_fail: bool,
333 }
334
335 #[async_trait]
336 impl Task for TestTask {
337 async fn execute(&self) -> TaskResult {
338 if self.should_fail {
339 return Err("Task intentionally failed".into());
340 }
341
342 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
344
345 #[derive(Serialize)]
346 struct Response {
347 status: String,
348 processed_data: String,
349 }
350
351 let response = Response {
352 status: "completed".to_string(),
353 processed_data: format!("Processed: {}", self.data),
354 };
355
356 Ok(rmp_serde::to_vec(&response)?)
357 }
358
359 fn name(&self) -> &str {
360 "test_task"
361 }
362
363 fn max_retries(&self) -> u32 {
364 2
365 }
366
367 fn timeout_seconds(&self) -> u64 {
368 30
369 }
370 }
371
372 #[tokio::test]
373 async fn test_task_registry_creation() {
374 let registry = TaskRegistry::new();
375 assert_eq!(registry.registered_tasks().len(), 0);
376 }
377
378 #[tokio::test]
379 async fn test_task_registration() {
380 let registry = TaskRegistry::new();
381
382 registry
384 .register_with_name::<TestTask>("test_task")
385 .expect("Failed to register task");
386
387 let tasks = registry.registered_tasks();
388 assert_eq!(tasks.len(), 1);
389 assert!(tasks.contains(&"test_task".to_string()));
390 }
391
392 #[tokio::test]
393 async fn test_task_execution() {
394 let registry = TaskRegistry::new();
395 registry
396 .register_with_name::<TestTask>("test_task")
397 .expect("Failed to register task");
398
399 let task = TestTask {
400 data: "Hello, World!".to_string(),
401 should_fail: false,
402 };
403
404 let payload = rmp_serde::to_vec(&task).expect("Failed to serialize task");
405 let result = registry.execute("test_task", payload).await;
406
407 assert!(result.is_ok());
408 let response_data = result.unwrap();
409 assert!(!response_data.is_empty());
410
411 #[derive(serde::Deserialize)]
414 struct Response {
415 status: String,
416 processed_data: String,
417 }
418
419 let response: Response =
420 rmp_serde::from_slice(&response_data).expect("Failed to deserialize response");
421 assert_eq!(response.status, "completed");
422 assert!(response.processed_data.contains("Hello, World!"));
423 }
424
425 #[tokio::test]
426 async fn test_task_execution_failure() {
427 let registry = TaskRegistry::new();
428 registry
429 .register_with_name::<TestTask>("test_task")
430 .expect("Failed to register task");
431
432 let task = TestTask {
433 data: "This will fail".to_string(),
434 should_fail: true,
435 };
436
437 let payload = rmp_serde::to_vec(&task).expect("Failed to serialize task");
438 let result = registry.execute("test_task", payload).await;
439
440 assert!(result.is_err());
441 assert!(result
442 .unwrap_err()
443 .to_string()
444 .contains("intentionally failed"));
445 }
446
447 #[tokio::test]
448 async fn test_unknown_task_execution() {
449 let registry = TaskRegistry::new();
450
451 let result = registry.execute("unknown_task", vec![1, 2, 3]).await;
452
453 assert!(result.is_err());
454 assert!(result
455 .unwrap_err()
456 .to_string()
457 .contains("Unknown task type"));
458 }
459
460 #[tokio::test]
461 async fn test_task_metadata_default() {
462 let metadata = TaskMetadata::default();
463
464 assert_eq!(metadata.name, "unknown");
465 assert_eq!(metadata.attempts, 0);
466 assert_eq!(metadata.max_retries, 3);
467 assert_eq!(metadata.timeout_seconds, 300);
468 }
469
470 #[tokio::test]
471 async fn test_task_wrapper_serialization() {
472 let metadata = TaskMetadata {
473 id: TaskId::new_v4(),
474 name: "test_task".to_string(),
475 created_at: chrono::Utc::now(),
476 attempts: 1,
477 max_retries: 3,
478 timeout_seconds: 300,
479 };
480
481 let wrapper = TaskWrapper {
482 metadata: metadata.clone(),
483 payload: vec![1, 2, 3, 4],
484 };
485
486 let serialized = rmp_serde::to_vec(&wrapper).expect("Failed to serialize wrapper");
488 assert!(!serialized.is_empty());
489
490 let deserialized: TaskWrapper =
492 rmp_serde::from_slice(&serialized).expect("Failed to deserialize wrapper");
493
494 assert_eq!(deserialized.metadata.id, metadata.id);
495 assert_eq!(deserialized.metadata.name, metadata.name);
496 assert_eq!(deserialized.payload, vec![1, 2, 3, 4]);
497 }
498
499 #[tokio::test]
500 async fn test_multiple_task_registration() {
501 let registry = TaskRegistry::new();
502
503 registry
505 .register_with_name::<TestTask>("task1")
506 .expect("Failed to register task1");
507 registry
508 .register_with_name::<TestTask>("task2")
509 .expect("Failed to register task2");
510
511 let tasks = registry.registered_tasks();
512 assert_eq!(tasks.len(), 2);
513 assert!(tasks.contains(&"task1".to_string()));
514 assert!(tasks.contains(&"task2".to_string()));
515 }
516
517 #[tokio::test]
518 async fn test_task_registry_concurrent_access() {
519 let registry = Arc::new(TaskRegistry::new());
520 registry
521 .register_with_name::<TestTask>("test_task")
522 .expect("Failed to register task");
523
524 let task = TestTask {
525 data: "Concurrent test".to_string(),
526 should_fail: false,
527 };
528 let payload = rmp_serde::to_vec(&task).expect("Failed to serialize task");
529
530 let mut handles = Vec::new();
532 for i in 0..10 {
533 let registry_clone = Arc::clone(®istry);
534 let payload_clone = payload.clone();
535 let handle = tokio::spawn(async move {
536 let result = registry_clone.execute("test_task", payload_clone).await;
537 assert!(result.is_ok(), "Task {} failed", i);
538 });
539 handles.push(handle);
540 }
541
542 for handle in handles {
544 handle.await.expect("Task execution failed");
545 }
546 }
547}