rust_task_queue/
scheduler.rs

1use crate::{RedisBroker, Task, TaskId, TaskQueueError};
2use chrono::{DateTime, Duration, Utc};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::task::JoinHandle;
6
7pub struct TaskScheduler {
8    broker: Arc<RedisBroker>,
9    tick_interval_seconds: u64,
10}
11
12impl TaskScheduler {
13    pub fn new(broker: Arc<RedisBroker>) -> Self {
14        Self {
15            broker,
16            tick_interval_seconds: 1, // Default to 1 second for better responsiveness
17        }
18    }
19
20    pub fn with_tick_interval(broker: Arc<RedisBroker>, interval_seconds: u64) -> Self {
21        Self {
22            broker,
23            tick_interval_seconds: interval_seconds,
24        }
25    }
26
27    pub fn start(broker: Arc<RedisBroker>) -> Result<JoinHandle<()>, TaskQueueError> {
28        let scheduler = Self::new(broker);
29        let tick_interval = scheduler.tick_interval_seconds;
30        let broker = scheduler.broker;
31
32        let handle = tokio::spawn(async move {
33            loop {
34                if let Err(e) = Self::process_scheduled_tasks(&broker).await {
35                    #[cfg(feature = "tracing")]
36                    tracing::error!("Scheduler error: {}", e);
37                }
38
39                tokio::time::sleep(tokio::time::Duration::from_secs(tick_interval)).await;
40            }
41        });
42
43        #[cfg(feature = "tracing")]
44        tracing::info!(
45            "Task scheduler started with {} second tick interval",
46            tick_interval
47        );
48
49        Ok(handle)
50    }
51
52    pub async fn schedule_task<T: Task>(
53        &self,
54        task: T,
55        queue: &str,
56        delay: Duration,
57    ) -> Result<TaskId, TaskQueueError> {
58        let schedule_start = std::time::Instant::now();
59        let execute_at = Utc::now() + delay;
60        let task_id = uuid::Uuid::new_v4();
61        let task_name = task.name();
62
63        #[cfg(feature = "tracing")]
64        tracing::info!(
65            task_id = %task_id,
66            task_name = task_name,
67            queue = queue,
68            delay_seconds = delay.num_seconds(),
69            execute_at = %execute_at,
70            "Scheduling task for delayed execution"
71        );
72
73        let scheduled_task = ScheduledTask {
74            id: task_id,
75            queue: queue.to_string(),
76            execute_at,
77            task_name: task.name().to_string(),
78            max_retries: task.max_retries(),
79            timeout_seconds: task.timeout_seconds(),
80            payload: rmp_serde::to_vec(&task)?,
81        };
82
83        // Store in Redis sorted set with timestamp as score
84        let mut conn = self.broker.pool.get().await.map_err(|e| {
85            TaskQueueError::Connection(format!("Failed to get Redis connection: {}", e))
86        })?;
87        let serialized = rmp_serde::to_vec(&scheduled_task)?;
88
89        redis::AsyncCommands::zadd::<_, _, _, ()>(
90            &mut *conn,
91            "scheduled_tasks",
92            serialized.clone(),
93            execute_at.timestamp(),
94        )
95        .await?;
96
97        #[cfg(feature = "tracing")]
98        tracing::info!(
99            task_id = %task_id,
100            task_name = task_name,
101            queue = queue,
102            execute_at = %execute_at,
103            delay_seconds = delay.num_seconds(),
104            duration_ms = schedule_start.elapsed().as_millis(),
105            payload_size_bytes = serialized.len(),
106            redis_score = execute_at.timestamp(),
107            "Task scheduled successfully"
108        );
109
110        Ok(task_id)
111    }
112
113    pub async fn process_scheduled_tasks(broker: &RedisBroker) -> Result<(), TaskQueueError> {
114        let process_start = std::time::Instant::now();
115        let mut conn = broker.pool.get().await.map_err(|e| {
116            TaskQueueError::Connection(format!("Failed to get Redis connection: {}", e))
117        })?;
118        let now = Utc::now().timestamp();
119
120        #[cfg(feature = "tracing")]
121        tracing::debug!(
122            timestamp_threshold = now,
123            "Processing scheduled tasks ready for execution"
124        );
125
126        // Get tasks that should execute now
127        let tasks: Vec<Vec<u8>> = redis::AsyncCommands::zrangebyscore_limit(
128            &mut *conn,
129            "scheduled_tasks",
130            0,
131            now,
132            0,
133            100,
134        )
135        .await?;
136
137        let batch_size = tasks.len();
138
139        if batch_size == 0 {
140            #[cfg(feature = "tracing")]
141            tracing::trace!(
142                duration_ms = process_start.elapsed().as_millis(),
143                "No scheduled tasks ready for execution"
144            );
145            return Ok(());
146        }
147
148        #[cfg(feature = "tracing")]
149        tracing::info!(
150            batch_size = batch_size,
151            timestamp_threshold = now,
152            "Found scheduled tasks ready for execution"
153        );
154
155        let mut processed_count = 0;
156        let mut failed_count = 0;
157        let mut queue_distribution = HashMap::new();
158
159        for serialized_task in tasks {
160            if let Ok(scheduled_task) = rmp_serde::from_slice::<ScheduledTask>(&serialized_task) {
161                // Track queue distribution
162                *queue_distribution
163                    .entry(scheduled_task.queue.clone())
164                    .or_insert(0) += 1;
165
166                #[cfg(feature = "tracing")]
167                tracing::debug!(
168                    task_id = %scheduled_task.id,
169                    task_name = %scheduled_task.task_name,
170                    queue = %scheduled_task.queue,
171                    scheduled_execute_at = %scheduled_task.execute_at,
172                    delay_from_scheduled = (now - scheduled_task.execute_at.timestamp()),
173                    "Moving scheduled task to regular queue"
174                );
175
176                // Move task to regular queue
177                let task_wrapper = crate::TaskWrapper {
178                    metadata: crate::TaskMetadata {
179                        id: scheduled_task.id,
180                        name: scheduled_task.task_name.clone(),
181                        created_at: Utc::now(),
182                        attempts: 0,
183                        max_retries: scheduled_task.max_retries,
184                        timeout_seconds: scheduled_task.timeout_seconds,
185                    },
186                    payload: scheduled_task.payload,
187                };
188
189                let serialized_wrapper = rmp_serde::to_vec(&task_wrapper)?;
190
191                match redis::AsyncCommands::lpush::<_, _, ()>(
192                    &mut *conn,
193                    &scheduled_task.queue,
194                    &serialized_wrapper,
195                )
196                .await
197                {
198                    Ok(_) => {
199                        // Remove from scheduled tasks
200                        match redis::AsyncCommands::zrem::<_, _, ()>(
201                            &mut *conn,
202                            "scheduled_tasks",
203                            &serialized_task,
204                        )
205                        .await
206                        {
207                            Ok(_) => {
208                                processed_count += 1;
209
210                                #[cfg(feature = "tracing")]
211                                tracing::info!(
212                                    task_id = %scheduled_task.id,
213                                    task_name = %scheduled_task.task_name,
214                                    queue = %scheduled_task.queue,
215                                    delay_from_scheduled_seconds = (now - scheduled_task.execute_at.timestamp()),
216                                    payload_size_bytes = serialized_wrapper.len(),
217                                    "Scheduled task moved to regular queue successfully"
218                                );
219                            }
220                            Err(e) => {
221                                failed_count += 1;
222                                #[cfg(feature = "tracing")]
223                                tracing::error!(
224                                    task_id = %scheduled_task.id,
225                                    task_name = %scheduled_task.task_name,
226                                    error = %e,
227                                    "Failed to remove task from scheduled tasks set"
228                                );
229                            }
230                        }
231                    }
232                    Err(e) => {
233                        failed_count += 1;
234                        #[cfg(feature = "tracing")]
235                        tracing::error!(
236                            task_id = %scheduled_task.id,
237                            task_name = %scheduled_task.task_name,
238                            queue = %scheduled_task.queue,
239                            error = %e,
240                            "Failed to push scheduled task to regular queue"
241                        );
242                    }
243                }
244            } else {
245                failed_count += 1;
246                #[cfg(feature = "tracing")]
247                tracing::error!(
248                    payload_size_bytes = serialized_task.len(),
249                    "Failed to deserialize scheduled task"
250                );
251            }
252        }
253
254        #[cfg(feature = "tracing")]
255        tracing::info!(
256            total_processed = processed_count,
257            failed_count = failed_count,
258            batch_size = batch_size,
259            duration_ms = process_start.elapsed().as_millis(),
260            queue_distribution = ?queue_distribution,
261            success_rate = if batch_size > 0 { processed_count as f64 / batch_size as f64 } else { 0.0 },
262            "Scheduled task processing batch completed"
263        );
264
265        Ok(())
266    }
267}
268
269#[derive(serde::Serialize, serde::Deserialize)]
270struct ScheduledTask {
271    id: TaskId,
272    queue: String,
273    execute_at: DateTime<Utc>,
274    task_name: String,
275    max_retries: u32,
276    timeout_seconds: u64,
277    payload: Vec<u8>,
278}