1use crate::{RedisBroker, Task, TaskId, TaskQueueError};
2use chrono::{DateTime, Duration, Utc};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::task::JoinHandle;
6
7pub struct TaskScheduler {
8 broker: Arc<RedisBroker>,
9 tick_interval_seconds: u64,
10}
11
12impl TaskScheduler {
13 pub fn new(broker: Arc<RedisBroker>) -> Self {
14 Self {
15 broker,
16 tick_interval_seconds: 1, }
18 }
19
20 pub fn with_tick_interval(broker: Arc<RedisBroker>, interval_seconds: u64) -> Self {
21 Self {
22 broker,
23 tick_interval_seconds: interval_seconds,
24 }
25 }
26
27 pub fn start(broker: Arc<RedisBroker>) -> Result<JoinHandle<()>, TaskQueueError> {
28 let scheduler = Self::new(broker);
29 let tick_interval = scheduler.tick_interval_seconds;
30 let broker = scheduler.broker;
31
32 let handle = tokio::spawn(async move {
33 loop {
34 if let Err(e) = Self::process_scheduled_tasks(&broker).await {
35 #[cfg(feature = "tracing")]
36 tracing::error!("Scheduler error: {}", e);
37 }
38
39 tokio::time::sleep(tokio::time::Duration::from_secs(tick_interval)).await;
40 }
41 });
42
43 #[cfg(feature = "tracing")]
44 tracing::info!(
45 "Task scheduler started with {} second tick interval",
46 tick_interval
47 );
48
49 Ok(handle)
50 }
51
52 pub async fn schedule_task<T: Task>(
53 &self,
54 task: T,
55 queue: &str,
56 delay: Duration,
57 ) -> Result<TaskId, TaskQueueError> {
58 let schedule_start = std::time::Instant::now();
59 let execute_at = Utc::now() + delay;
60 let task_id = uuid::Uuid::new_v4();
61 let task_name = task.name();
62
63 #[cfg(feature = "tracing")]
64 tracing::info!(
65 task_id = %task_id,
66 task_name = task_name,
67 queue = queue,
68 delay_seconds = delay.num_seconds(),
69 execute_at = %execute_at,
70 "Scheduling task for delayed execution"
71 );
72
73 let scheduled_task = ScheduledTask {
74 id: task_id,
75 queue: queue.to_string(),
76 execute_at,
77 task_name: task.name().to_string(),
78 max_retries: task.max_retries(),
79 timeout_seconds: task.timeout_seconds(),
80 payload: rmp_serde::to_vec(&task)?,
81 };
82
83 let mut conn = self.broker.pool.get().await.map_err(|e| {
85 TaskQueueError::Connection(format!("Failed to get Redis connection: {}", e))
86 })?;
87 let serialized = rmp_serde::to_vec(&scheduled_task)?;
88
89 redis::AsyncCommands::zadd::<_, _, _, ()>(
90 &mut *conn,
91 "scheduled_tasks",
92 serialized.clone(),
93 execute_at.timestamp(),
94 )
95 .await?;
96
97 #[cfg(feature = "tracing")]
98 tracing::info!(
99 task_id = %task_id,
100 task_name = task_name,
101 queue = queue,
102 execute_at = %execute_at,
103 delay_seconds = delay.num_seconds(),
104 duration_ms = schedule_start.elapsed().as_millis(),
105 payload_size_bytes = serialized.len(),
106 redis_score = execute_at.timestamp(),
107 "Task scheduled successfully"
108 );
109
110 Ok(task_id)
111 }
112
113 pub async fn process_scheduled_tasks(broker: &RedisBroker) -> Result<(), TaskQueueError> {
114 let process_start = std::time::Instant::now();
115 let mut conn = broker.pool.get().await.map_err(|e| {
116 TaskQueueError::Connection(format!("Failed to get Redis connection: {}", e))
117 })?;
118 let now = Utc::now().timestamp();
119
120 #[cfg(feature = "tracing")]
121 tracing::debug!(
122 timestamp_threshold = now,
123 "Processing scheduled tasks ready for execution"
124 );
125
126 let tasks: Vec<Vec<u8>> = redis::AsyncCommands::zrangebyscore_limit(
128 &mut *conn,
129 "scheduled_tasks",
130 0,
131 now,
132 0,
133 100,
134 )
135 .await?;
136
137 let batch_size = tasks.len();
138
139 if batch_size == 0 {
140 #[cfg(feature = "tracing")]
141 tracing::trace!(
142 duration_ms = process_start.elapsed().as_millis(),
143 "No scheduled tasks ready for execution"
144 );
145 return Ok(());
146 }
147
148 #[cfg(feature = "tracing")]
149 tracing::info!(
150 batch_size = batch_size,
151 timestamp_threshold = now,
152 "Found scheduled tasks ready for execution"
153 );
154
155 let mut processed_count = 0;
156 let mut failed_count = 0;
157 let mut queue_distribution = HashMap::new();
158
159 for serialized_task in tasks {
160 if let Ok(scheduled_task) = rmp_serde::from_slice::<ScheduledTask>(&serialized_task) {
161 *queue_distribution
163 .entry(scheduled_task.queue.clone())
164 .or_insert(0) += 1;
165
166 #[cfg(feature = "tracing")]
167 tracing::debug!(
168 task_id = %scheduled_task.id,
169 task_name = %scheduled_task.task_name,
170 queue = %scheduled_task.queue,
171 scheduled_execute_at = %scheduled_task.execute_at,
172 delay_from_scheduled = (now - scheduled_task.execute_at.timestamp()),
173 "Moving scheduled task to regular queue"
174 );
175
176 let task_wrapper = crate::TaskWrapper {
178 metadata: crate::TaskMetadata {
179 id: scheduled_task.id,
180 name: scheduled_task.task_name.clone(),
181 created_at: Utc::now(),
182 attempts: 0,
183 max_retries: scheduled_task.max_retries,
184 timeout_seconds: scheduled_task.timeout_seconds,
185 },
186 payload: scheduled_task.payload,
187 };
188
189 let serialized_wrapper = rmp_serde::to_vec(&task_wrapper)?;
190
191 match redis::AsyncCommands::lpush::<_, _, ()>(
192 &mut *conn,
193 &scheduled_task.queue,
194 &serialized_wrapper,
195 )
196 .await
197 {
198 Ok(_) => {
199 match redis::AsyncCommands::zrem::<_, _, ()>(
201 &mut *conn,
202 "scheduled_tasks",
203 &serialized_task,
204 )
205 .await
206 {
207 Ok(_) => {
208 processed_count += 1;
209
210 #[cfg(feature = "tracing")]
211 tracing::info!(
212 task_id = %scheduled_task.id,
213 task_name = %scheduled_task.task_name,
214 queue = %scheduled_task.queue,
215 delay_from_scheduled_seconds = (now - scheduled_task.execute_at.timestamp()),
216 payload_size_bytes = serialized_wrapper.len(),
217 "Scheduled task moved to regular queue successfully"
218 );
219 }
220 Err(e) => {
221 failed_count += 1;
222 #[cfg(feature = "tracing")]
223 tracing::error!(
224 task_id = %scheduled_task.id,
225 task_name = %scheduled_task.task_name,
226 error = %e,
227 "Failed to remove task from scheduled tasks set"
228 );
229 }
230 }
231 }
232 Err(e) => {
233 failed_count += 1;
234 #[cfg(feature = "tracing")]
235 tracing::error!(
236 task_id = %scheduled_task.id,
237 task_name = %scheduled_task.task_name,
238 queue = %scheduled_task.queue,
239 error = %e,
240 "Failed to push scheduled task to regular queue"
241 );
242 }
243 }
244 } else {
245 failed_count += 1;
246 #[cfg(feature = "tracing")]
247 tracing::error!(
248 payload_size_bytes = serialized_task.len(),
249 "Failed to deserialize scheduled task"
250 );
251 }
252 }
253
254 #[cfg(feature = "tracing")]
255 tracing::info!(
256 total_processed = processed_count,
257 failed_count = failed_count,
258 batch_size = batch_size,
259 duration_ms = process_start.elapsed().as_millis(),
260 queue_distribution = ?queue_distribution,
261 success_rate = if batch_size > 0 { processed_count as f64 / batch_size as f64 } else { 0.0 },
262 "Scheduled task processing batch completed"
263 );
264
265 Ok(())
266 }
267}
268
269#[derive(serde::Serialize, serde::Deserialize)]
270struct ScheduledTask {
271 id: TaskId,
272 queue: String,
273 execute_at: DateTime<Utc>,
274 task_name: String,
275 max_retries: u32,
276 timeout_seconds: u64,
277 payload: Vec<u8>,
278}