rust_task_queue/
task.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::sync::Arc;
7use uuid::Uuid;
8
9pub type TaskId = Uuid;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TaskMetadata {
13    pub id: TaskId,
14    pub name: String,
15    pub created_at: DateTime<Utc>,
16    pub attempts: u32,
17    pub max_retries: u32,
18    pub timeout_seconds: u64,
19}
20
21impl Default for TaskMetadata {
22    fn default() -> Self {
23        Self {
24            id: Uuid::new_v4(),
25            name: "unknown".to_string(),
26            created_at: Utc::now(),
27            attempts: 0,
28            max_retries: 3,
29            timeout_seconds: 300,
30        }
31    }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct TaskWrapper {
36    pub metadata: TaskMetadata,
37    pub payload: Vec<u8>,
38}
39
40/// Enhanced Task trait with improved type safety and async characteristics
41#[async_trait]
42pub trait Task: Send + Sync + Serialize + for<'de> Deserialize<'de> + Debug {
43    /// Execute the task with comprehensive error handling
44    async fn execute(&self) -> TaskResult;
45
46    /// Task identifier for registration and execution
47    fn name(&self) -> &str;
48
49    /// Maximum retry attempts (default: 3)
50    fn max_retries(&self) -> u32 {
51        3
52    }
53
54    /// Task timeout in seconds (default: 300s/5min)
55    fn timeout_seconds(&self) -> u64 {
56        300
57    }
58
59    /// Task priority level (default: Normal)
60    fn priority(&self) -> TaskPriority {
61        TaskPriority::Normal
62    }
63
64    /// Estimate task resource requirements for better scheduling
65    fn resource_requirements(&self) -> TaskResourceRequirements {
66        TaskResourceRequirements::default()
67    }
68
69    /// Custom retry delay strategy
70    fn retry_delay_strategy(&self) -> RetryStrategy {
71        RetryStrategy::ExponentialBackoff {
72            base_delay_ms: 1000,
73            max_delay_ms: 60000,
74            multiplier: 2.0,
75        }
76    }
77}
78
79/// Task priority levels for queue ordering
80#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
81pub enum TaskPriority {
82    Low = 0,
83    Normal = 1,
84    High = 2,
85    Critical = 3,
86}
87
88/// Resource requirements for task execution planning
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct TaskResourceRequirements {
91    /// Expected memory usage in bytes
92    pub memory_bytes: Option<u64>,
93    /// Expected CPU intensity (0.0 to 1.0)
94    pub cpu_intensity: Option<f32>,
95    /// Expected I/O operations per second
96    pub io_ops_per_second: Option<u32>,
97    /// Network bandwidth requirements in bytes/sec
98    pub network_bandwidth_bytes: Option<u64>,
99}
100
101impl Default for TaskResourceRequirements {
102    fn default() -> Self {
103        Self {
104            memory_bytes: None,
105            cpu_intensity: Some(0.1), // Low CPU by default
106            io_ops_per_second: None,
107            network_bandwidth_bytes: None,
108        }
109    }
110}
111
112/// Retry strategy configuration
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub enum RetryStrategy {
115    /// Fixed delay between retries
116    FixedDelay { delay_ms: u64 },
117    /// Exponential backoff with jitter
118    ExponentialBackoff {
119        base_delay_ms: u64,
120        max_delay_ms: u64,
121        multiplier: f64,
122    },
123    /// Custom retry intervals
124    CustomIntervals { intervals_ms: Vec<u64> },
125    /// No retries
126    NoRetry,
127}
128
129/// Result type for task execution
130pub type TaskResult = Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>;
131
132/// Future type for task execution
133pub type TaskFuture = std::pin::Pin<Box<dyn std::future::Future<Output = TaskResult> + Send>>;
134
135/// Type-erased task executor function
136pub type TaskExecutor = Arc<dyn Fn(Vec<u8>) -> TaskFuture + Send + Sync>;
137
138/// Registration information for automatic task registration via inventory
139#[cfg(feature = "auto-register")]
140pub struct TaskRegistration {
141    pub type_name: &'static str,
142    pub register_fn: fn(&TaskRegistry) -> Result<(), Box<dyn std::error::Error + Send + Sync>>,
143}
144
145#[cfg(feature = "auto-register")]
146inventory::collect!(TaskRegistration);
147
148/// Task registry for mapping task names to executors
149pub struct TaskRegistry {
150    executors: DashMap<String, TaskExecutor>,
151}
152
153impl TaskRegistry {
154    pub fn new() -> Self {
155        Self {
156            executors: DashMap::new(),
157        }
158    }
159
160    /// Create a new registry with all automatically registered tasks
161    #[cfg(feature = "auto-register")]
162    pub fn with_auto_registered() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
163        let registry = Self::new();
164        registry.auto_register_tasks()?;
165        Ok(registry)
166    }
167
168    /// Create a new registry with all automatically registered tasks and configuration
169    #[cfg(feature = "auto-register")]
170    pub fn with_auto_registered_and_config(
171        config: Option<&crate::config::AutoRegisterConfig>,
172    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
173        let registry = Self::new();
174        registry.auto_register_tasks_with_config(config)?;
175        Ok(registry)
176    }
177
178    /// Register all tasks that have been submitted via the inventory pattern
179    #[cfg(feature = "auto-register")]
180    pub fn auto_register_tasks(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
181        self.auto_register_tasks_with_config(None)
182    }
183
184    /// Register all tasks with optional configuration
185    #[cfg(feature = "auto-register")]
186    pub fn auto_register_tasks_with_config(
187        &self,
188        _config: Option<&crate::config::AutoRegisterConfig>,
189    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
190        #[cfg(feature = "tracing")]
191        tracing::info!("Auto-registering tasks...");
192
193        let mut registered_count = 0;
194        let mut errors = Vec::new();
195
196        // Register tasks from the inventory pattern (compile-time registered)
197        for registration in inventory::iter::<TaskRegistration> {
198            #[cfg(feature = "tracing")]
199            tracing::debug!("Auto-registering task type: {}", registration.type_name);
200
201            match (registration.register_fn)(self) {
202                Ok(()) => {
203                    registered_count += 1;
204                    #[cfg(feature = "tracing")]
205                    tracing::debug!(
206                        "Successfully registered task type: {}",
207                        registration.type_name
208                    );
209                }
210                Err(e) => {
211                    #[cfg(feature = "tracing")]
212                    tracing::error!(
213                        "Failed to register task type {}: {}",
214                        registration.type_name,
215                        e
216                    );
217                    errors.push(format!(
218                        "Failed to register {}: {}",
219                        registration.type_name, e
220                    ));
221                }
222            }
223        }
224
225        if !errors.is_empty() {
226            return Err(format!("Task registration errors: {}", errors.join(", ")).into());
227        }
228
229        #[cfg(feature = "tracing")]
230        tracing::info!("Auto-registered {} task types", registered_count);
231
232        Ok(())
233    }
234
235    /// Register a task type with explicit name
236    pub fn register_with_name<T>(
237        &self,
238        task_name: &str,
239    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
240    where
241        T: Task + 'static,
242    {
243        let executor: TaskExecutor = Arc::new(move |payload| {
244            Box::pin(async move {
245                match rmp_serde::from_slice::<T>(&payload) {
246                    Ok(task) => task.execute().await,
247                    Err(e) => Err(format!("Failed to deserialize task: {}", e).into()),
248                }
249            })
250        });
251
252        self.executors.insert(task_name.to_string(), executor);
253
254        Ok(())
255    }
256
257    /// Execute a task by name
258    pub async fn execute(&self, task_name: &str, payload: Vec<u8>) -> TaskResult {
259        let executor = self.executors.get(task_name).map(|e| e.clone());
260
261        if let Some(executor) = executor {
262            executor(payload).await
263        } else {
264            Err(format!("Unknown task type: {}", task_name).into())
265        }
266    }
267
268    /// Get list of registered task names
269    pub fn registered_tasks(&self) -> Vec<String> {
270        self.executors
271            .iter()
272            .map(|entry| entry.key().clone())
273            .collect()
274    }
275}
276
277impl Default for TaskRegistry {
278    fn default() -> Self {
279        Self::new()
280    }
281}
282
283// Macro to make task registration easier
284#[macro_export]
285macro_rules! manual_register_task {
286    ($registry:expr, $task_type:ty) => {{
287        // We need a way to get the task name from the type
288        // This requires the task to implement Default temporarily
289        let temp_instance = <$task_type as Default>::default();
290        let task_name = temp_instance.name().to_string();
291        $registry.register_with_name::<$task_type>(&task_name)
292    }};
293}
294
295// Helper macro for registering multiple tasks at once
296#[macro_export]
297macro_rules! register_tasks {
298    ($registry:expr, $($task_type:ty),+ $(,)?) => {
299        {
300            let mut results = Vec::new();
301            $(
302                results.push($crate::manual_register_task!($registry, $task_type));
303            )+
304
305            // Return the first error if any, otherwise Ok
306            for result in results {
307                if let Err(e) = result {
308                    return Err(e);
309                }
310            }
311            Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
312        }
313    };
314}
315
316// Alternative macro that doesn't require Default trait
317#[macro_export]
318macro_rules! register_task_with_name {
319    ($registry:expr, $task_type:ty, $name:expr) => {
320        $registry.register_with_name::<$task_type>($name)
321    };
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use serde::{Deserialize, Serialize};
328
329    #[derive(Debug, Serialize, Deserialize, Clone, Default)]
330    struct TestTask {
331        pub data: String,
332        pub should_fail: bool,
333    }
334
335    #[async_trait]
336    impl Task for TestTask {
337        async fn execute(&self) -> TaskResult {
338            if self.should_fail {
339                return Err("Task intentionally failed".into());
340            }
341
342            // Simulate some work
343            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
344
345            #[derive(Serialize)]
346            struct Response {
347                status: String,
348                processed_data: String,
349            }
350
351            let response = Response {
352                status: "completed".to_string(),
353                processed_data: format!("Processed: {}", self.data),
354            };
355
356            Ok(rmp_serde::to_vec(&response)?)
357        }
358
359        fn name(&self) -> &str {
360            "test_task"
361        }
362
363        fn max_retries(&self) -> u32 {
364            2
365        }
366
367        fn timeout_seconds(&self) -> u64 {
368            30
369        }
370    }
371
372    #[tokio::test]
373    async fn test_task_registry_creation() {
374        let registry = TaskRegistry::new();
375        assert_eq!(registry.registered_tasks().len(), 0);
376    }
377
378    #[tokio::test]
379    async fn test_task_registration() {
380        let registry = TaskRegistry::new();
381
382        // Register a task
383        registry
384            .register_with_name::<TestTask>("test_task")
385            .expect("Failed to register task");
386
387        let tasks = registry.registered_tasks();
388        assert_eq!(tasks.len(), 1);
389        assert!(tasks.contains(&"test_task".to_string()));
390    }
391
392    #[tokio::test]
393    async fn test_task_execution() {
394        let registry = TaskRegistry::new();
395        registry
396            .register_with_name::<TestTask>("test_task")
397            .expect("Failed to register task");
398
399        let task = TestTask {
400            data: "Hello, World!".to_string(),
401            should_fail: false,
402        };
403
404        let payload = rmp_serde::to_vec(&task).expect("Failed to serialize task");
405        let result = registry.execute("test_task", payload).await;
406
407        assert!(result.is_ok());
408        let response_data = result.unwrap();
409        assert!(!response_data.is_empty());
410
411        // Verify the response contains expected data
412        // Since we're using MessagePack, we need to deserialize it properly
413        #[derive(serde::Deserialize)]
414        struct Response {
415            status: String,
416            processed_data: String,
417        }
418
419        let response: Response =
420            rmp_serde::from_slice(&response_data).expect("Failed to deserialize response");
421        assert_eq!(response.status, "completed");
422        assert!(response.processed_data.contains("Hello, World!"));
423    }
424
425    #[tokio::test]
426    async fn test_task_execution_failure() {
427        let registry = TaskRegistry::new();
428        registry
429            .register_with_name::<TestTask>("test_task")
430            .expect("Failed to register task");
431
432        let task = TestTask {
433            data: "This will fail".to_string(),
434            should_fail: true,
435        };
436
437        let payload = rmp_serde::to_vec(&task).expect("Failed to serialize task");
438        let result = registry.execute("test_task", payload).await;
439
440        assert!(result.is_err());
441        assert!(result
442            .unwrap_err()
443            .to_string()
444            .contains("intentionally failed"));
445    }
446
447    #[tokio::test]
448    async fn test_unknown_task_execution() {
449        let registry = TaskRegistry::new();
450
451        let result = registry.execute("unknown_task", vec![1, 2, 3]).await;
452
453        assert!(result.is_err());
454        assert!(result
455            .unwrap_err()
456            .to_string()
457            .contains("Unknown task type"));
458    }
459
460    #[tokio::test]
461    async fn test_task_metadata_default() {
462        let metadata = TaskMetadata::default();
463
464        assert_eq!(metadata.name, "unknown");
465        assert_eq!(metadata.attempts, 0);
466        assert_eq!(metadata.max_retries, 3);
467        assert_eq!(metadata.timeout_seconds, 300);
468    }
469
470    #[tokio::test]
471    async fn test_task_wrapper_serialization() {
472        let metadata = TaskMetadata {
473            id: TaskId::new_v4(),
474            name: "test_task".to_string(),
475            created_at: chrono::Utc::now(),
476            attempts: 1,
477            max_retries: 3,
478            timeout_seconds: 300,
479        };
480
481        let wrapper = TaskWrapper {
482            metadata: metadata.clone(),
483            payload: vec![1, 2, 3, 4],
484        };
485
486        // Test serialization
487        let serialized = rmp_serde::to_vec(&wrapper).expect("Failed to serialize wrapper");
488        assert!(!serialized.is_empty());
489
490        // Test deserialization
491        let deserialized: TaskWrapper =
492            rmp_serde::from_slice(&serialized).expect("Failed to deserialize wrapper");
493
494        assert_eq!(deserialized.metadata.id, metadata.id);
495        assert_eq!(deserialized.metadata.name, metadata.name);
496        assert_eq!(deserialized.payload, vec![1, 2, 3, 4]);
497    }
498
499    #[tokio::test]
500    async fn test_multiple_task_registration() {
501        let registry = TaskRegistry::new();
502
503        // Register multiple tasks
504        registry
505            .register_with_name::<TestTask>("task1")
506            .expect("Failed to register task1");
507        registry
508            .register_with_name::<TestTask>("task2")
509            .expect("Failed to register task2");
510
511        let tasks = registry.registered_tasks();
512        assert_eq!(tasks.len(), 2);
513        assert!(tasks.contains(&"task1".to_string()));
514        assert!(tasks.contains(&"task2".to_string()));
515    }
516
517    #[tokio::test]
518    async fn test_task_registry_concurrent_access() {
519        let registry = Arc::new(TaskRegistry::new());
520        registry
521            .register_with_name::<TestTask>("test_task")
522            .expect("Failed to register task");
523
524        let task = TestTask {
525            data: "Concurrent test".to_string(),
526            should_fail: false,
527        };
528        let payload = rmp_serde::to_vec(&task).expect("Failed to serialize task");
529
530        // Execute multiple tasks concurrently
531        let mut handles = Vec::new();
532        for i in 0..10 {
533            let registry_clone = Arc::clone(&registry);
534            let payload_clone = payload.clone();
535            let handle = tokio::spawn(async move {
536                let result = registry_clone.execute("test_task", payload_clone).await;
537                assert!(result.is_ok(), "Task {} failed", i);
538            });
539            handles.push(handle);
540        }
541
542        // Wait for all tasks to complete
543        for handle in handles {
544            handle.await.expect("Task execution failed");
545        }
546    }
547}