1use std::time::Duration;
2
3use async_trait::async_trait;
4use redis::{
5 aio::Connection,
6 streams::{StreamReadOptions, StreamReadReply},
7 AsyncCommands, Client, IntoConnectionInfo, Value,
8};
9use serde::{de::DeserializeOwned, Serialize};
10use uuid::Uuid;
11
12use crate::{error::JobQueueError, Job, JobQueueBackend};
13
14const KEY_DATA: &str = "data";
15
16#[derive(Clone)]
17pub struct RedisJobQueueBackendOptions {
18 min_idle_time: Duration,
19 new_delivery_fetch_timeout: Duration,
20 polling_interval: Duration,
21}
22
23impl RedisJobQueueBackendOptions {
24 pub fn min_idle_time(mut self, d: Duration) -> Self {
25 self.min_idle_time = d;
26 self
27 }
28
29 pub fn new_delivery_fetch_timeout(mut self, d: Duration) -> Self {
30 self.new_delivery_fetch_timeout = d;
31 self
32 }
33
34 pub fn polling_interval(mut self, d: Duration) -> Self {
35 self.polling_interval = d;
36 self
37 }
38}
39
40impl Default for RedisJobQueueBackendOptions {
41 fn default() -> Self {
42 Self {
43 min_idle_time: Duration::from_secs(60),
44 new_delivery_fetch_timeout: Duration::from_secs(5),
45 polling_interval: Duration::from_secs(5),
46 }
47 }
48}
49
50#[derive(Clone)]
51pub struct RedisJobQueueBackend {
52 client: Client,
53 name: String,
54 consumer_id: Uuid,
55 options: RedisJobQueueBackendOptions,
56}
57
58impl RedisJobQueueBackend {
59 pub fn new<I: IntoConnectionInfo>(
60 connection_info: I,
61 name: String,
62 options: RedisJobQueueBackendOptions,
63 ) -> Result<Self, JobQueueError> {
64 Ok(Self {
65 client: Client::open(connection_info)?,
66 name,
67 consumer_id: Uuid::new_v4(),
68 options,
69 })
70 }
71}
72
73impl RedisJobQueueBackend {
74 async fn read_job<T>(
75 &self,
76 conn: &mut Connection,
77 id: &str,
78 block: i64,
79 ) -> Result<Option<(Job<T>, RedisJobContext)>, JobQueueError>
80 where
81 T: DeserializeOwned,
82 {
83 let mut options = StreamReadOptions::default()
84 .group(&self.name, &self.consumer_id.to_string())
85 .count(1);
86
87 if block > 0 {
88 options = options.block(block as usize);
89 }
90
91 conn.xread_options::<_, _, StreamReadReply>(&[&self.name], &[id], &options)
92 .await?
93 .keys
94 .get(0)
95 .map(|k| k.ids.get(0))
96 .flatten()
97 .map(|v| {
98 let ctx = RedisJobContext { id: v.id.clone() };
99 match v.get(KEY_DATA) {
100 Some(job) => Ok((job, ctx)),
101 None => Err(JobQueueError::MalformedJob),
102 }
103 })
104 .transpose()
105 }
106}
107
108pub struct RedisJobContext {
109 id: String,
110}
111
112#[async_trait]
113impl<T> JobQueueBackend<T> for RedisJobQueueBackend
114where
115 T: Serialize + DeserializeOwned + Send + Sync + 'static,
116{
117 type Context = RedisJobContext;
118
119 async fn setup(&self) -> Result<(), JobQueueError> {
120 let mut conn = self.client.get_async_connection().await?;
121 match conn
122 .xgroup_create_mkstream::<_, _, _, String>(&self.name, &self.name, 0)
123 .await
124 {
125 Ok(_) => (),
126 Err(err) => match err.code() {
127 Some(code) if code == "BUSYGROUP" => (),
128 _ => return Err(JobQueueError::from(err)),
129 },
130 }
131
132 Ok(())
133 }
134
135 async fn produce(&self, job: Job<T>) -> Result<(), JobQueueError> {
136 let mut conn = self.client.get_async_connection().await?;
137 conn.xadd(&self.name, "*", &[(KEY_DATA, job)]).await?;
138 Ok(())
139 }
140
141 async fn consume(&self) -> Result<(Job<T>, Self::Context), JobQueueError> {
142 let mut conn = self.client.get_async_connection().await?;
143
144 redis::cmd("XAUTOCLAIM")
145 .arg(&self.name)
146 .arg(&self.name)
147 .arg(&self.consumer_id.to_string())
148 .arg(self.options.min_idle_time.as_millis() as i64)
149 .arg(0)
150 .arg("COUNT")
151 .arg(1)
152 .arg("JUSTID")
153 .query_async::<_, Value>(&mut conn)
154 .await?;
155
156 let mut pending_id = "0".to_string();
157 loop {
158 let result = self.read_job::<T>(&mut conn, &pending_id, -1).await?;
159
160 match result {
161 Some((job, ctx)) if !job.should_process() => {
162 pending_id = ctx.id;
163 continue;
164 }
165 Some((job, ctx)) => {
166 break Ok((job, ctx));
167 }
168 None => {
169 match self
170 .read_job::<T>(
171 &mut conn,
172 ">",
173 self.options.new_delivery_fetch_timeout.as_millis() as i64,
174 )
175 .await?
176 {
177 Some((job, _)) if !job.should_process() => {
178 tokio::time::sleep(self.options.polling_interval).await;
179 pending_id = "0".to_string();
180 continue;
181 }
182 Some((job, ctx)) => {
183 break Ok((job, ctx));
184 }
185 None => {
186 pending_id = "0".to_string();
187 }
188 }
189 }
190 }
191 }
192 }
193
194 async fn done(&self, _: Job<T>, ctx: Self::Context) {
195 match self.client.get_async_connection().await {
196 Ok(mut conn) => match conn
197 .xack::<_, _, _, Value>(&self.name, &self.name, &[&ctx.id])
198 .await
199 {
200 Ok(_) => match conn.xdel::<_, _, Value>(&self.name, &[ctx.id]).await {
201 Ok(_) => (),
202 Err(_) => todo!("handle done notification failure"),
203 },
204 Err(_) => todo!("handle done notification failure"),
205 },
206 Err(_) => todo!("handle done notification failure"),
207 }
208 }
209
210 async fn failed(&self, _: Job<T>, _: Self::Context) {
211 todo!("Handle job failures")
212 }
213}