simple_job_queue/
redis.rs

1use std::time::Duration;
2
3use async_trait::async_trait;
4use redis::{
5    aio::Connection,
6    streams::{StreamReadOptions, StreamReadReply},
7    AsyncCommands, Client, IntoConnectionInfo, Value,
8};
9use serde::{de::DeserializeOwned, Serialize};
10use uuid::Uuid;
11
12use crate::{error::JobQueueError, Job, JobQueueBackend};
13
14const KEY_DATA: &str = "data";
15
16#[derive(Clone)]
17pub struct RedisJobQueueBackendOptions {
18    min_idle_time: Duration,
19    new_delivery_fetch_timeout: Duration,
20    polling_interval: Duration,
21}
22
23impl RedisJobQueueBackendOptions {
24    pub fn min_idle_time(mut self, d: Duration) -> Self {
25        self.min_idle_time = d;
26        self
27    }
28
29    pub fn new_delivery_fetch_timeout(mut self, d: Duration) -> Self {
30        self.new_delivery_fetch_timeout = d;
31        self
32    }
33
34    pub fn polling_interval(mut self, d: Duration) -> Self {
35        self.polling_interval = d;
36        self
37    }
38}
39
40impl Default for RedisJobQueueBackendOptions {
41    fn default() -> Self {
42        Self {
43            min_idle_time: Duration::from_secs(60),
44            new_delivery_fetch_timeout: Duration::from_secs(5),
45            polling_interval: Duration::from_secs(5),
46        }
47    }
48}
49
50#[derive(Clone)]
51pub struct RedisJobQueueBackend {
52    client: Client,
53    name: String,
54    consumer_id: Uuid,
55    options: RedisJobQueueBackendOptions,
56}
57
58impl RedisJobQueueBackend {
59    pub fn new<I: IntoConnectionInfo>(
60        connection_info: I,
61        name: String,
62        options: RedisJobQueueBackendOptions,
63    ) -> Result<Self, JobQueueError> {
64        Ok(Self {
65            client: Client::open(connection_info)?,
66            name,
67            consumer_id: Uuid::new_v4(),
68            options,
69        })
70    }
71}
72
73impl RedisJobQueueBackend {
74    async fn read_job<T>(
75        &self,
76        conn: &mut Connection,
77        id: &str,
78        block: i64,
79    ) -> Result<Option<(Job<T>, RedisJobContext)>, JobQueueError>
80    where
81        T: DeserializeOwned,
82    {
83        let mut options = StreamReadOptions::default()
84            .group(&self.name, &self.consumer_id.to_string())
85            .count(1);
86
87        if block > 0 {
88            options = options.block(block as usize);
89        }
90
91        conn.xread_options::<_, _, StreamReadReply>(&[&self.name], &[id], &options)
92            .await?
93            .keys
94            .get(0)
95            .map(|k| k.ids.get(0))
96            .flatten()
97            .map(|v| {
98                let ctx = RedisJobContext { id: v.id.clone() };
99                match v.get(KEY_DATA) {
100                    Some(job) => Ok((job, ctx)),
101                    None => Err(JobQueueError::MalformedJob),
102                }
103            })
104            .transpose()
105    }
106}
107
108pub struct RedisJobContext {
109    id: String,
110}
111
112#[async_trait]
113impl<T> JobQueueBackend<T> for RedisJobQueueBackend
114where
115    T: Serialize + DeserializeOwned + Send + Sync + 'static,
116{
117    type Context = RedisJobContext;
118
119    async fn setup(&self) -> Result<(), JobQueueError> {
120        let mut conn = self.client.get_async_connection().await?;
121        match conn
122            .xgroup_create_mkstream::<_, _, _, String>(&self.name, &self.name, 0)
123            .await
124        {
125            Ok(_) => (),
126            Err(err) => match err.code() {
127                Some(code) if code == "BUSYGROUP" => (),
128                _ => return Err(JobQueueError::from(err)),
129            },
130        }
131
132        Ok(())
133    }
134
135    async fn produce(&self, job: Job<T>) -> Result<(), JobQueueError> {
136        let mut conn = self.client.get_async_connection().await?;
137        conn.xadd(&self.name, "*", &[(KEY_DATA, job)]).await?;
138        Ok(())
139    }
140
141    async fn consume(&self) -> Result<(Job<T>, Self::Context), JobQueueError> {
142        let mut conn = self.client.get_async_connection().await?;
143
144        redis::cmd("XAUTOCLAIM")
145            .arg(&self.name)
146            .arg(&self.name)
147            .arg(&self.consumer_id.to_string())
148            .arg(self.options.min_idle_time.as_millis() as i64)
149            .arg(0)
150            .arg("COUNT")
151            .arg(1)
152            .arg("JUSTID")
153            .query_async::<_, Value>(&mut conn)
154            .await?;
155
156        let mut pending_id = "0".to_string();
157        loop {
158            let result = self.read_job::<T>(&mut conn, &pending_id, -1).await?;
159
160            match result {
161                Some((job, ctx)) if !job.should_process() => {
162                    pending_id = ctx.id;
163                    continue;
164                }
165                Some((job, ctx)) => {
166                    break Ok((job, ctx));
167                }
168                None => {
169                    match self
170                        .read_job::<T>(
171                            &mut conn,
172                            ">",
173                            self.options.new_delivery_fetch_timeout.as_millis() as i64,
174                        )
175                        .await?
176                    {
177                        Some((job, _)) if !job.should_process() => {
178                            tokio::time::sleep(self.options.polling_interval).await;
179                            pending_id = "0".to_string();
180                            continue;
181                        }
182                        Some((job, ctx)) => {
183                            break Ok((job, ctx));
184                        }
185                        None => {
186                            pending_id = "0".to_string();
187                        }
188                    }
189                }
190            }
191        }
192    }
193
194    async fn done(&self, _: Job<T>, ctx: Self::Context) {
195        match self.client.get_async_connection().await {
196            Ok(mut conn) => match conn
197                .xack::<_, _, _, Value>(&self.name, &self.name, &[&ctx.id])
198                .await
199            {
200                Ok(_) => match conn.xdel::<_, _, Value>(&self.name, &[ctx.id]).await {
201                    Ok(_) => (),
202                    Err(_) => todo!("handle done notification failure"),
203                },
204                Err(_) => todo!("handle done notification failure"),
205            },
206            Err(_) => todo!("handle done notification failure"),
207        }
208    }
209
210    async fn failed(&self, _: Job<T>, _: Self::Context) {
211        todo!("Handle job failures")
212    }
213}