1use std::sync::Arc;
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use redis::AsyncCommands;
6use tracing;
7
8use rustvello_core::error::{RustvelloError, RustvelloResult};
9use rustvello_core::trigger::TriggerStore;
10use rustvello_proto::identifiers::TaskId;
11use rustvello_proto::trigger::{
12 ConditionId, TriggerCondition, TriggerDefinitionDTO, TriggerDefinitionId, TriggerRunId,
13 ValidCondition,
14};
15
16use crate::connection::{redis_err, scan_keys, RedisPool};
17
18async fn batch_get_conditions(
20 conn: &mut redis::aio::MultiplexedConnection,
21 member_ids: &[String],
22 cond_prefix: &str,
23) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
24 if member_ids.is_empty() {
25 return Ok(Vec::new());
26 }
27 let keys: Vec<String> = member_ids
28 .iter()
29 .map(|id| {
30 let mut k = String::with_capacity(cond_prefix.len() + id.len());
31 k.push_str(cond_prefix);
32 k.push_str(id);
33 k
34 })
35 .collect();
36 let values: Vec<Option<String>> = redis::cmd("MGET")
37 .arg(&keys)
38 .query_async(conn)
39 .await
40 .map_err(redis_err)?;
41 let mut result = Vec::with_capacity(member_ids.len());
42 for (cid_str, val) in member_ids.iter().zip(values) {
43 if let Some(json) = val {
44 match serde_json::from_str::<TriggerCondition>(&json) {
45 Ok(cond) => result.push((ConditionId::from(cid_str.clone()), cond)),
46 Err(e) => {
47 tracing::warn!("Failed to deserialize condition {}: {}", cid_str, e);
48 }
49 }
50 }
51 }
52 Ok(result)
53}
54
55#[non_exhaustive]
57pub struct RedisTriggerStore {
58 pool: Arc<RedisPool>,
59 cond_prefix: String,
60 cond_task_prefix: String,
61 trigger_prefix: String,
62 cond_trigger_prefix: String,
63 valid_cond_prefix: String,
64 cron_exec_prefix: String,
65 run_prefix: String,
66 trigger_task_prefix: String,
67 cron_index: String,
68 event_index_prefix: String,
69}
70
71impl RedisTriggerStore {
72 pub fn new(pool: Arc<RedisPool>) -> Self {
73 let p = pool.prefix();
74 Self {
75 cond_prefix: format!("{p}trg:cond:"),
76 cond_task_prefix: format!("{p}trg:cond_task:"),
77 trigger_prefix: format!("{p}trg:def:"),
78 cond_trigger_prefix: format!("{p}trg:cond_trg:"),
79 valid_cond_prefix: format!("{p}trg:valid:"),
80 cron_exec_prefix: format!("{p}trg:cron_exec:"),
81 run_prefix: format!("{p}trg:run:"),
82 trigger_task_prefix: format!("{p}trg:trg_task:"),
83 cron_index: format!("{p}trg:cron_ids"),
84 event_index_prefix: format!("{p}trg:event:"),
85 pool,
86 }
87 }
88}
89
90#[async_trait]
91impl TriggerStore for RedisTriggerStore {
92 async fn register_condition(
93 &self,
94 condition: &TriggerCondition,
95 ) -> RustvelloResult<ConditionId> {
96 let cond_id = condition.condition_id();
97 let mut conn = self.pool.conn().await?;
98 let json = serde_json::to_string(condition).map_err(|e| RustvelloError::Serialization {
99 message: e.to_string(),
100 })?;
101 conn.set::<_, _, ()>(format!("{}{}", &self.cond_prefix, cond_id.as_str()), &json)
102 .await
103 .map_err(redis_err)?;
104
105 for task_id in condition.source_task_ids() {
107 conn.sadd::<_, _, ()>(
108 format!("{}{}", &self.cond_task_prefix, task_id),
109 cond_id.as_str().to_owned(),
110 )
111 .await
112 .map_err(redis_err)?;
113 }
114
115 if matches!(condition, TriggerCondition::Cron(_)) {
117 conn.sadd::<_, _, ()>(&self.cron_index, cond_id.as_str().to_owned())
118 .await
119 .map_err(redis_err)?;
120 }
121 if let TriggerCondition::Event(ev) = condition {
122 conn.sadd::<_, _, ()>(
123 format!("{}{}", &self.event_index_prefix, ev.event_code),
124 cond_id.as_str().to_owned(),
125 )
126 .await
127 .map_err(redis_err)?;
128 }
129
130 Ok(cond_id)
131 }
132
133 async fn get_condition(&self, id: &ConditionId) -> RustvelloResult<Option<TriggerCondition>> {
134 let mut conn = self.pool.conn().await?;
135 let val: Option<String> = conn
136 .get(format!("{}{}", &self.cond_prefix, id.as_str()))
137 .await
138 .map_err(redis_err)?;
139 match val {
140 Some(s) => {
141 let c: TriggerCondition =
142 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
143 message: e.to_string(),
144 })?;
145 Ok(Some(c))
146 }
147 None => Ok(None),
148 }
149 }
150
151 async fn get_conditions_for_task(
152 &self,
153 task_id: &TaskId,
154 ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
155 let mut conn = self.pool.conn().await?;
156 let members: Vec<String> = conn
157 .smembers(format!("{}{}", &self.cond_task_prefix, task_id))
158 .await
159 .map_err(redis_err)?;
160
161 batch_get_conditions(&mut conn, &members, &self.cond_prefix).await
162 }
163
164 async fn get_cron_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
165 let mut conn = self.pool.conn().await?;
166 let members: Vec<String> = conn.smembers(&self.cron_index).await.map_err(redis_err)?;
168
169 batch_get_conditions(&mut conn, &members, &self.cond_prefix).await
170 }
171
172 async fn get_event_conditions(
173 &self,
174 event_code: &str,
175 ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
176 let mut conn = self.pool.conn().await?;
177 let members: Vec<String> = conn
179 .smembers(format!("{}{}", &self.event_index_prefix, event_code))
180 .await
181 .map_err(redis_err)?;
182
183 batch_get_conditions(&mut conn, &members, &self.cond_prefix).await
184 }
185
186 async fn register_trigger(&self, trigger: &TriggerDefinitionDTO) -> RustvelloResult<()> {
187 let mut conn = self.pool.conn().await?;
188 let json = serde_json::to_string(trigger).map_err(|e| RustvelloError::Serialization {
189 message: e.to_string(),
190 })?;
191 conn.set::<_, _, ()>(
192 format!("{}{}", &self.trigger_prefix, trigger.trigger_id.as_str()),
193 &json,
194 )
195 .await
196 .map_err(redis_err)?;
197
198 for cid in &trigger.condition_ids {
200 conn.sadd::<_, _, ()>(
201 format!("{}{}", &self.cond_trigger_prefix, cid.as_str()),
202 trigger.trigger_id.as_str().to_owned(),
203 )
204 .await
205 .map_err(redis_err)?;
206 }
207
208 conn.sadd::<_, _, ()>(
210 format!("{}{}", &self.trigger_task_prefix, trigger.task_id),
211 trigger.trigger_id.as_str().to_owned(),
212 )
213 .await
214 .map_err(redis_err)?;
215
216 Ok(())
217 }
218
219 async fn get_trigger(
220 &self,
221 id: &TriggerDefinitionId,
222 ) -> RustvelloResult<Option<TriggerDefinitionDTO>> {
223 let mut conn = self.pool.conn().await?;
224 let val: Option<String> = conn
225 .get(format!("{}{}", &self.trigger_prefix, id.as_str()))
226 .await
227 .map_err(redis_err)?;
228 match val {
229 Some(s) => {
230 let t: TriggerDefinitionDTO =
231 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
232 message: e.to_string(),
233 })?;
234 Ok(Some(t))
235 }
236 None => Ok(None),
237 }
238 }
239
240 async fn get_triggers_for_condition(
241 &self,
242 cond_id: &ConditionId,
243 ) -> RustvelloResult<Vec<TriggerDefinitionDTO>> {
244 let mut conn = self.pool.conn().await?;
245 let members: Vec<String> = conn
246 .smembers(format!("{}{}", &self.cond_trigger_prefix, cond_id.as_str()))
247 .await
248 .map_err(redis_err)?;
249
250 if members.is_empty() {
251 return Ok(Vec::new());
252 }
253
254 let keys: Vec<String> = members
255 .iter()
256 .map(|tid| format!("{}{}", &self.trigger_prefix, tid))
257 .collect();
258 let values: Vec<Option<String>> = redis::cmd("MGET")
259 .arg(&keys)
260 .query_async(&mut conn)
261 .await
262 .map_err(redis_err)?;
263
264 let mut result = Vec::new();
265 for val in values.into_iter().flatten() {
266 let t: TriggerDefinitionDTO =
267 serde_json::from_str(&val).map_err(|e| RustvelloError::Serialization {
268 message: e.to_string(),
269 })?;
270 result.push(t);
271 }
272 Ok(result)
273 }
274
275 async fn remove_triggers_for_task(&self, task_id: &TaskId) -> RustvelloResult<u32> {
276 let mut conn = self.pool.conn().await?;
277 let members: Vec<String> = conn
278 .smembers(format!("{}{}", &self.trigger_task_prefix, task_id))
279 .await
280 .map_err(redis_err)?;
281
282 let count = u32::try_from(members.len()).unwrap_or(u32::MAX);
283 for tid_str in &members {
284 let val: Option<String> = conn
287 .get(format!("{}{}", &self.trigger_prefix, tid_str))
288 .await
289 .map_err(redis_err)?;
290 if let Some(json) = val {
291 if let Ok(trigger) = serde_json::from_str::<TriggerDefinitionDTO>(&json) {
292 for cid in &trigger.condition_ids {
293 conn.srem::<_, _, ()>(
294 format!("{}{}", &self.cond_trigger_prefix, cid.as_str()),
295 tid_str.as_str(),
296 )
297 .await
298 .map_err(redis_err)?;
299 }
300 }
301 }
302 conn.del::<_, ()>(format!("{}{}", &self.trigger_prefix, tid_str))
303 .await
304 .map_err(redis_err)?;
305 }
306 conn.del::<_, ()>(format!("{}{}", &self.trigger_task_prefix, task_id))
307 .await
308 .map_err(redis_err)?;
309 Ok(count)
310 }
311
312 async fn record_valid_condition(&self, vc: &ValidCondition) -> RustvelloResult<()> {
313 let mut conn = self.pool.conn().await?;
314 let json = serde_json::to_string(vc).map_err(|e| RustvelloError::Serialization {
315 message: e.to_string(),
316 })?;
317 let key = format!("{}{}", &self.valid_cond_prefix, vc.valid_condition_id);
318 conn.set::<_, _, ()>(&key, &json).await.map_err(redis_err)
319 }
320
321 async fn get_valid_conditions(&self) -> RustvelloResult<Vec<ValidCondition>> {
322 let mut conn = self.pool.conn().await?;
323 let keys = scan_keys(&mut conn, &format!("{}*", &self.valid_cond_prefix)).await?;
324
325 if keys.is_empty() {
326 return Ok(Vec::new());
327 }
328
329 let values: Vec<Option<String>> = redis::cmd("MGET")
330 .arg(&keys)
331 .query_async(&mut conn)
332 .await
333 .map_err(redis_err)?;
334
335 let mut result = Vec::new();
336 for val in values.into_iter().flatten() {
337 let vc: ValidCondition =
338 serde_json::from_str(&val).map_err(|e| RustvelloError::Serialization {
339 message: e.to_string(),
340 })?;
341 result.push(vc);
342 }
343 Ok(result)
344 }
345
346 async fn clear_valid_conditions(&self, ids: &[String]) -> RustvelloResult<()> {
347 if ids.is_empty() {
348 return Ok(());
349 }
350 let mut conn = self.pool.conn().await?;
351 let keys: Vec<String> = ids
352 .iter()
353 .map(|id| format!("{}{}", &self.valid_cond_prefix, id))
354 .collect();
355 conn.del::<_, ()>(keys).await.map_err(redis_err)
356 }
357
358 async fn get_last_cron_execution(
359 &self,
360 cond_id: &ConditionId,
361 ) -> RustvelloResult<Option<DateTime<Utc>>> {
362 let mut conn = self.pool.conn().await?;
363 let val: Option<String> = conn
364 .get(format!("{}{}", &self.cron_exec_prefix, cond_id.as_str()))
365 .await
366 .map_err(redis_err)?;
367 match val {
368 Some(s) => {
369 let dt = DateTime::parse_from_rfc3339(&s)
370 .map(|d| d.with_timezone(&Utc))
371 .map_err(|e| RustvelloError::Serialization {
372 message: format!("cron timestamp: {}", e),
373 })?;
374 Ok(Some(dt))
375 }
376 None => Ok(None),
377 }
378 }
379
380 async fn store_cron_execution(
381 &self,
382 cond_id: &ConditionId,
383 time: DateTime<Utc>,
384 expected_last: Option<DateTime<Utc>>,
385 ) -> RustvelloResult<bool> {
386 let key = format!("{}{}", &self.cron_exec_prefix, cond_id.as_str());
387 let mut conn = self.pool.conn().await?;
388
389 let expected_val = match expected_last {
391 Some(dt) => dt.to_rfc3339(),
392 None => String::new(), };
394 let new_val = time.to_rfc3339();
395
396 let script = redis::Script::new(
399 r"
400 local current = redis.call('GET', KEYS[1])
401 local expected = ARGV[1]
402 if expected == '' then
403 if current == false then
404 redis.call('SET', KEYS[1], ARGV[2])
405 return 1
406 else
407 return 0
408 end
409 else
410 if current == expected then
411 redis.call('SET', KEYS[1], ARGV[2])
412 return 1
413 else
414 return 0
415 end
416 end
417 ",
418 );
419 let result: i32 = script
420 .key(&key)
421 .arg(&expected_val)
422 .arg(&new_val)
423 .invoke_async(&mut conn)
424 .await
425 .map_err(redis_err)?;
426 Ok(result == 1)
427 }
428
429 async fn claim_trigger_run(&self, run_id: &TriggerRunId) -> RustvelloResult<bool> {
430 let key = format!("{}{}", &self.run_prefix, run_id.as_str());
431 let mut conn = self.pool.conn().await?;
432 let set: bool = conn.set_nx(&key, "1").await.map_err(redis_err)?;
434 if set {
435 conn.expire::<_, ()>(&key, 3600).await.map_err(redis_err)?;
437 }
438 Ok(set)
439 }
440
441 async fn purge(&self) -> RustvelloResult<()> {
442 let prefixes = [
443 &self.cond_prefix,
444 &self.cond_task_prefix,
445 &self.trigger_prefix,
446 &self.cond_trigger_prefix,
447 &self.valid_cond_prefix,
448 &self.cron_exec_prefix,
449 &self.run_prefix,
450 &self.trigger_task_prefix,
451 &self.event_index_prefix,
452 ];
453 let mut conn = self.pool.conn().await?;
454 for prefix in prefixes {
455 let keys = scan_keys(&mut conn, &format!("{}*", prefix)).await?;
456 if !keys.is_empty() {
457 conn.del::<_, ()>(keys).await.map_err(redis_err)?;
458 }
459 }
460 conn.del::<_, ()>(&self.cron_index)
462 .await
463 .map_err(redis_err)?;
464 Ok(())
465 }
466
467 async fn get_all_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
468 let mut conn = self.pool.conn().await?;
469 let keys = scan_keys(&mut conn, &format!("{}*", &self.cond_prefix)).await?;
470 let ids: Vec<String> = keys
471 .iter()
472 .filter_map(|k| k.strip_prefix(&self.cond_prefix).map(String::from))
473 .collect();
474 batch_get_conditions(&mut conn, &ids, &self.cond_prefix).await
475 }
476}