1use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::Semaphore;
6use tracing::{error, info, warn};
7
8use crate::backend::QueueBackend;
9use crate::job::{Job, JobResult, JobStatus};
10
11#[derive(Clone, Copy)]
12pub struct WorkerConfig {
13 pub max_concurrency: usize,
14 pub poll_interval: Duration,
15}
16
17impl Default for WorkerConfig {
18 fn default() -> Self {
19 Self {
20 max_concurrency: 5,
21 poll_interval: Duration::from_millis(100),
22 }
23 }
24}
25
26use serde::de::DeserializeOwned;
27
28use std::sync::RwLock;
29
30pub struct WorkerPool<B: QueueBackend + ?Sized> {
31 pub backend: Arc<B>,
32 config: WorkerConfig,
33 registry: Arc<JobRegistry>,
34}
35
36type JobFactory =
37 Box<dyn Fn(serde_json::Value) -> Result<Box<dyn Job>, serde_json::Error> + Send + Sync>;
38
39struct JobRegistry {
40 factories: RwLock<std::collections::HashMap<String, JobFactory>>,
41}
42
43impl<B: QueueBackend + 'static> WorkerPool<B> {
44 pub fn new(backend: B, config: WorkerConfig) -> Self {
45 Self::new_with_arc(Arc::new(backend), config)
46 }
47}
48
49impl<B: QueueBackend + ?Sized + 'static> WorkerPool<B> {
50 pub fn new_with_arc(backend: Arc<B>, config: WorkerConfig) -> Self {
52 Self {
53 backend,
54 config,
55 registry: Arc::new(JobRegistry {
56 factories: RwLock::new(std::collections::HashMap::new()),
57 }),
58 }
59 }
60
61 pub fn register_job_type<J: Job + DeserializeOwned + 'static>(&self, name: &str) {
63 let factory = Box::new(|payload: serde_json::Value| {
64 let job: J = serde_json::from_value(payload)?;
65 Ok(Box::new(job) as Box<dyn Job>)
66 });
67
68 self.registry
69 .factories
70 .write()
71 .expect("Job registry RwLock poisoned")
72 .insert(name.to_string(), factory);
73 }
74
75 pub fn register_job_factory<F>(&self, name: &str, factory: F)
77 where
78 F: Fn(serde_json::Value) -> Box<dyn Job> + Send + Sync + 'static,
79 {
80 self.registry
81 .factories
82 .write()
83 .expect("Job registry RwLock poisoned")
84 .insert(
85 name.to_string(),
86 Box::new(move |payload| Ok(factory(payload))),
87 );
88 }
89
90 pub async fn start(&self) {
91 let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
92
93 info!(
94 "Worker pool started with concurrency {}",
95 self.config.max_concurrency
96 );
97
98 loop {
99 if semaphore.available_permits() > 0 {
100 match self.backend.dequeue().await {
101 Ok(Some(entry)) => {
102 let permit = semaphore
103 .clone()
104 .acquire_owned()
105 .await
106 .expect("Worker semaphore closed unexpectedly");
107 let backend = self.backend.clone();
108 let registry = self.registry.clone();
109
110 tokio::spawn(async move {
111 let job_opt = {
112 let factories = registry
113 .factories
114 .read()
115 .expect("Job registry RwLock poisoned");
116 factories
117 .get(&entry.job_type)
118 .map(|f| f(entry.payload.clone()))
119 };
120
121 match job_opt {
122 Some(Ok(mut job)) => {
123 info!("Processing job {} ({})", entry.id, entry.job_type);
124
125 let result = job.execute().await;
126
127 match result {
128 JobResult::Success(value) => {
129 if let Some(val) = value {
130 let _ = backend.set_result(entry.id, val).await;
131 }
132 let _ = backend
133 .update_status(
134 entry.id,
135 JobStatus::Completed,
136 None,
137 None,
138 )
139 .await;
140 }
141 JobResult::Retry(e) => {
142 let delay =
143 job.backoff_strategy().delay(entry.attempts);
144 let delay_secs = delay.as_secs();
145
146 info!(
147 job_id = %entry.id,
148 attempt = entry.attempts + 1,
149 delay_secs = delay_secs,
150 "Job failed, scheduling retry with backoff"
151 );
152
153 let _ = backend
154 .update_status(
155 entry.id,
156 JobStatus::Failed(entry.attempts + 1),
157 Some(e),
158 Some(delay_secs),
159 )
160 .await;
161 }
162 JobResult::Fatal(e) => {
163 let _ = backend
164 .update_status(
165 entry.id,
166 JobStatus::DeadLetter,
167 Some(e),
168 None,
169 )
170 .await;
171 }
172 }
173 }
174 Some(Err(e)) => {
175 error!(job_id = %entry.id, error = %e, "Job payload deserialization failed");
176 let _ = backend
177 .update_status(
178 entry.id,
179 JobStatus::DeadLetter,
180 Some(e.to_string()),
181 None,
182 )
183 .await;
184 }
185 None => {
186 warn!("No handler registered for job type: {}", entry.job_type);
187 let _ = backend
188 .update_status(
189 entry.id,
190 JobStatus::DeadLetter,
191 Some(format!("No handler for {}", entry.job_type)),
192 None,
193 )
194 .await;
195 }
196 }
197
198 drop(permit);
199 });
200 }
201 Ok(None) => {
202 tokio::time::sleep(self.config.poll_interval).await;
204 }
205 Err(e) => {
206 error!("Queue error: {}", e);
207 tokio::time::sleep(Duration::from_secs(1)).await;
208 }
209 }
210 } else {
211 tokio::time::sleep(Duration::from_millis(50)).await;
213 }
214 }
215 }
216}