1use std::sync::Arc;
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use mongodb::bson::doc;
6use mongodb::error::{ErrorKind, WriteFailure};
7
8use rustvello_core::error::{RustvelloError, RustvelloResult};
9use rustvello_core::trigger::TriggerStore;
10use rustvello_proto::identifiers::TaskId;
11use rustvello_proto::trigger::{
12 ConditionId, TriggerCondition, TriggerDefinitionDTO, TriggerDefinitionId, TriggerRunId,
13 ValidCondition,
14};
15
16use crate::connection::{mongo_err, MongoPool};
17
18const COND_COL: &str = "trg_conditions";
19const TRIGGER_COL: &str = "trg_definitions";
20const VALID_COL: &str = "trg_valid_conditions";
21const CRON_EXEC_COL: &str = "trg_cron_executions";
22const RUN_COL: &str = "trg_runs";
23
24#[non_exhaustive]
26pub struct MongoTriggerStore {
27 pool: Arc<MongoPool>,
28}
29
30impl MongoTriggerStore {
31 pub fn new(pool: Arc<MongoPool>) -> Self {
32 Self { pool }
33 }
34}
35
36#[async_trait]
37impl TriggerStore for MongoTriggerStore {
38 async fn register_condition(
39 &self,
40 condition: &TriggerCondition,
41 ) -> RustvelloResult<ConditionId> {
42 let cond_id = condition.condition_id();
43 let db = self.pool.db().await?;
44 let col = db.collection::<mongodb::bson::Document>(COND_COL);
45 let json = serde_json::to_string(condition).map_err(|e| RustvelloError::Serialization {
46 message: e.to_string(),
47 })?;
48
49 let task_ids: Vec<String> = condition
50 .source_task_ids()
51 .iter()
52 .map(ToString::to_string)
53 .collect();
54
55 let condition_type = match condition {
57 TriggerCondition::Cron(_) => "Cron",
58 TriggerCondition::Event(_) => "Event",
59 TriggerCondition::Status(_) => "Status",
60 TriggerCondition::Result(_) => "Result",
61 TriggerCondition::Exception(_) => "Exception",
62 TriggerCondition::Composite(_) => "Composite",
63 _ => "Other",
64 };
65
66 let mut set_fields = doc! {
67 "data": &json,
68 "task_ids": &task_ids,
69 "condition_type": condition_type,
70 };
71
72 if let TriggerCondition::Event(ev) = condition {
74 set_fields.insert("event_code", ev.event_code.clone());
75 }
76
77 let update_doc = doc! { "$set": set_fields };
78
79 let filter = doc! { "_id": cond_id.as_str().to_owned() };
80 col.update_one(filter, update_doc)
81 .upsert(true)
82 .await
83 .map_err(mongo_err)?;
84
85 Ok(cond_id)
86 }
87
88 async fn get_condition(&self, id: &ConditionId) -> RustvelloResult<Option<TriggerCondition>> {
89 let db = self.pool.db().await?;
90 let col = db.collection::<mongodb::bson::Document>(COND_COL);
91 let filter = doc! { "_id": &id.as_str() };
92 let result = col.find_one(filter).await.map_err(mongo_err)?;
93 match result {
94 Some(d) => {
95 let s = d
96 .get_str("data")
97 .map_err(|e| RustvelloError::state_backend(e.to_string()))?;
98 let c: TriggerCondition =
99 serde_json::from_str(s).map_err(|e| RustvelloError::Serialization {
100 message: e.to_string(),
101 })?;
102 Ok(Some(c))
103 }
104 None => Ok(None),
105 }
106 }
107
108 async fn get_conditions_for_task(
109 &self,
110 task_id: &TaskId,
111 ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
112 let db = self.pool.db().await?;
113 let col = db.collection::<mongodb::bson::Document>(COND_COL);
114 let filter = doc! { "task_ids": task_id.to_string() };
115 let mut cursor = col.find(filter).await.map_err(mongo_err)?;
116
117 let mut result = Vec::new();
118 use futures_util::StreamExt;
119 while let Some(doc_result) = StreamExt::next(&mut cursor).await {
120 let d = doc_result.map_err(mongo_err)?;
121 if let (Ok(id), Ok(data)) = (d.get_str("_id"), d.get_str("data")) {
122 let cond: TriggerCondition =
123 serde_json::from_str(data).map_err(|e| RustvelloError::Serialization {
124 message: e.to_string(),
125 })?;
126 result.push((ConditionId::from(id.to_string()), cond));
127 }
128 }
129 Ok(result)
130 }
131
132 async fn get_cron_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
133 let db = self.pool.db().await?;
134 let col = db.collection::<mongodb::bson::Document>(COND_COL);
135 let mut cursor = col
137 .find(doc! { "condition_type": "Cron" })
138 .await
139 .map_err(mongo_err)?;
140
141 let mut result = Vec::new();
142 use futures_util::StreamExt;
143 while let Some(doc_result) = StreamExt::next(&mut cursor).await {
144 let d = doc_result.map_err(mongo_err)?;
145 if let (Ok(id), Ok(data)) = (d.get_str("_id"), d.get_str("data")) {
146 let cond: TriggerCondition =
147 serde_json::from_str(data).map_err(|e| RustvelloError::Serialization {
148 message: e.to_string(),
149 })?;
150 result.push((ConditionId::from(id.to_string()), cond));
151 }
152 }
153 Ok(result)
154 }
155
156 async fn get_event_conditions(
157 &self,
158 event_code: &str,
159 ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
160 let db = self.pool.db().await?;
161 let col = db.collection::<mongodb::bson::Document>(COND_COL);
162 let mut cursor = col
164 .find(doc! { "condition_type": "Event", "event_code": event_code })
165 .await
166 .map_err(mongo_err)?;
167
168 let mut result = Vec::new();
169 use futures_util::StreamExt;
170 while let Some(doc_result) = StreamExt::next(&mut cursor).await {
171 let d = doc_result.map_err(mongo_err)?;
172 if let (Ok(id), Ok(data)) = (d.get_str("_id"), d.get_str("data")) {
173 let cond: TriggerCondition =
174 serde_json::from_str(data).map_err(|e| RustvelloError::Serialization {
175 message: e.to_string(),
176 })?;
177 result.push((ConditionId::from(id.to_string()), cond));
178 }
179 }
180 Ok(result)
181 }
182
183 async fn register_trigger(&self, trigger: &TriggerDefinitionDTO) -> RustvelloResult<()> {
184 let db = self.pool.db().await?;
185 let col = db.collection::<mongodb::bson::Document>(TRIGGER_COL);
186 let json = serde_json::to_string(trigger).map_err(|e| RustvelloError::Serialization {
187 message: e.to_string(),
188 })?;
189
190 let cond_ids: Vec<String> = trigger
191 .condition_ids
192 .iter()
193 .map(|c| c.as_str().to_owned())
194 .collect();
195
196 let filter = doc! { "_id": &trigger.trigger_id.as_str() };
197 let update = doc! {
198 "$set": {
199 "data": &json,
200 "task_id": trigger.task_id.to_string(),
201 "condition_ids": &cond_ids,
202 }
203 };
204 col.update_one(filter, update)
205 .upsert(true)
206 .await
207 .map_err(mongo_err)?;
208 Ok(())
209 }
210
211 async fn get_trigger(
212 &self,
213 id: &TriggerDefinitionId,
214 ) -> RustvelloResult<Option<TriggerDefinitionDTO>> {
215 let db = self.pool.db().await?;
216 let col = db.collection::<mongodb::bson::Document>(TRIGGER_COL);
217 let filter = doc! { "_id": &id.as_str() };
218 let result = col.find_one(filter).await.map_err(mongo_err)?;
219 match result {
220 Some(d) => {
221 let s = d
222 .get_str("data")
223 .map_err(|e| RustvelloError::state_backend(e.to_string()))?;
224 let t: TriggerDefinitionDTO =
225 serde_json::from_str(s).map_err(|e| RustvelloError::Serialization {
226 message: e.to_string(),
227 })?;
228 Ok(Some(t))
229 }
230 None => Ok(None),
231 }
232 }
233
234 async fn get_triggers_for_condition(
235 &self,
236 cond_id: &ConditionId,
237 ) -> RustvelloResult<Vec<TriggerDefinitionDTO>> {
238 let db = self.pool.db().await?;
239 let col = db.collection::<mongodb::bson::Document>(TRIGGER_COL);
240 let filter = doc! { "condition_ids": cond_id.as_str() };
241 let mut cursor = col.find(filter).await.map_err(mongo_err)?;
242
243 let mut result = Vec::new();
244 use futures_util::StreamExt;
245 while let Some(doc_result) = StreamExt::next(&mut cursor).await {
246 let d = doc_result.map_err(mongo_err)?;
247 if let Ok(s) = d.get_str("data") {
248 let t: TriggerDefinitionDTO =
249 serde_json::from_str(s).map_err(|e| RustvelloError::Serialization {
250 message: e.to_string(),
251 })?;
252 result.push(t);
253 }
254 }
255 Ok(result)
256 }
257
258 async fn remove_triggers_for_task(&self, task_id: &TaskId) -> RustvelloResult<u32> {
259 let db = self.pool.db().await?;
260 let col = db.collection::<mongodb::bson::Document>(TRIGGER_COL);
261 let filter = doc! { "task_id": task_id.to_string() };
262 let result = col.delete_many(filter).await.map_err(mongo_err)?;
263 Ok(u32::try_from(result.deleted_count).unwrap_or(u32::MAX))
264 }
265
266 async fn record_valid_condition(&self, vc: &ValidCondition) -> RustvelloResult<()> {
267 let db = self.pool.db().await?;
268 let col = db.collection::<mongodb::bson::Document>(VALID_COL);
269 let json = serde_json::to_string(vc).map_err(|e| RustvelloError::Serialization {
270 message: e.to_string(),
271 })?;
272 let filter = doc! { "_id": &vc.valid_condition_id };
273 let update = doc! { "$set": { "data": &json } };
274 col.update_one(filter, update)
275 .upsert(true)
276 .await
277 .map_err(mongo_err)?;
278 Ok(())
279 }
280
281 async fn get_valid_conditions(&self) -> RustvelloResult<Vec<ValidCondition>> {
282 let db = self.pool.db().await?;
283 let col = db.collection::<mongodb::bson::Document>(VALID_COL);
284 let mut cursor = col.find(doc! {}).await.map_err(mongo_err)?;
285
286 let mut result = Vec::new();
287 use futures_util::StreamExt;
288 while let Some(doc_result) = StreamExt::next(&mut cursor).await {
289 let d = doc_result.map_err(mongo_err)?;
290 if let Ok(s) = d.get_str("data") {
291 let vc: ValidCondition =
292 serde_json::from_str(s).map_err(|e| RustvelloError::Serialization {
293 message: e.to_string(),
294 })?;
295 result.push(vc);
296 }
297 }
298 Ok(result)
299 }
300
301 async fn clear_valid_conditions(&self, ids: &[String]) -> RustvelloResult<()> {
302 if ids.is_empty() {
303 return Ok(());
304 }
305 let db = self.pool.db().await?;
306 let col = db.collection::<mongodb::bson::Document>(VALID_COL);
307 let bson_ids: Vec<mongodb::bson::Bson> = ids
308 .iter()
309 .map(|id| mongodb::bson::Bson::String(id.clone()))
310 .collect();
311 let filter = doc! { "_id": { "$in": bson_ids } };
312 col.delete_many(filter).await.map_err(mongo_err)?;
313 Ok(())
314 }
315
316 async fn get_last_cron_execution(
317 &self,
318 cond_id: &ConditionId,
319 ) -> RustvelloResult<Option<DateTime<Utc>>> {
320 let db = self.pool.db().await?;
321 let col = db.collection::<mongodb::bson::Document>(CRON_EXEC_COL);
322 let filter = doc! { "_id": cond_id.as_str() };
323 let result = col.find_one(filter).await.map_err(mongo_err)?;
324 match result {
325 Some(d) => {
326 let ts = d
327 .get_str("timestamp")
328 .map_err(|e| RustvelloError::state_backend(e.to_string()))?;
329 let dt = DateTime::parse_from_rfc3339(ts)
330 .map(|d| d.with_timezone(&Utc))
331 .map_err(|e| RustvelloError::Serialization {
332 message: format!("cron timestamp: {}", e),
333 })?;
334 Ok(Some(dt))
335 }
336 None => Ok(None),
337 }
338 }
339
340 async fn store_cron_execution(
341 &self,
342 cond_id: &ConditionId,
343 time: DateTime<Utc>,
344 expected_last: Option<DateTime<Utc>>,
345 ) -> RustvelloResult<bool> {
346 let db = self.pool.db().await?;
347 let col = db.collection::<mongodb::bson::Document>(CRON_EXEC_COL);
348
349 let filter = match expected_last {
351 Some(ts) => doc! { "_id": cond_id.as_str(), "timestamp": ts.to_rfc3339() },
352 None => doc! { "_id": cond_id.as_str(), "timestamp": { "$exists": false } },
353 };
354 let update = doc! { "$set": { "timestamp": time.to_rfc3339() } };
355
356 let result = col
357 .update_one(filter, update)
358 .upsert(expected_last.is_none())
359 .await;
360
361 match result {
362 Ok(r) => Ok(r.matched_count > 0 || r.upserted_id.is_some()),
363 Err(e) => {
364 if matches!(*e.kind, ErrorKind::Write(WriteFailure::WriteError(ref we)) if we.code == 11000)
367 {
368 Ok(false)
369 } else {
370 Err(mongo_err(e))
371 }
372 }
373 }
374 }
375
376 async fn claim_trigger_run(&self, run_id: &TriggerRunId) -> RustvelloResult<bool> {
377 let db = self.pool.db().await?;
378 let col = db.collection::<mongodb::bson::Document>(RUN_COL);
379 let doc = doc! { "_id": run_id.as_str().to_owned(), "claimed": true };
380 match col.insert_one(doc).await {
381 Ok(_) => Ok(true),
382 Err(e) => {
383 if matches!(*e.kind, ErrorKind::Write(WriteFailure::WriteError(ref we)) if we.code == 11000)
385 {
386 Ok(false)
387 } else {
388 Err(mongo_err(e))
389 }
390 }
391 }
392 }
393
394 async fn purge(&self) -> RustvelloResult<()> {
395 let db = self.pool.db().await?;
396 for col_name in [COND_COL, TRIGGER_COL, VALID_COL, CRON_EXEC_COL, RUN_COL] {
397 let col = db.collection::<mongodb::bson::Document>(col_name);
398 col.delete_many(doc! {}).await.map_err(mongo_err)?;
399 }
400 Ok(())
401 }
402
403 async fn get_all_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
404 let db = self.pool.db().await?;
405 let col = db.collection::<mongodb::bson::Document>(COND_COL);
406 let mut cursor = col.find(doc! {}).await.map_err(mongo_err)?;
407
408 let mut result = Vec::new();
409 use futures_util::StreamExt;
410 while let Some(doc_result) = StreamExt::next(&mut cursor).await {
411 let d = doc_result.map_err(mongo_err)?;
412 if let (Ok(id), Ok(data)) = (d.get_str("_id"), d.get_str("data")) {
413 let cond: TriggerCondition =
414 serde_json::from_str(data).map_err(|e| RustvelloError::Serialization {
415 message: e.to_string(),
416 })?;
417 result.push((ConditionId::from(id.to_string()), cond));
418 }
419 }
420 Ok(result)
421 }
422}