1use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::Semaphore;
6use tracing::{error, info, warn};
7
8use crate::backend::QueueBackend;
9use crate::job::{Job, JobResult, JobStatus};
10
11#[derive(Clone, Copy)]
12pub struct WorkerConfig {
13 pub max_concurrency: usize,
14 pub poll_interval: Duration,
15}
16
17impl Default for WorkerConfig {
18 fn default() -> Self {
19 Self {
20 max_concurrency: 5,
21 poll_interval: Duration::from_millis(100),
22 }
23 }
24}
25
26use serde::de::DeserializeOwned;
27
28use std::sync::RwLock;
29
30pub struct WorkerPool<B: QueueBackend + ?Sized> {
31 pub backend: Arc<B>,
32 config: WorkerConfig,
33 registry: Arc<JobRegistry>,
34}
35
36type JobFactory = Box<dyn Fn(serde_json::Value) -> Box<dyn Job> + Send + Sync>;
37
38struct JobRegistry {
39 factories: RwLock<std::collections::HashMap<String, JobFactory>>,
40}
41
42impl<B: QueueBackend + 'static> WorkerPool<B> {
43 pub fn new(backend: B, config: WorkerConfig) -> Self {
44 Self::new_with_arc(Arc::new(backend), config)
45 }
46}
47
48impl<B: QueueBackend + ?Sized + 'static> WorkerPool<B> {
49 pub fn new_with_arc(backend: Arc<B>, config: WorkerConfig) -> Self {
51 Self {
52 backend,
53 config,
54 registry: Arc::new(JobRegistry {
55 factories: RwLock::new(std::collections::HashMap::new()),
56 }),
57 }
58 }
59
60 pub fn register_job_type<J: Job + DeserializeOwned + 'static>(&self, name: &str) {
62 let factory = Box::new(|payload: serde_json::Value| {
63 let job: J =
64 serde_json::from_value(payload).expect("Job payload deserialization failed");
65 Box::new(job) as Box<dyn Job>
66 });
67
68 self.registry
69 .factories
70 .write()
71 .unwrap()
72 .insert(name.to_string(), factory);
73 }
74
75 pub fn register_job_factory<F>(&self, name: &str, factory: F)
77 where
78 F: Fn(serde_json::Value) -> Box<dyn Job> + Send + Sync + 'static,
79 {
80 self.registry
81 .factories
82 .write()
83 .unwrap()
84 .insert(name.to_string(), Box::new(factory));
85 }
86
87 pub async fn start(&self) {
88 let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
89
90 info!(
91 "Worker pool started with concurrency {}",
92 self.config.max_concurrency
93 );
94
95 loop {
96 if semaphore.available_permits() > 0 {
97 match self.backend.dequeue().await {
98 Ok(Some(entry)) => {
99 let permit = semaphore.clone().acquire_owned().await.unwrap();
100 let backend = self.backend.clone();
101 let registry = self.registry.clone();
102
103 tokio::spawn(async move {
104 let job_opt = {
105 let factories = registry.factories.read().unwrap();
106 factories
107 .get(&entry.job_type)
108 .map(|f| f(entry.payload.clone()))
109 };
110
111 if let Some(mut job) = job_opt {
112 info!("Processing job {} ({})", entry.id, entry.job_type);
113
114 let result = job.execute().await;
115
116 match result {
117 JobResult::Success => {
118 let _ = backend
119 .update_status(
120 entry.id,
121 JobStatus::Completed,
122 None,
123 None,
124 )
125 .await;
126 }
127 JobResult::Retry(e) => {
128 let delay = job.backoff_strategy().delay(entry.attempts);
130 let delay_secs = delay.as_secs();
131
132 info!(
133 job_id = %entry.id,
134 attempt = entry.attempts + 1,
135 delay_secs = delay_secs,
136 "Job failed, scheduling retry with backoff"
137 );
138
139 let _ = backend
140 .update_status(
141 entry.id,
142 JobStatus::Failed(entry.attempts + 1),
143 Some(e),
144 Some(delay_secs),
145 )
146 .await;
147 }
148 JobResult::Fatal(e) => {
149 let _ = backend
150 .update_status(
151 entry.id,
152 JobStatus::DeadLetter,
153 Some(e),
154 None,
155 )
156 .await;
157 }
158 }
159 } else {
160 warn!("No handler registered for job type: {}", entry.job_type);
161 let _ = backend
162 .update_status(
163 entry.id,
164 JobStatus::DeadLetter,
165 Some(format!("No handler for {}", entry.job_type)),
166 None,
167 )
168 .await;
169 }
170
171 drop(permit);
172 });
173 }
174 Ok(None) => {
175 tokio::time::sleep(self.config.poll_interval).await;
177 }
178 Err(e) => {
179 error!("Queue error: {}", e);
180 tokio::time::sleep(Duration::from_secs(1)).await;
181 }
182 }
183 } else {
184 tokio::time::sleep(Duration::from_millis(50)).await;
186 }
187 }
188 }
189}