sidekiq/
middleware.rs

1use super::Result;
2use crate::{Counter, Job, RedisPool, RetryOpts, UnitOfWork, WorkerRef};
3use async_trait::async_trait;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tracing::error;
7
8#[async_trait]
9pub trait ServerMiddleware {
10    async fn call(
11        &self,
12        iter: ChainIter,
13        job: &Job,
14        worker: Arc<WorkerRef>,
15        redis: RedisPool,
16    ) -> Result<()>;
17}
18
19/// A pseudo iterator used to know which middleware should be called next.
20/// This is created by the Chain type.
21#[derive(Clone)]
22pub struct ChainIter {
23    stack: Arc<RwLock<Vec<Box<dyn ServerMiddleware + Send + Sync>>>>,
24    index: usize,
25}
26
27impl ChainIter {
28    #[inline]
29    pub async fn next(&self, job: &Job, worker: Arc<WorkerRef>, redis: RedisPool) -> Result<()> {
30        let stack = self.stack.read().await;
31
32        if let Some(middleware) = stack.get(self.index) {
33            middleware
34                .call(
35                    ChainIter {
36                        stack: self.stack.clone(),
37                        index: self.index + 1,
38                    },
39                    job,
40                    worker,
41                    redis,
42                )
43                .await?;
44        }
45
46        Ok(())
47    }
48}
49
50/// A chain of middlewares that will be called in order by different server middlewares.
51#[derive(Clone)]
52pub(crate) struct Chain {
53    stack: Arc<RwLock<Vec<Box<dyn ServerMiddleware + Send + Sync>>>>,
54}
55
56impl Chain {
57    // Testing helper to get an empty chain.
58    #[allow(dead_code)]
59    pub(crate) fn empty() -> Self {
60        Self {
61            stack: Arc::new(RwLock::new(vec![])),
62        }
63    }
64
65    pub(crate) fn new_with_stats(counter: Counter) -> Self {
66        Self {
67            stack: Arc::new(RwLock::new(vec![
68                Box::new(RetryMiddleware),
69                Box::new(StatsMiddleware::new(counter)),
70                Box::new(HandlerMiddleware),
71            ])),
72        }
73    }
74
75    pub(crate) async fn using(&mut self, middleware: Box<dyn ServerMiddleware + Send + Sync>) {
76        let mut stack = self.stack.write().await;
77        // HACK: Insert after retry middleware but before the handler middleware.
78        let index = if stack.is_empty() { 0 } else { stack.len() - 1 };
79
80        stack.insert(index, middleware);
81    }
82
83    #[inline]
84    pub(crate) fn iter(&self) -> ChainIter {
85        ChainIter {
86            stack: self.stack.clone(),
87            index: 0,
88        }
89    }
90
91    #[inline]
92    pub(crate) async fn call(
93        &mut self,
94        job: &Job,
95        worker: Arc<WorkerRef>,
96        redis: RedisPool,
97    ) -> Result<()> {
98        // The middleware must call bottom of the stack to the top.
99        // Each middleware should receive a lambda to the next middleware
100        // up the stack. Each middleware can short-circuit the stack by
101        // not calling the "next" middleware.
102        self.iter().next(job, worker, redis).await
103    }
104}
105
106pub struct StatsMiddleware {
107    busy_count: Counter,
108}
109
110impl StatsMiddleware {
111    fn new(busy_count: Counter) -> Self {
112        Self { busy_count }
113    }
114}
115
116#[async_trait]
117impl ServerMiddleware for StatsMiddleware {
118    #[inline]
119    async fn call(
120        &self,
121        chain: ChainIter,
122        job: &Job,
123        worker: Arc<WorkerRef>,
124        redis: RedisPool,
125    ) -> Result<()> {
126        self.busy_count.incrby(1);
127        let res = chain.next(job, worker, redis).await;
128        self.busy_count.decrby(1);
129        res
130    }
131}
132
133struct HandlerMiddleware;
134
135#[async_trait]
136impl ServerMiddleware for HandlerMiddleware {
137    #[inline]
138    async fn call(
139        &self,
140        _chain: ChainIter,
141        job: &Job,
142        worker: Arc<WorkerRef>,
143        _redis: RedisPool,
144    ) -> Result<()> {
145        worker.call(job.args.clone()).await
146    }
147}
148
149struct RetryMiddleware;
150
151#[async_trait]
152impl ServerMiddleware for RetryMiddleware {
153    #[inline]
154    async fn call(
155        &self,
156        chain: ChainIter,
157        job: &Job,
158        worker: Arc<WorkerRef>,
159        redis: RedisPool,
160    ) -> Result<()> {
161        // Check the job for a max retries N in the retry field and then fall
162        // back to the worker default max retries.
163        let max_retries = if let RetryOpts::Max(max_retries) = job.retry {
164            max_retries
165        } else {
166            worker.max_retries()
167        };
168
169        let err = {
170            match chain.next(job, worker, redis.clone()).await {
171                Ok(()) => return Ok(()),
172                Err(err) => format!("{err:?}"),
173            }
174        };
175
176        let mut job = job.clone();
177
178        // Update error fields on the job.
179        job.error_message = Some(err);
180        if job.retry_count.is_some() {
181            job.retried_at = Some(chrono::Utc::now().timestamp() as f64);
182        } else {
183            job.failed_at = Some(chrono::Utc::now().timestamp() as f64);
184        }
185        let retry_count = job.retry_count.unwrap_or(0) + 1;
186        job.retry_count = Some(retry_count);
187
188        // Attempt the retry.
189        if retry_count > max_retries || job.retry == RetryOpts::Never {
190            error!({
191                "status" = "fail",
192                "class" = &job.class,
193                "jid" = &job.jid,
194                "queue" = &job.queue,
195                "err" = &job.error_message
196            }, "Max retries exceeded, will not reschedule job");
197        } else {
198            error!({
199                "status" = "fail",
200                "class" = &job.class,
201                "jid" = &job.jid,
202                "queue" = &job.queue,
203                "retry_queue" = &job.retry_queue,
204                "err" = &job.error_message
205            }, "Scheduling job for retry in the future");
206
207            // We will now make sure we use the new retry_queue option if set.
208            if let Some(ref retry_queue) = job.retry_queue {
209                job.queue = retry_queue.into();
210            }
211
212            UnitOfWork::from_job(job).reenqueue(&redis).await?;
213        }
214
215        Ok(())
216    }
217}
218
219#[cfg(test)]
220mod test {
221    use super::*;
222    use crate::{RedisConnectionManager, RedisPool, RetryOpts, Worker};
223    use bb8::Pool;
224    use tokio::sync::Mutex;
225
226    async fn redis() -> RedisPool {
227        let manager = RedisConnectionManager::new("redis://127.0.0.1/").unwrap();
228        Pool::builder().build(manager).await.unwrap()
229    }
230
231    fn job() -> Job {
232        Job {
233            class: "TestWorker".into(),
234            queue: "default".into(),
235            args: vec![1337].into(),
236            retry: RetryOpts::Yes,
237            jid: crate::new_jid(),
238            created_at: 1337.0,
239            enqueued_at: None,
240            failed_at: None,
241            error_message: None,
242            error_class: None,
243            retry_count: None,
244            retried_at: None,
245            retry_queue: None,
246            unique_for: None,
247        }
248    }
249
250    #[derive(Clone)]
251    struct TestWorker {
252        touched: Arc<Mutex<bool>>,
253    }
254
255    #[async_trait]
256    impl Worker<()> for TestWorker {
257        async fn perform(&self, _args: ()) -> Result<()> {
258            *self.touched.lock().await = true;
259            Ok(())
260        }
261    }
262
263    #[tokio::test]
264    async fn calls_through_a_middleware_stack() {
265        let inner = Arc::new(TestWorker {
266            touched: Arc::new(Mutex::new(false)),
267        });
268        let worker = Arc::new(WorkerRef::wrap(Arc::clone(&inner)));
269
270        let job = job();
271        let mut chain = Chain::empty();
272        chain.using(Box::new(HandlerMiddleware)).await;
273        chain
274            .call(&job, worker.clone(), redis().await)
275            .await
276            .unwrap();
277
278        assert!(
279            *inner.touched.lock().await,
280            "The job was processed by the middleware",
281        );
282    }
283}