Skip to main content

rustvello_sqlite/trigger/
store.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use rusqlite::OptionalExtension;
6
7use rustvello_core::error::{RustvelloError, RustvelloResult};
8use rustvello_core::trigger::TriggerStore;
9use rustvello_proto::identifiers::TaskId;
10use rustvello_proto::trigger::{
11    ConditionContext, ConditionId, TriggerCondition, TriggerDefinitionDTO, TriggerDefinitionId,
12    TriggerRunId, ValidCondition,
13};
14
15use crate::db::{blocking, lock_err, sql_err};
16
17use super::{condition_type_tag, get_condition_ids_for_trigger, parse_logic, SqliteTriggerStore};
18
19#[async_trait]
20impl TriggerStore for SqliteTriggerStore {
21    async fn register_condition(
22        &self,
23        condition: &TriggerCondition,
24    ) -> RustvelloResult<ConditionId> {
25        let db = Arc::clone(&self.db);
26        let condition = condition.clone();
27        blocking(move || {
28
29            let cond_id = condition.condition_id();
30            let json = serde_json::to_string(&condition).map_err(|e| RustvelloError::Serialization {
31                message: e.to_string(),
32            })?;
33            let cond_type = condition_type_tag(&condition);
34            let event_code = match &condition {
35                TriggerCondition::Event(evt) => Some(evt.event_code.clone()),
36                _ => None,
37            };
38
39            let conn = db.conn.lock().map_err(lock_err)?;
40            let tx = conn.unchecked_transaction().map_err(sql_err)?;
41
42            tx.execute(
43                "INSERT OR REPLACE INTO trg_conditions (condition_id, condition_type, event_code, condition_json) VALUES (?1, ?2, ?3, ?4)",
44                rusqlite::params![cond_id.as_str(), cond_type, &event_code, &json],
45            )
46            .map_err(sql_err)?;
47
48            for task_id in condition.source_task_ids() {
49                tx.execute(
50                    "INSERT OR IGNORE INTO trg_source_task_conditions (task_id, condition_id) VALUES (?1, ?2)",
51                    rusqlite::params![&task_id.to_string(), cond_id.as_str()],
52                )
53                .map_err(sql_err)?;
54            }
55
56            tx.commit().map_err(sql_err)?;
57            Ok(cond_id)
58
59        })
60        .await
61    }
62
63    async fn get_condition(&self, id: &ConditionId) -> RustvelloResult<Option<TriggerCondition>> {
64        let db = Arc::clone(&self.db);
65        let id = id.clone();
66        blocking(move || {
67            let conn = db.conn.lock().map_err(lock_err)?;
68            let mut stmt = conn
69                .prepare("SELECT condition_json FROM trg_conditions WHERE condition_id = ?1")
70                .map_err(sql_err)?;
71
72            let result = stmt
73                .query_row(rusqlite::params![&id.as_str()], |row| {
74                    let json: String = row.get(0)?;
75                    Ok(json)
76                })
77                .optional()
78                .map_err(sql_err)?;
79
80            match result {
81                Some(json) => {
82                    let cond: TriggerCondition =
83                        serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
84                            message: e.to_string(),
85                        })?;
86                    Ok(Some(cond))
87                }
88                None => Ok(None),
89            }
90        })
91        .await
92    }
93
94    async fn get_conditions_for_task(
95        &self,
96        task_id: &TaskId,
97    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
98        let db = Arc::clone(&self.db);
99        let task_id = task_id.clone();
100        blocking(move || {
101            let conn = db.conn.lock().map_err(lock_err)?;
102            let mut stmt = conn
103                .prepare(
104                    "SELECT c.condition_id, c.condition_json
105                     FROM trg_conditions c
106                     INNER JOIN trg_source_task_conditions stc ON c.condition_id = stc.condition_id
107                     WHERE stc.task_id = ?1",
108                )
109                .map_err(sql_err)?;
110
111            let rows = stmt
112                .query_map(rusqlite::params![&task_id.to_string()], |row| {
113                    let id: String = row.get(0)?;
114                    let json: String = row.get(1)?;
115                    Ok((id, json))
116                })
117                .map_err(sql_err)?;
118
119            let mut result = Vec::new();
120            for row in rows {
121                let (id, json) = row.map_err(sql_err)?;
122                let cond: TriggerCondition =
123                    serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
124                        message: e.to_string(),
125                    })?;
126                result.push((ConditionId::from(id), cond));
127            }
128            Ok(result)
129        })
130        .await
131    }
132
133    async fn get_cron_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
134        let db = Arc::clone(&self.db);
135        blocking(move || {
136
137            let conn = db.conn.lock().map_err(lock_err)?;
138            let mut stmt = conn
139                .prepare("SELECT condition_id, condition_json FROM trg_conditions WHERE condition_type = 'Cron'")
140                .map_err(sql_err)?;
141
142            let rows = stmt
143                .query_map([], |row| {
144                    let id: String = row.get(0)?;
145                    let json: String = row.get(1)?;
146                    Ok((id, json))
147                })
148                .map_err(sql_err)?;
149
150            let mut result = Vec::new();
151            for row in rows {
152                let (id, json) = row.map_err(sql_err)?;
153                let cond: TriggerCondition =
154                    serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
155                        message: e.to_string(),
156                    })?;
157                result.push((ConditionId::from(id), cond));
158            }
159            Ok(result)
160
161        })
162        .await
163    }
164
165    async fn get_all_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
166        let db = Arc::clone(&self.db);
167        blocking(move || {
168            let conn = db.conn.lock().map_err(lock_err)?;
169            let mut stmt = conn
170                .prepare("SELECT condition_id, condition_json FROM trg_conditions")
171                .map_err(sql_err)?;
172
173            let rows = stmt
174                .query_map([], |row| {
175                    let id: String = row.get(0)?;
176                    let json: String = row.get(1)?;
177                    Ok((id, json))
178                })
179                .map_err(sql_err)?;
180
181            let mut result = Vec::new();
182            for row in rows {
183                let (id, json) = row.map_err(sql_err)?;
184                let cond: TriggerCondition =
185                    serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
186                        message: e.to_string(),
187                    })?;
188                result.push((ConditionId::from(id), cond));
189            }
190            Ok(result)
191        })
192        .await
193    }
194
195    async fn get_event_conditions(
196        &self,
197        event_code: &str,
198    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
199        let db = Arc::clone(&self.db);
200        let event_code = event_code.to_owned();
201        blocking(move || {
202
203            let conn = db.conn.lock().map_err(lock_err)?;
204            let mut stmt = conn
205                .prepare("SELECT condition_id, condition_json FROM trg_conditions WHERE condition_type = 'Event' AND event_code = ?1")
206                .map_err(sql_err)?;
207
208            let rows = stmt
209                .query_map(rusqlite::params![event_code], |row| {
210                    let id: String = row.get(0)?;
211                    let json: String = row.get(1)?;
212                    Ok((id, json))
213                })
214                .map_err(sql_err)?;
215
216            let mut result = Vec::new();
217            for row in rows {
218                let (id, json) = row.map_err(sql_err)?;
219                let cond: TriggerCondition =
220                    serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
221                        message: e.to_string(),
222                    })?;
223                result.push((ConditionId::from(id), cond));
224            }
225            Ok(result)
226
227        })
228        .await
229    }
230
231    async fn register_trigger(&self, trigger: &TriggerDefinitionDTO) -> RustvelloResult<()> {
232        let db = Arc::clone(&self.db);
233        let trigger = trigger.clone();
234        blocking(move || {
235
236            let conn = db.conn.lock().map_err(lock_err)?;
237            let tx = conn.unchecked_transaction().map_err(sql_err)?;
238            let arg_tmpl = trigger.argument_template.as_ref().map(ToString::to_string);
239
240            tx.execute(
241                "INSERT OR REPLACE INTO trg_triggers (trigger_id, task_id, logic, argument_template) VALUES (?1, ?2, ?3, ?4)",
242                rusqlite::params![
243                    &trigger.trigger_id.as_str(),
244                    &trigger.task_id.to_string(),
245                    &trigger.logic.to_string(),
246                    &arg_tmpl,
247                ],
248            )
249            .map_err(sql_err)?;
250
251            for cid in &trigger.condition_ids {
252                tx.execute(
253                    "INSERT OR IGNORE INTO trg_condition_triggers (condition_id, trigger_id) VALUES (?1, ?2)",
254                    rusqlite::params![cid.as_str(), &trigger.trigger_id.as_str()],
255                )
256                .map_err(sql_err)?;
257            }
258
259            tx.commit().map_err(sql_err)?;
260            Ok(())
261
262        })
263        .await
264    }
265
266    async fn get_trigger(
267        &self,
268        id: &TriggerDefinitionId,
269    ) -> RustvelloResult<Option<TriggerDefinitionDTO>> {
270        let db = Arc::clone(&self.db);
271        let id = id.clone();
272        blocking(move || {
273
274            let conn = db.conn.lock().map_err(lock_err)?;
275            let mut stmt = conn
276                .prepare(
277                    "SELECT task_id, logic, argument_template FROM trg_triggers WHERE trigger_id = ?1",
278                )
279                .map_err(sql_err)?;
280
281            let result = stmt
282                .query_row(rusqlite::params![&id.as_str()], |row| {
283                    let task_id_str: String = row.get(0)?;
284                    let logic_str: String = row.get(1)?;
285                    let arg_tmpl: Option<String> = row.get(2)?;
286                    Ok((task_id_str, logic_str, arg_tmpl))
287                })
288                .optional()
289                .map_err(sql_err)?;
290
291            match result {
292                Some((task_id_str, logic_str, arg_tmpl)) => {
293                    let task_id: TaskId =
294                        task_id_str
295                            .parse()
296                            .map_err(|e| RustvelloError::state_backend(format!("invalid task_id in database: {e}")))?;
297                    let logic = parse_logic(&logic_str)?;
298                    let argument_template = arg_tmpl
299                        .map(|s| serde_json::from_str(&s))
300                        .transpose()
301                        .map_err(|e| RustvelloError::Serialization {
302                            message: e.to_string(),
303                        })?;
304
305                    let condition_ids = get_condition_ids_for_trigger(&conn, id.as_str())?;
306
307                    Ok(Some(TriggerDefinitionDTO {
308                        trigger_id: id.clone(),
309                        task_id,
310                        condition_ids,
311                        logic,
312                        argument_template,
313                    }))
314                }
315                None => Ok(None),
316            }
317
318        })
319        .await
320    }
321
322    async fn get_triggers_for_condition(
323        &self,
324        cond_id: &ConditionId,
325    ) -> RustvelloResult<Vec<TriggerDefinitionDTO>> {
326        let db = Arc::clone(&self.db);
327        let cond_id = cond_id.clone();
328        blocking(move || {
329            let conn = db.conn.lock().map_err(lock_err)?;
330            let mut stmt = conn
331                .prepare(
332                    "SELECT t.trigger_id, t.task_id, t.logic, t.argument_template
333                     FROM trg_triggers t
334                     INNER JOIN trg_condition_triggers ct ON t.trigger_id = ct.trigger_id
335                     WHERE ct.condition_id = ?1",
336                )
337                .map_err(sql_err)?;
338
339            let rows = stmt
340                .query_map(rusqlite::params![cond_id.as_str()], |row| {
341                    let trigger_id: String = row.get(0)?;
342                    let task_id_str: String = row.get(1)?;
343                    let logic_str: String = row.get(2)?;
344                    let arg_tmpl: Option<String> = row.get(3)?;
345                    Ok((trigger_id, task_id_str, logic_str, arg_tmpl))
346                })
347                .map_err(sql_err)?;
348
349            let mut result = Vec::new();
350            for row in rows {
351                let (trigger_id, task_id_str, logic_str, arg_tmpl) = row.map_err(sql_err)?;
352                let task_id: TaskId = task_id_str.parse().map_err(|e| {
353                    RustvelloError::state_backend(format!("invalid task_id in database: {e}"))
354                })?;
355                let logic = parse_logic(&logic_str)?;
356                let argument_template = arg_tmpl
357                    .map(|s| serde_json::from_str(&s))
358                    .transpose()
359                    .map_err(|e| RustvelloError::Serialization {
360                        message: e.to_string(),
361                    })?;
362                let condition_ids = get_condition_ids_for_trigger(&conn, &trigger_id)?;
363
364                result.push(TriggerDefinitionDTO {
365                    trigger_id: TriggerDefinitionId::from(trigger_id),
366                    task_id,
367                    condition_ids,
368                    logic,
369                    argument_template,
370                });
371            }
372            Ok(result)
373        })
374        .await
375    }
376
377    async fn remove_triggers_for_task(&self, task_id: &TaskId) -> RustvelloResult<u32> {
378        let db = Arc::clone(&self.db);
379        let task_id = task_id.clone();
380        blocking(move || {
381            let conn = db.conn.lock().map_err(lock_err)?;
382            let tx = conn.unchecked_transaction().map_err(sql_err)?;
383            let task_str = task_id.to_string();
384
385            let mut stmt = tx
386                .prepare("SELECT trigger_id FROM trg_triggers WHERE task_id = ?1")
387                .map_err(sql_err)?;
388            let ids: Vec<String> = stmt
389                .query_map(rusqlite::params![&task_str], |row| row.get(0))
390                .map_err(sql_err)?
391                .collect::<Result<Vec<_>, _>>()
392                .map_err(sql_err)?;
393            drop(stmt);
394
395            let count = u32::try_from(ids.len()).unwrap_or(u32::MAX);
396            for id in &ids {
397                tx.execute(
398                    "DELETE FROM trg_condition_triggers WHERE trigger_id = ?1",
399                    rusqlite::params![id],
400                )
401                .map_err(sql_err)?;
402                tx.execute(
403                    "DELETE FROM trg_triggers WHERE trigger_id = ?1",
404                    rusqlite::params![id],
405                )
406                .map_err(sql_err)?;
407            }
408
409            tx.commit().map_err(sql_err)?;
410            Ok(count)
411        })
412        .await
413    }
414
415    async fn record_valid_condition(&self, vc: &ValidCondition) -> RustvelloResult<()> {
416        let db = Arc::clone(&self.db);
417        let vc = vc.clone();
418        blocking(move || {
419
420            let context_json =
421                serde_json::to_string(&vc.context).map_err(|e| RustvelloError::Serialization {
422                    message: e.to_string(),
423                })?;
424
425            let conn = db.conn.lock().map_err(lock_err)?;
426            conn.execute(
427                "INSERT OR REPLACE INTO trg_valid_conditions (valid_condition_id, condition_id, context_json) VALUES (?1, ?2, ?3)",
428                rusqlite::params![&vc.valid_condition_id, &vc.condition_id.as_str(), &context_json],
429            )
430            .map_err(sql_err)?;
431            Ok(())
432
433        })
434        .await
435    }
436
437    async fn get_valid_conditions(&self) -> RustvelloResult<Vec<ValidCondition>> {
438        let db = Arc::clone(&self.db);
439        blocking(move || {
440
441            let conn = db.conn.lock().map_err(lock_err)?;
442            let mut stmt = conn
443                .prepare(
444                    "SELECT valid_condition_id, condition_id, context_json FROM trg_valid_conditions",
445                )
446                .map_err(sql_err)?;
447
448            let rows = stmt
449                .query_map([], |row| {
450                    let vc_id: String = row.get(0)?;
451                    let cond_id: String = row.get(1)?;
452                    let ctx_json: String = row.get(2)?;
453                    Ok((vc_id, cond_id, ctx_json))
454                })
455                .map_err(sql_err)?;
456
457            let mut result = Vec::new();
458            for row in rows {
459                let (vc_id, cond_id, ctx_json) = row.map_err(sql_err)?;
460                let context: ConditionContext =
461                    serde_json::from_str(&ctx_json).map_err(|e| RustvelloError::Serialization {
462                        message: e.to_string(),
463                    })?;
464                result.push(ValidCondition {
465                    valid_condition_id: vc_id,
466                    condition_id: ConditionId::from(cond_id),
467                    context,
468                });
469            }
470            Ok(result)
471
472        })
473        .await
474    }
475
476    async fn clear_valid_conditions(&self, ids: &[String]) -> RustvelloResult<()> {
477        let db = Arc::clone(&self.db);
478        let ids = ids.to_vec();
479        blocking(move || {
480            let conn = db.conn.lock().map_err(lock_err)?;
481            for id in ids {
482                conn.execute(
483                    "DELETE FROM trg_valid_conditions WHERE valid_condition_id = ?1",
484                    rusqlite::params![id],
485                )
486                .map_err(sql_err)?;
487            }
488            Ok(())
489        })
490        .await
491    }
492
493    async fn get_last_cron_execution(
494        &self,
495        cond_id: &ConditionId,
496    ) -> RustvelloResult<Option<DateTime<Utc>>> {
497        let db = Arc::clone(&self.db);
498        let cond_id = cond_id.clone();
499        blocking(move || {
500            let conn = db.conn.lock().map_err(lock_err)?;
501            let mut stmt = conn
502                .prepare("SELECT last_execution FROM trg_cron_executions WHERE condition_id = ?1")
503                .map_err(sql_err)?;
504
505            let result = stmt
506                .query_row(rusqlite::params![cond_id.as_str()], |row| {
507                    let ts: String = row.get(0)?;
508                    Ok(ts)
509                })
510                .optional()
511                .map_err(sql_err)?;
512
513            match result {
514                Some(ts) => {
515                    let dt = DateTime::parse_from_rfc3339(&ts)
516                        .map(|dt| dt.with_timezone(&Utc))
517                        .map_err(|e| {
518                            RustvelloError::state_backend(format!("invalid timestamp: {}", e))
519                        })?;
520                    Ok(Some(dt))
521                }
522                None => Ok(None),
523            }
524        })
525        .await
526    }
527
528    async fn store_cron_execution(
529        &self,
530        cond_id: &ConditionId,
531        time: DateTime<Utc>,
532        expected_last: Option<DateTime<Utc>>,
533    ) -> RustvelloResult<bool> {
534        let db = Arc::clone(&self.db);
535        let cond_id = cond_id.clone();
536        blocking(move || {
537
538            let conn = db.conn.lock().map_err(lock_err)?;
539            let time_str = time.to_rfc3339();
540
541            let changed = match expected_last {
542                None => {
543                    conn.execute(
544                        "INSERT OR IGNORE INTO trg_cron_executions (condition_id, last_execution) VALUES (?1, ?2)",
545                        rusqlite::params![cond_id.as_str(), &time_str],
546                    )
547                    .map_err(sql_err)?
548                }
549                Some(expected) => {
550                    let expected_str = expected.to_rfc3339();
551                    conn.execute(
552                        "UPDATE trg_cron_executions SET last_execution = ?1 WHERE condition_id = ?2 AND last_execution = ?3",
553                        rusqlite::params![&time_str, cond_id.as_str(), &expected_str],
554                    )
555                    .map_err(sql_err)?
556                }
557            };
558
559            Ok(changed > 0)
560
561        })
562        .await
563    }
564
565    async fn claim_trigger_run(&self, run_id: &TriggerRunId) -> RustvelloResult<bool> {
566        let db = Arc::clone(&self.db);
567        let run_id = run_id.clone();
568        blocking(move || {
569
570            let conn = db.conn.lock().map_err(lock_err)?;
571            let now = Utc::now().to_rfc3339();
572
573            let changed = conn
574                .execute(
575                    "INSERT OR IGNORE INTO trg_trigger_run_claims (trigger_run_id, claimed_at) VALUES (?1, ?2)",
576                    rusqlite::params![run_id.as_str(), &now],
577                )
578                .map_err(sql_err)?;
579
580            Ok(changed > 0)
581
582        })
583        .await
584    }
585
586    async fn purge(&self) -> RustvelloResult<()> {
587        let db = Arc::clone(&self.db);
588        blocking(move || {
589            let conn = db.conn.lock().map_err(lock_err)?;
590            conn.execute_batch(
591                "DELETE FROM trg_trigger_run_claims;
592                 DELETE FROM trg_cron_executions;
593                 DELETE FROM trg_valid_conditions;
594                 DELETE FROM trg_source_task_conditions;
595                 DELETE FROM trg_condition_triggers;
596                 DELETE FROM trg_triggers;
597                 DELETE FROM trg_conditions;",
598            )
599            .map_err(sql_err)?;
600            Ok(())
601        })
602        .await
603    }
604}