1use super::Result;
2use crate::{Counter, Job, RedisPool, RetryOpts, UnitOfWork, WorkerRef};
3use async_trait::async_trait;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tracing::error;
7
8#[async_trait]
9pub trait ServerMiddleware {
10 async fn call(
11 &self,
12 iter: ChainIter,
13 job: &Job,
14 worker: Arc<WorkerRef>,
15 redis: RedisPool,
16 ) -> Result<()>;
17}
18
19#[derive(Clone)]
22pub struct ChainIter {
23 stack: Arc<RwLock<Vec<Box<dyn ServerMiddleware + Send + Sync>>>>,
24 index: usize,
25}
26
27impl ChainIter {
28 #[inline]
29 pub async fn next(&self, job: &Job, worker: Arc<WorkerRef>, redis: RedisPool) -> Result<()> {
30 let stack = self.stack.read().await;
31
32 if let Some(middleware) = stack.get(self.index) {
33 middleware
34 .call(
35 ChainIter {
36 stack: self.stack.clone(),
37 index: self.index + 1,
38 },
39 job,
40 worker,
41 redis,
42 )
43 .await?;
44 }
45
46 Ok(())
47 }
48}
49
50#[derive(Clone)]
52pub(crate) struct Chain {
53 stack: Arc<RwLock<Vec<Box<dyn ServerMiddleware + Send + Sync>>>>,
54}
55
56impl Chain {
57 #[allow(dead_code)]
59 pub(crate) fn empty() -> Self {
60 Self {
61 stack: Arc::new(RwLock::new(vec![])),
62 }
63 }
64
65 pub(crate) fn new_with_stats(counter: Counter) -> Self {
66 Self {
67 stack: Arc::new(RwLock::new(vec![
68 Box::new(RetryMiddleware),
69 Box::new(StatsMiddleware::new(counter)),
70 Box::new(HandlerMiddleware),
71 ])),
72 }
73 }
74
75 pub(crate) async fn using(&mut self, middleware: Box<dyn ServerMiddleware + Send + Sync>) {
76 let mut stack = self.stack.write().await;
77 let index = if stack.is_empty() { 0 } else { stack.len() - 1 };
79
80 stack.insert(index, middleware);
81 }
82
83 #[inline]
84 pub(crate) fn iter(&self) -> ChainIter {
85 ChainIter {
86 stack: self.stack.clone(),
87 index: 0,
88 }
89 }
90
91 #[inline]
92 pub(crate) async fn call(
93 &mut self,
94 job: &Job,
95 worker: Arc<WorkerRef>,
96 redis: RedisPool,
97 ) -> Result<()> {
98 self.iter().next(job, worker, redis).await
103 }
104}
105
106pub struct StatsMiddleware {
107 busy_count: Counter,
108}
109
110impl StatsMiddleware {
111 fn new(busy_count: Counter) -> Self {
112 Self { busy_count }
113 }
114}
115
116#[async_trait]
117impl ServerMiddleware for StatsMiddleware {
118 #[inline]
119 async fn call(
120 &self,
121 chain: ChainIter,
122 job: &Job,
123 worker: Arc<WorkerRef>,
124 redis: RedisPool,
125 ) -> Result<()> {
126 self.busy_count.incrby(1);
127 let res = chain.next(job, worker, redis).await;
128 self.busy_count.decrby(1);
129 res
130 }
131}
132
133struct HandlerMiddleware;
134
135#[async_trait]
136impl ServerMiddleware for HandlerMiddleware {
137 #[inline]
138 async fn call(
139 &self,
140 _chain: ChainIter,
141 job: &Job,
142 worker: Arc<WorkerRef>,
143 _redis: RedisPool,
144 ) -> Result<()> {
145 worker.call(job.args.clone()).await
146 }
147}
148
149struct RetryMiddleware;
150
151#[async_trait]
152impl ServerMiddleware for RetryMiddleware {
153 #[inline]
154 async fn call(
155 &self,
156 chain: ChainIter,
157 job: &Job,
158 worker: Arc<WorkerRef>,
159 redis: RedisPool,
160 ) -> Result<()> {
161 let max_retries = if let RetryOpts::Max(max_retries) = job.retry {
164 max_retries
165 } else {
166 worker.max_retries()
167 };
168
169 let err = {
170 match chain.next(job, worker, redis.clone()).await {
171 Ok(()) => return Ok(()),
172 Err(err) => format!("{err:?}"),
173 }
174 };
175
176 let mut job = job.clone();
177
178 job.error_message = Some(err);
180 if job.retry_count.is_some() {
181 job.retried_at = Some(chrono::Utc::now().timestamp() as f64);
182 } else {
183 job.failed_at = Some(chrono::Utc::now().timestamp() as f64);
184 }
185 let retry_count = job.retry_count.unwrap_or(0) + 1;
186 job.retry_count = Some(retry_count);
187
188 if retry_count > max_retries || job.retry == RetryOpts::Never {
190 error!({
191 "status" = "fail",
192 "class" = &job.class,
193 "jid" = &job.jid,
194 "queue" = &job.queue,
195 "err" = &job.error_message
196 }, "Max retries exceeded, will not reschedule job");
197 } else {
198 error!({
199 "status" = "fail",
200 "class" = &job.class,
201 "jid" = &job.jid,
202 "queue" = &job.queue,
203 "retry_queue" = &job.retry_queue,
204 "err" = &job.error_message
205 }, "Scheduling job for retry in the future");
206
207 if let Some(ref retry_queue) = job.retry_queue {
209 job.queue = retry_queue.into();
210 }
211
212 UnitOfWork::from_job(job).reenqueue(&redis).await?;
213 }
214
215 Ok(())
216 }
217}
218
219#[cfg(test)]
220mod test {
221 use super::*;
222 use crate::{RedisConnectionManager, RedisPool, RetryOpts, Worker};
223 use bb8::Pool;
224 use tokio::sync::Mutex;
225
226 async fn redis() -> RedisPool {
227 let manager = RedisConnectionManager::new("redis://127.0.0.1/").unwrap();
228 Pool::builder().build(manager).await.unwrap()
229 }
230
231 fn job() -> Job {
232 Job {
233 class: "TestWorker".into(),
234 queue: "default".into(),
235 args: vec![1337].into(),
236 retry: RetryOpts::Yes,
237 jid: crate::new_jid(),
238 created_at: 1337.0,
239 enqueued_at: None,
240 failed_at: None,
241 error_message: None,
242 error_class: None,
243 retry_count: None,
244 retried_at: None,
245 retry_queue: None,
246 unique_for: None,
247 }
248 }
249
250 #[derive(Clone)]
251 struct TestWorker {
252 touched: Arc<Mutex<bool>>,
253 }
254
255 #[async_trait]
256 impl Worker<()> for TestWorker {
257 async fn perform(&self, _args: ()) -> Result<()> {
258 *self.touched.lock().await = true;
259 Ok(())
260 }
261 }
262
263 #[tokio::test]
264 async fn calls_through_a_middleware_stack() {
265 let inner = Arc::new(TestWorker {
266 touched: Arc::new(Mutex::new(false)),
267 });
268 let worker = Arc::new(WorkerRef::wrap(Arc::clone(&inner)));
269
270 let job = job();
271 let mut chain = Chain::empty();
272 chain.using(Box::new(HandlerMiddleware)).await;
273 chain
274 .call(&job, worker.clone(), redis().await)
275 .await
276 .unwrap();
277
278 assert!(
279 *inner.touched.lock().await,
280 "The job was processed by the middleware",
281 );
282 }
283}