Skip to main content

rustvello_postgres/
trigger.rs

1//! PostgreSQL trigger store implementation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8
9use rustvello_core::error::{RustvelloError, RustvelloResult};
10use rustvello_core::trigger::TriggerStore;
11use rustvello_proto::identifiers::TaskId;
12use rustvello_proto::trigger::{
13    ConditionContext, ConditionId, TriggerCondition, TriggerDefinitionDTO, TriggerDefinitionId,
14    TriggerLogic, TriggerRunId, ValidCondition,
15};
16
17use crate::db::{pg_err, Database};
18
19/// PostgreSQL-backed trigger store implementation.
20pub struct PostgresTriggerStore {
21    db: Arc<Database>,
22}
23
24impl PostgresTriggerStore {
25    pub fn new(db: Arc<Database>) -> Self {
26        Self { db }
27    }
28}
29
30fn parse_logic(s: &str) -> RustvelloResult<TriggerLogic> {
31    match s {
32        "AND" => Ok(TriggerLogic::And),
33        "OR" => Ok(TriggerLogic::Or),
34        other => Err(RustvelloError::state_backend(format!(
35            "unknown trigger logic value: {other:?}"
36        ))),
37    }
38}
39
40// parse_task_id removed — use TaskId::from_str (via `.parse()`) instead.
41
42/// Return the discriminant tag for a condition variant.
43fn condition_type_tag(condition: &TriggerCondition) -> &'static str {
44    match condition {
45        TriggerCondition::Cron(_) => "Cron",
46        TriggerCondition::Status(_) => "Status",
47        TriggerCondition::Event(_) => "Event",
48        TriggerCondition::Result(_) => "Result",
49        TriggerCondition::Exception(_) => "Exception",
50        TriggerCondition::Composite(_) => "Composite",
51        _ => "Unknown",
52    }
53}
54
55#[async_trait]
56impl TriggerStore for PostgresTriggerStore {
57    async fn register_condition(
58        &self,
59        condition: &TriggerCondition,
60    ) -> RustvelloResult<ConditionId> {
61        let cond_id = condition.condition_id();
62        let json = serde_json::to_string(condition).map_err(|e| RustvelloError::Serialization {
63            message: e.to_string(),
64        })?;
65        let cond_type = condition_type_tag(condition);
66        let event_code: Option<&str> = match condition {
67            TriggerCondition::Event(evt) => Some(&evt.event_code),
68            _ => None,
69        };
70
71        let mut client = self.db.conn().await?;
72        let tx = client.transaction().await.map_err(pg_err)?;
73
74        tx
75            .execute(
76                "INSERT INTO trg_conditions (condition_id, condition_type, condition_json, event_code) VALUES ($1, $2, $3, $4)
77                 ON CONFLICT (condition_id) DO UPDATE SET condition_type = $2, condition_json = $3, event_code = $4",
78                &[&cond_id.as_str(), &cond_type, &json, &event_code],
79            )
80            .await
81            .map_err(pg_err)?;
82
83        // Index by source task IDs
84        for task_id in condition.source_task_ids() {
85            tx.execute(
86                "INSERT INTO trg_source_task_conditions (task_id, condition_id) VALUES ($1, $2)
87                     ON CONFLICT DO NOTHING",
88                &[&task_id.to_string(), &cond_id.as_str()],
89            )
90            .await
91            .map_err(pg_err)?;
92        }
93
94        tx.commit().await.map_err(pg_err)?;
95
96        Ok(cond_id)
97    }
98
99    async fn get_condition(&self, id: &ConditionId) -> RustvelloResult<Option<TriggerCondition>> {
100        let client = self.db.conn().await?;
101
102        let row = client
103            .query_opt(
104                "SELECT condition_json FROM trg_conditions WHERE condition_id = $1",
105                &[&id.as_str()],
106            )
107            .await
108            .map_err(pg_err)?;
109
110        match row {
111            Some(r) => {
112                let json: String = r.get(0);
113                let cond: TriggerCondition =
114                    serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
115                        message: e.to_string(),
116                    })?;
117                Ok(Some(cond))
118            }
119            None => Ok(None),
120        }
121    }
122
123    async fn get_conditions_for_task(
124        &self,
125        task_id: &TaskId,
126    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
127        let client = self.db.conn().await?;
128
129        let rows = client
130            .query(
131                "SELECT c.condition_id, c.condition_json
132                 FROM trg_conditions c
133                 INNER JOIN trg_source_task_conditions stc ON c.condition_id = stc.condition_id
134                 WHERE stc.task_id = $1",
135                &[&task_id.to_string()],
136            )
137            .await
138            .map_err(pg_err)?;
139
140        let mut result = Vec::new();
141        for row in &rows {
142            let id: String = row.get(0);
143            let json: String = row.get(1);
144            let cond: TriggerCondition =
145                serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
146                    message: e.to_string(),
147                })?;
148            result.push((ConditionId::from(id), cond));
149        }
150        Ok(result)
151    }
152
153    async fn get_cron_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
154        let client = self.db.conn().await?;
155
156        let rows = client
157            .query(
158                "SELECT condition_id, condition_json FROM trg_conditions WHERE condition_type = 'Cron'",
159                &[],
160            )
161            .await
162            .map_err(pg_err)?;
163
164        let mut result = Vec::new();
165        for row in &rows {
166            let id: String = row.get(0);
167            let json: String = row.get(1);
168            let cond: TriggerCondition =
169                serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
170                    message: e.to_string(),
171                })?;
172            result.push((ConditionId::from(id), cond));
173        }
174        Ok(result)
175    }
176
177    async fn get_event_conditions(
178        &self,
179        event_code: &str,
180    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
181        let client = self.db.conn().await?;
182
183        let rows = client
184            .query(
185                "SELECT condition_id, condition_json FROM trg_conditions \
186                 WHERE condition_type = 'Event' AND event_code = $1",
187                &[&event_code],
188            )
189            .await
190            .map_err(pg_err)?;
191
192        let mut result = Vec::new();
193        for row in &rows {
194            let id: String = row.get(0);
195            let json: String = row.get(1);
196            let cond: TriggerCondition =
197                serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
198                    message: e.to_string(),
199                })?;
200            result.push((ConditionId::from(id), cond));
201        }
202        Ok(result)
203    }
204
205    async fn register_trigger(&self, trigger: &TriggerDefinitionDTO) -> RustvelloResult<()> {
206        let mut client = self.db.conn().await?;
207        let tx = client.transaction().await.map_err(pg_err)?;
208        let arg_tmpl = trigger.argument_template.as_ref().map(ToString::to_string);
209
210        tx
211            .execute(
212                "INSERT INTO trg_triggers (trigger_id, task_id, logic, argument_template) VALUES ($1, $2, $3, $4)
213                 ON CONFLICT (trigger_id) DO UPDATE SET task_id = $2, logic = $3, argument_template = $4",
214                &[
215                    &trigger.trigger_id.as_str(),
216                    &trigger.task_id.to_string(),
217                    &trigger.logic.to_string(),
218                    &arg_tmpl as &(dyn tokio_postgres::types::ToSql + Sync),
219                ],
220            )
221            .await
222            .map_err(pg_err)?;
223
224        // Insert condition mappings
225        for cid in &trigger.condition_ids {
226            tx.execute(
227                "INSERT INTO trg_condition_triggers (condition_id, trigger_id) VALUES ($1, $2)
228                     ON CONFLICT DO NOTHING",
229                &[&cid.as_str(), &trigger.trigger_id.as_str()],
230            )
231            .await
232            .map_err(pg_err)?;
233        }
234
235        tx.commit().await.map_err(pg_err)?;
236        Ok(())
237    }
238
239    async fn get_trigger(
240        &self,
241        id: &TriggerDefinitionId,
242    ) -> RustvelloResult<Option<TriggerDefinitionDTO>> {
243        let client = self.db.conn().await?;
244
245        let row = client
246            .query_opt(
247                "SELECT task_id, logic, argument_template FROM trg_triggers WHERE trigger_id = $1",
248                &[&id.as_str()],
249            )
250            .await
251            .map_err(pg_err)?;
252
253        match row {
254            Some(r) => {
255                let task_id_str: String = r.get(0);
256                let logic_str: String = r.get(1);
257                let arg_tmpl: Option<String> = r.get(2);
258
259                let task_id: TaskId = task_id_str.parse().map_err(|e| {
260                    RustvelloError::state_backend(format!("invalid task_id in database: {e}"))
261                })?;
262                let logic = parse_logic(&logic_str)?;
263                let argument_template = arg_tmpl
264                    .map(|s| serde_json::from_str(&s))
265                    .transpose()
266                    .map_err(|e| RustvelloError::Serialization {
267                        message: e.to_string(),
268                    })?;
269
270                let condition_ids = self
271                    .get_condition_ids_for_trigger(&client, id.as_str())
272                    .await?;
273
274                Ok(Some(TriggerDefinitionDTO {
275                    trigger_id: id.clone(),
276                    task_id,
277                    condition_ids,
278                    logic,
279                    argument_template,
280                }))
281            }
282            None => Ok(None),
283        }
284    }
285
286    async fn get_triggers_for_condition(
287        &self,
288        cond_id: &ConditionId,
289    ) -> RustvelloResult<Vec<TriggerDefinitionDTO>> {
290        let client = self.db.conn().await?;
291
292        let rows = client
293            .query(
294                "SELECT t.trigger_id, t.task_id, t.logic, t.argument_template
295                 FROM trg_triggers t
296                 INNER JOIN trg_condition_triggers ct ON t.trigger_id = ct.trigger_id
297                 WHERE ct.condition_id = $1",
298                &[&cond_id.as_str()],
299            )
300            .await
301            .map_err(pg_err)?;
302
303        // Collect trigger IDs for batch condition lookup
304        let trigger_ids: Vec<String> = rows.iter().map(|r| r.get::<_, String>(0)).collect();
305        let mut cond_map = self
306            .get_condition_ids_for_triggers(&client, &trigger_ids)
307            .await?;
308
309        let mut result = Vec::new();
310        for row in &rows {
311            let trigger_id: String = row.get(0);
312            let task_id_str: String = row.get(1);
313            let logic_str: String = row.get(2);
314            let arg_tmpl: Option<String> = row.get(3);
315
316            let task_id: TaskId = task_id_str.parse().map_err(|e| {
317                RustvelloError::state_backend(format!("invalid task_id in database: {e}"))
318            })?;
319            let logic = parse_logic(&logic_str)?;
320            let argument_template = arg_tmpl
321                .map(|s| serde_json::from_str(&s))
322                .transpose()
323                .map_err(|e| RustvelloError::Serialization {
324                    message: e.to_string(),
325                })?;
326            let condition_ids = cond_map.remove(&trigger_id).unwrap_or_default();
327
328            result.push(TriggerDefinitionDTO {
329                trigger_id: TriggerDefinitionId::from(trigger_id),
330                task_id,
331                condition_ids,
332                logic,
333                argument_template,
334            });
335        }
336        Ok(result)
337    }
338
339    async fn remove_triggers_for_task(&self, task_id: &TaskId) -> RustvelloResult<u32> {
340        let client = self.db.conn().await?;
341        let task_str = task_id.to_string();
342
343        // Batch delete condition mappings via subquery
344        client
345            .execute(
346                "DELETE FROM trg_condition_triggers WHERE trigger_id IN \
347                 (SELECT trigger_id FROM trg_triggers WHERE task_id = $1)",
348                &[&task_str],
349            )
350            .await
351            .map_err(pg_err)?;
352
353        // Batch delete triggers
354        let deleted = client
355            .execute("DELETE FROM trg_triggers WHERE task_id = $1", &[&task_str])
356            .await
357            .map_err(pg_err)?;
358
359        Ok(u32::try_from(deleted).unwrap_or(u32::MAX))
360    }
361
362    async fn record_valid_condition(&self, vc: &ValidCondition) -> RustvelloResult<()> {
363        let context_json =
364            serde_json::to_string(&vc.context).map_err(|e| RustvelloError::Serialization {
365                message: e.to_string(),
366            })?;
367
368        let client = self.db.conn().await?;
369        client
370            .execute(
371                "INSERT INTO trg_valid_conditions (valid_condition_id, condition_id, context_json) VALUES ($1, $2, $3)
372                 ON CONFLICT (valid_condition_id) DO UPDATE SET condition_id = $2, context_json = $3",
373                &[&vc.valid_condition_id, &vc.condition_id.as_str(), &context_json],
374            )
375            .await
376            .map_err(pg_err)?;
377        Ok(())
378    }
379
380    async fn get_valid_conditions(&self) -> RustvelloResult<Vec<ValidCondition>> {
381        let client = self.db.conn().await?;
382
383        let rows = client
384            .query(
385                "SELECT valid_condition_id, condition_id, context_json FROM trg_valid_conditions",
386                &[],
387            )
388            .await
389            .map_err(pg_err)?;
390
391        let mut result = Vec::new();
392        for row in &rows {
393            let vc_id: String = row.get(0);
394            let cond_id: String = row.get(1);
395            let ctx_json: String = row.get(2);
396            let context: ConditionContext =
397                serde_json::from_str(&ctx_json).map_err(|e| RustvelloError::Serialization {
398                    message: e.to_string(),
399                })?;
400            result.push(ValidCondition {
401                valid_condition_id: vc_id,
402                condition_id: ConditionId::from(cond_id),
403                context,
404            });
405        }
406        Ok(result)
407    }
408
409    async fn clear_valid_conditions(&self, ids: &[String]) -> RustvelloResult<()> {
410        if ids.is_empty() {
411            return Ok(());
412        }
413        let client = self.db.conn().await?;
414        client
415            .execute(
416                "DELETE FROM trg_valid_conditions WHERE valid_condition_id = ANY($1)",
417                &[&ids],
418            )
419            .await
420            .map_err(pg_err)?;
421        Ok(())
422    }
423
424    async fn get_last_cron_execution(
425        &self,
426        cond_id: &ConditionId,
427    ) -> RustvelloResult<Option<DateTime<Utc>>> {
428        let client = self.db.conn().await?;
429
430        let row = client
431            .query_opt(
432                "SELECT last_execution FROM trg_cron_executions WHERE condition_id = $1",
433                &[&cond_id.as_str()],
434            )
435            .await
436            .map_err(pg_err)?;
437
438        Ok(row.map(|r| r.get::<_, DateTime<Utc>>(0)))
439    }
440
441    async fn store_cron_execution(
442        &self,
443        cond_id: &ConditionId,
444        time: DateTime<Utc>,
445        expected_last: Option<DateTime<Utc>>,
446    ) -> RustvelloResult<bool> {
447        let client = self.db.conn().await?;
448
449        let changed = match expected_last {
450            None => {
451                // Only insert if no existing row
452                client
453                    .execute(
454                        "INSERT INTO trg_cron_executions (condition_id, last_execution) VALUES ($1, $2)
455                         ON CONFLICT DO NOTHING",
456                        &[&cond_id.as_str(), &time],
457                    )
458                    .await
459                    .map_err(pg_err)?
460            }
461            Some(expected) => client
462                .execute(
463                    "UPDATE trg_cron_executions SET last_execution = $1
464                         WHERE condition_id = $2 AND last_execution = $3",
465                    &[&time, &cond_id.as_str(), &expected],
466                )
467                .await
468                .map_err(pg_err)?,
469        };
470
471        Ok(changed > 0)
472    }
473
474    async fn claim_trigger_run(&self, run_id: &TriggerRunId) -> RustvelloResult<bool> {
475        let client = self.db.conn().await?;
476        let now = Utc::now();
477
478        let changed = client
479            .execute(
480                "INSERT INTO trg_trigger_run_claims (trigger_run_id, claimed_at) VALUES ($1, $2)
481                 ON CONFLICT DO NOTHING",
482                &[&run_id.as_str(), &now],
483            )
484            .await
485            .map_err(pg_err)?;
486
487        Ok(changed > 0)
488    }
489
490    async fn purge(&self) -> RustvelloResult<()> {
491        let client = self.db.conn().await?;
492        client
493            .batch_execute(
494                "DELETE FROM trg_trigger_run_claims;
495                 DELETE FROM trg_cron_executions;
496                 DELETE FROM trg_valid_conditions;
497                 DELETE FROM trg_source_task_conditions;
498                 DELETE FROM trg_condition_triggers;
499                 DELETE FROM trg_triggers;
500                 DELETE FROM trg_conditions;",
501            )
502            .await
503            .map_err(pg_err)?;
504        Ok(())
505    }
506
507    async fn get_all_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
508        let client = self.db.conn().await?;
509
510        let rows = client
511            .query(
512                "SELECT condition_id, condition_json FROM trg_conditions",
513                &[],
514            )
515            .await
516            .map_err(pg_err)?;
517
518        let mut result = Vec::new();
519        for row in &rows {
520            let id: String = row.get(0);
521            let json: String = row.get(1);
522            let cond: TriggerCondition =
523                serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
524                    message: e.to_string(),
525                })?;
526            result.push((ConditionId::from(id), cond));
527        }
528        Ok(result)
529    }
530}
531
532// Helper methods
533impl PostgresTriggerStore {
534    async fn get_condition_ids_for_trigger(
535        &self,
536        client: &deadpool_postgres::Client,
537        trigger_id: &str,
538    ) -> RustvelloResult<Vec<ConditionId>> {
539        let rows = client
540            .query(
541                "SELECT condition_id FROM trg_condition_triggers WHERE trigger_id = $1",
542                &[&trigger_id],
543            )
544            .await
545            .map_err(pg_err)?;
546
547        Ok(rows
548            .iter()
549            .map(|r| ConditionId::from(r.get::<_, String>(0)))
550            .collect())
551    }
552
553    /// Batch-fetch condition IDs for multiple triggers in a single query.
554    async fn get_condition_ids_for_triggers(
555        &self,
556        client: &deadpool_postgres::Client,
557        trigger_ids: &[String],
558    ) -> RustvelloResult<HashMap<String, Vec<ConditionId>>> {
559        if trigger_ids.is_empty() {
560            return Ok(HashMap::new());
561        }
562        let rows = client
563            .query(
564                "SELECT trigger_id, condition_id FROM trg_condition_triggers \
565                 WHERE trigger_id = ANY($1)",
566                &[&trigger_ids],
567            )
568            .await
569            .map_err(pg_err)?;
570
571        let mut map: HashMap<String, Vec<ConditionId>> = HashMap::new();
572        for row in &rows {
573            let tid: String = row.get(0);
574            let cid: String = row.get(1);
575            map.entry(tid).or_default().push(ConditionId::from(cid));
576        }
577        Ok(map)
578    }
579}