Skip to main content

sidekiq/
middleware.rs

1use super::Result;
2use crate::{Counter, Job, RedisPool, RetryOpts, UnitOfWork, WorkerRef};
3use async_trait::async_trait;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tracing::error;
7
8#[async_trait]
9pub trait ServerMiddleware {
10    async fn call(
11        &self,
12        iter: ChainIter,
13        job: &Job,
14        worker: Arc<WorkerRef>,
15        redis: RedisPool,
16    ) -> Result<()>;
17}
18
19/// A pseudo iterator used to know which middleware should be called next.
20/// This is created by the Chain type.
21#[derive(Clone)]
22pub struct ChainIter {
23    stack: Arc<RwLock<Vec<Box<dyn ServerMiddleware + Send + Sync>>>>,
24    index: usize,
25}
26
27impl ChainIter {
28    #[inline]
29    pub async fn next(&self, job: &Job, worker: Arc<WorkerRef>, redis: RedisPool) -> Result<()> {
30        let stack = self.stack.read().await;
31
32        if let Some(middleware) = stack.get(self.index) {
33            middleware
34                .call(
35                    ChainIter {
36                        stack: self.stack.clone(),
37                        index: self.index + 1,
38                    },
39                    job,
40                    worker,
41                    redis,
42                )
43                .await?;
44        }
45
46        Ok(())
47    }
48}
49
50/// A chain of middlewares that will be called in order by different server middlewares.
51#[derive(Clone)]
52pub(crate) struct Chain {
53    stack: Arc<RwLock<Vec<Box<dyn ServerMiddleware + Send + Sync>>>>,
54}
55
56impl Chain {
57    // Testing helper to get an empty chain.
58    #[allow(dead_code)]
59    pub(crate) fn empty() -> Self {
60        Self {
61            stack: Arc::new(RwLock::new(vec![])),
62        }
63    }
64
65    pub(crate) fn new_with_stats(counter: Counter) -> Self {
66        Self {
67            stack: Arc::new(RwLock::new(vec![
68                Box::new(RetryMiddleware),
69                Box::new(StatsMiddleware::new(counter)),
70                Box::new(HandlerMiddleware),
71            ])),
72        }
73    }
74
75    pub(crate) async fn using(&mut self, middleware: Box<dyn ServerMiddleware + Send + Sync>) {
76        let mut stack = self.stack.write().await;
77        // HACK: Insert after retry middleware but before the handler middleware.
78        let index = if stack.is_empty() { 0 } else { stack.len() - 1 };
79
80        stack.insert(index, middleware);
81    }
82
83    #[inline]
84    pub(crate) fn iter(&self) -> ChainIter {
85        ChainIter {
86            stack: self.stack.clone(),
87            index: 0,
88        }
89    }
90
91    #[inline]
92    pub(crate) async fn call(
93        &mut self,
94        job: &Job,
95        worker: Arc<WorkerRef>,
96        redis: RedisPool,
97    ) -> Result<()> {
98        // The middleware must call bottom of the stack to the top.
99        // Each middleware should receive a lambda to the next middleware
100        // up the stack. Each middleware can short-circuit the stack by
101        // not calling the "next" middleware.
102        self.iter().next(job, worker, redis).await
103    }
104}
105
106pub struct StatsMiddleware {
107    busy_count: Counter,
108}
109
110impl StatsMiddleware {
111    fn new(busy_count: Counter) -> Self {
112        Self { busy_count }
113    }
114}
115
116#[async_trait]
117impl ServerMiddleware for StatsMiddleware {
118    #[inline]
119    async fn call(
120        &self,
121        chain: ChainIter,
122        job: &Job,
123        worker: Arc<WorkerRef>,
124        redis: RedisPool,
125    ) -> Result<()> {
126        self.busy_count.incrby(1);
127        let res = chain.next(job, worker, redis).await;
128        self.busy_count.decrby(1);
129        res
130    }
131}
132
133struct HandlerMiddleware;
134
135#[async_trait]
136impl ServerMiddleware for HandlerMiddleware {
137    #[inline]
138    async fn call(
139        &self,
140        _chain: ChainIter,
141        job: &Job,
142        worker: Arc<WorkerRef>,
143        _redis: RedisPool,
144    ) -> Result<()> {
145        worker.call(job.args.clone()).await
146    }
147}
148
149struct RetryMiddleware;
150
151#[async_trait]
152impl ServerMiddleware for RetryMiddleware {
153    #[inline]
154    async fn call(
155        &self,
156        chain: ChainIter,
157        job: &Job,
158        worker: Arc<WorkerRef>,
159        redis: RedisPool,
160    ) -> Result<()> {
161        // Check the job for a max retries N in the retry field and then fall
162        // back to the worker default max retries.
163        let max_retries = if let RetryOpts::Max(max_retries) = job.retry {
164            max_retries
165        } else {
166            worker.max_retries()
167        };
168
169        let err = {
170            match chain.next(job, worker, redis.clone()).await {
171                Ok(()) => return Ok(()),
172                Err(err) => err,
173            }
174        };
175
176        let mut job = job.clone();
177
178        // Update error fields on the job.
179        job.error_message = Some(format!("{err:?}"));
180        job.error_class = Some(match &err {
181            crate::Error::Message(_) => "RuntimeError".to_string(),
182            crate::Error::Json(_) => "JSON::ParserError".to_string(),
183            crate::Error::Redis(_) | crate::Error::BB8(_) => "Redis::BaseError".to_string(),
184            _ => "StandardError".to_string(),
185        });
186        if job.retry_count.is_some() {
187            job.retried_at = Some(chrono::Utc::now().timestamp_millis() as f64);
188        } else {
189            job.failed_at = Some(chrono::Utc::now().timestamp_millis() as f64);
190        }
191        // Match Ruby Sidekiq: retry_count starts at 0 on first failure.
192        let retry_count = if job.retry_count.is_some() {
193            job.retry_count.unwrap_or(0) + 1
194        } else {
195            0
196        };
197        job.retry_count = Some(retry_count);
198
199        // Attempt the retry.
200        if retry_count >= max_retries || job.retry == RetryOpts::Never {
201            error!({
202                "status" = "dead",
203                "class" = &job.class,
204                "jid" = &job.jid,
205                "queue" = &job.queue,
206                "err" = &job.error_message
207            }, "Max retries exceeded, moving job to dead set");
208
209            // Add to the dead set so the job is visible in Sidekiq web UI.
210            // Score is float seconds (matching Ruby's Time.now.to_f).
211            let now = chrono::Utc::now().timestamp_millis() as f64 / 1000.0;
212            if let Err(err) = redis
213                .get()
214                .await?
215                .zadd("dead".to_string(), serde_json::to_string(&job)?, now)
216                .await
217            {
218                error!("Failed to add job to dead set: {:?}", err);
219            }
220        } else {
221            error!({
222                "status" = "fail",
223                "class" = &job.class,
224                "jid" = &job.jid,
225                "queue" = &job.queue,
226                "retry_queue" = &job.retry_queue,
227                "err" = &job.error_message
228            }, "Scheduling job for retry in the future");
229
230            // We will now make sure we use the new retry_queue option if set.
231            if let Some(ref retry_queue) = job.retry_queue {
232                job.queue = retry_queue.into();
233            }
234
235            UnitOfWork::from_job(job).reenqueue(&redis).await?;
236        }
237
238        Ok(())
239    }
240}
241
242#[cfg(test)]
243mod test {
244    use super::*;
245    use crate::{RedisConnectionManager, RedisPool, RetryOpts, Worker};
246    use bb8::Pool;
247    use tokio::sync::Mutex;
248
249    async fn redis() -> RedisPool {
250        let manager = RedisConnectionManager::new("redis://127.0.0.1/").unwrap();
251        Pool::builder().build(manager).await.unwrap()
252    }
253
254    fn job() -> Job {
255        Job {
256            class: "TestWorker".into(),
257            queue: "default".into(),
258            args: vec![1337].into(),
259            retry: RetryOpts::Yes,
260            jid: crate::new_jid(),
261            created_at: 1337.0,
262            enqueued_at: None,
263            failed_at: None,
264            error_message: None,
265            error_class: None,
266            retry_count: None,
267            retried_at: None,
268            retry_queue: None,
269            unique_for: None,
270        }
271    }
272
273    #[derive(Clone)]
274    struct TestWorker {
275        touched: Arc<Mutex<bool>>,
276    }
277
278    #[async_trait]
279    impl Worker<()> for TestWorker {
280        async fn perform(&self, _args: ()) -> Result<()> {
281            *self.touched.lock().await = true;
282            Ok(())
283        }
284    }
285
286    #[tokio::test]
287    async fn calls_through_a_middleware_stack() {
288        let inner = Arc::new(TestWorker {
289            touched: Arc::new(Mutex::new(false)),
290        });
291        let worker = Arc::new(WorkerRef::wrap(Arc::clone(&inner)));
292
293        let job = job();
294        let mut chain = Chain::empty();
295        chain.using(Box::new(HandlerMiddleware)).await;
296        chain
297            .call(&job, worker.clone(), redis().await)
298            .await
299            .unwrap();
300
301        assert!(
302            *inner.touched.lock().await,
303            "The job was processed by the middleware",
304        );
305    }
306}