Skip to main content

rustvello_sqlite/state_backend/
core.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::{RustvelloError, RustvelloResult, TaskError};
6use rustvello_core::state_backend::StateBackendCore;
7
8use rustvello_proto::call::{CallDTO, SerializedArguments};
9use rustvello_proto::identifiers::{CallId, InvocationId, RunnerId, TaskId};
10use rustvello_proto::invocation::{InvocationDTO, InvocationHistory, WorkflowIdentity};
11use rustvello_proto::status::{InvocationStatus, InvocationStatusRecord};
12
13use crate::db::{blocking, lock_err, parse_status, parse_timestamp, sql_err};
14
15use super::SqliteStateBackend;
16
17#[async_trait]
18impl StateBackendCore for SqliteStateBackend {
19    async fn upsert_invocation(
20        &self,
21        invocation: &InvocationDTO,
22        call: &CallDTO,
23    ) -> RustvelloResult<()> {
24        let db = Arc::clone(&self.db);
25        let invocation = invocation.clone();
26        let call = call.clone();
27        blocking(move || {
28
29            let conn = db.conn.lock().map_err(lock_err)?;
30
31            let tx = conn.unchecked_transaction().map_err(sql_err)?;
32
33            let args_json = serde_json::to_string(&call.serialized_arguments.0)
34                .map_err(|e| RustvelloError::Serialization { message: e.to_string() })?;
35
36            let (parent_inv_id, wf_id, wf_type, wf_depth) = match &invocation.workflow {
37                Some(wf) => (
38                    invocation
39                        .parent_invocation_id
40                        .as_ref()
41                        .map(|id| id.as_str().to_owned()),
42                    Some(wf.workflow_id.as_str().to_owned()),
43                    Some(wf.workflow_type.to_string()),
44                    Some(wf.depth as i64),
45                ),
46                None => (
47                    invocation
48                        .parent_invocation_id
49                        .as_ref()
50                        .map(|id| id.as_str().to_owned()),
51                    None,
52                    None,
53                    None,
54                ),
55            };
56
57            tx.execute(
58                "INSERT OR REPLACE INTO invocations (invocation_id, task_id, call_id, status, created_at, updated_at, parent_invocation_id, workflow_id, workflow_type, workflow_depth)
59                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
60                rusqlite::params![
61                    &invocation.invocation_id.as_str(),
62                    &invocation.task_id.to_string(),
63                    &invocation.call_id.to_string(),
64                    &invocation.status.to_string(),
65                    &invocation.created_at.to_rfc3339(),
66                    &invocation.updated_at.to_rfc3339(),
67                    &parent_inv_id,
68                    &wf_id,
69                    &wf_type,
70                    &wf_depth,
71                ],
72            )
73            .map_err(sql_err)?;
74
75            tx.execute(
76                "INSERT OR REPLACE INTO calls (call_id, task_id, serialized_arguments) VALUES (?1, ?2, ?3)",
77                rusqlite::params![
78                    &call.call_id.to_string(),
79                    &call.task_id.to_string(),
80                    &args_json,
81                ],
82            )
83            .map_err(sql_err)?;
84
85            tx.commit().map_err(sql_err)?;
86
87            Ok(())
88
89        })
90        .await
91    }
92
93    async fn get_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<InvocationDTO> {
94        let db = Arc::clone(&self.db);
95        let invocation_id = invocation_id.clone();
96        blocking(move || {
97
98            let conn = db.conn.lock().map_err(lock_err)?;
99
100            let (task_id_str, call_id_str, status_str, created_str, updated_str, parent_inv_id, wf_id, wf_type, wf_depth): (
101                String,
102                String,
103                String,
104                String,
105                String,
106                Option<String>,
107                Option<String>,
108                Option<String>,
109                Option<i64>,
110            ) = conn
111                .query_row(
112                    "SELECT task_id, call_id, status, created_at, updated_at, parent_invocation_id, workflow_id, workflow_type, workflow_depth FROM invocations WHERE invocation_id = ?1",
113                    [invocation_id.as_str()],
114                    |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?, row.get(4)?, row.get(5)?, row.get(6)?, row.get(7)?, row.get(8)?)),
115                )
116                .map_err(|_| RustvelloError::InvocationNotFound { invocation_id: invocation_id.clone() })?;
117
118            let task_id: TaskId = task_id_str
119                .parse()
120                .map_err(|e| RustvelloError::state_backend(format!("invalid task_id in database: {e}")))?;
121
122            let args_id = call_id_str
123                .rsplit_once(':')
124                .map_or(call_id_str.as_str(), |(_, a)| a);
125            let call_id = CallId::new(task_id.clone(), args_id);
126
127            let created_at = parse_timestamp(&created_str)?;
128            let updated_at = parse_timestamp(&updated_str)?;
129
130            let parent_invocation_id = parent_inv_id.map(InvocationId::from_string);
131
132            let workflow = match (wf_id, wf_type) {
133                (Some(wf_id_str), Some(wf_type_str)) => {
134                    let wf_task_id: TaskId = wf_type_str.parse().map_err(|e| {
135                        RustvelloError::state_backend(format!("invalid workflow task_id in database: {e}"))
136                    })?;
137                    Some(WorkflowIdentity {
138                        workflow_id: InvocationId::from_string(wf_id_str),
139                        workflow_type: wf_task_id,
140                        parent_id: None,
141                        depth: u32::try_from(wf_depth.unwrap_or(0)).unwrap_or(0),
142                    })
143                }
144                _ => None,
145            };
146
147            Ok(InvocationDTO {
148                invocation_id: invocation_id.clone(),
149                task_id,
150                call_id,
151                status: parse_status(&status_str)?,
152                created_at,
153                updated_at,
154                parent_invocation_id,
155                workflow,
156            })
157
158        })
159        .await
160    }
161
162    async fn get_call(&self, call_id: &CallId) -> RustvelloResult<CallDTO> {
163        let db = Arc::clone(&self.db);
164        let call_id = call_id.clone();
165        blocking(move || {
166            let conn = db.conn.lock().map_err(lock_err)?;
167            let call_id_str = call_id.to_string();
168
169            let (task_id_str, args_json): (String, String) = conn
170                .query_row(
171                    "SELECT task_id, serialized_arguments FROM calls WHERE call_id = ?1",
172                    [&call_id_str],
173                    |row| Ok((row.get(0)?, row.get(1)?)),
174                )
175                .map_err(|_| {
176                    RustvelloError::state_backend(format!("call not found: {}", call_id_str))
177                })?;
178
179            let task_id: TaskId = task_id_str.parse().map_err(|e| {
180                RustvelloError::state_backend(format!("invalid task_id in database: {e}"))
181            })?;
182
183            let args_map: std::collections::BTreeMap<String, String> =
184                serde_json::from_str(&args_json).map_err(|e| RustvelloError::Serialization {
185                    message: e.to_string(),
186                })?;
187
188            let args = SerializedArguments(args_map);
189
190            Ok(CallDTO {
191                call_id: call_id.clone(),
192                task_id,
193                serialized_arguments: args,
194            })
195        })
196        .await
197    }
198
199    async fn store_result(
200        &self,
201        invocation_id: &InvocationId,
202        result: &str,
203    ) -> RustvelloResult<()> {
204        let db = Arc::clone(&self.db);
205        let invocation_id = invocation_id.clone();
206        let result = result.to_owned();
207        blocking(move || {
208            let conn = db.conn.lock().map_err(lock_err)?;
209            conn.execute(
210                "INSERT OR REPLACE INTO results (invocation_id, result) VALUES (?1, ?2)",
211                rusqlite::params![invocation_id.as_str(), result],
212            )
213            .map_err(sql_err)?;
214            Ok(())
215        })
216        .await
217    }
218
219    async fn get_result(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<String>> {
220        let db = Arc::clone(&self.db);
221        let invocation_id = invocation_id.clone();
222        blocking(move || {
223            let conn = db.conn.lock().map_err(lock_err)?;
224            let result: Option<String> = conn
225                .query_row(
226                    "SELECT result FROM results WHERE invocation_id = ?1",
227                    [invocation_id.as_str()],
228                    |row| row.get(0),
229                )
230                .ok();
231            Ok(result)
232        })
233        .await
234    }
235
236    async fn store_error(
237        &self,
238        invocation_id: &InvocationId,
239        error: &TaskError,
240    ) -> RustvelloResult<()> {
241        let db = Arc::clone(&self.db);
242        let invocation_id = invocation_id.clone();
243        let error = error.clone();
244        blocking(move || {
245
246            let conn = db.conn.lock().map_err(lock_err)?;
247            conn.execute(
248                "INSERT OR REPLACE INTO errors (invocation_id, error_type, message, traceback) VALUES (?1, ?2, ?3, ?4)",
249                rusqlite::params![
250                    invocation_id.as_str(),
251                    &error.error_type,
252                    &error.message,
253                    &error.traceback,
254                ],
255            )
256            .map_err(sql_err)?;
257            Ok(())
258
259        })
260        .await
261    }
262
263    async fn get_error(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<TaskError>> {
264        let db = Arc::clone(&self.db);
265        let invocation_id = invocation_id.clone();
266        blocking(move || {
267            let conn = db.conn.lock().map_err(lock_err)?;
268            let result: Option<(String, String, Option<String>)> = conn
269                .query_row(
270                    "SELECT error_type, message, traceback FROM errors WHERE invocation_id = ?1",
271                    [invocation_id.as_str()],
272                    |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
273                )
274                .ok();
275
276            Ok(result.map(|(error_type, message, traceback)| TaskError {
277                error_type,
278                message,
279                traceback,
280            }))
281        })
282        .await
283    }
284
285    async fn add_history(&self, history: &InvocationHistory) -> RustvelloResult<()> {
286        let db = Arc::clone(&self.db);
287        let history = history.clone();
288        blocking(move || {
289
290            let conn = db.conn.lock().map_err(lock_err)?;
291            let hist_ts = history.history_timestamp.map(|ts| ts.to_rfc3339());
292            conn.execute(
293                "INSERT INTO history (invocation_id, status, runner_id, timestamp, message, history_timestamp) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
294                rusqlite::params![
295                    &history.invocation_id.as_str(),
296                    &history.status_record.status.to_string(),
297                    &history.status_record.runner_id.as_ref().map(|r| r.as_str().to_string()),
298                    &history.status_record.timestamp.to_rfc3339(),
299                    &history.message,
300                    &hist_ts,
301                ],
302            )
303            .map_err(sql_err)?;
304            Ok(())
305
306        })
307        .await
308    }
309
310    async fn get_history(
311        &self,
312        invocation_id: &InvocationId,
313    ) -> RustvelloResult<Vec<InvocationHistory>> {
314        let db = Arc::clone(&self.db);
315        let invocation_id = invocation_id.clone();
316        blocking(move || {
317
318            let conn = db.conn.lock().map_err(lock_err)?;
319
320            let mut stmt = conn
321                .prepare(
322                    "SELECT status, runner_id, timestamp, message, history_timestamp FROM history WHERE invocation_id = ?1 ORDER BY id",
323                )
324                .map_err(sql_err)?;
325
326            let histories: Vec<InvocationHistory> = stmt
327                .query_map([invocation_id.as_str()], |row| {
328                    let status_str: String = row.get(0)?;
329                    let runner_id: Option<String> = row.get(1)?;
330                    let timestamp_str: String = row.get(2)?;
331                    let message: Option<String> = row.get(3)?;
332                    let hist_ts_str: Option<String> = row.get(4)?;
333
334                    let timestamp = chrono::DateTime::parse_from_rfc3339(&timestamp_str)
335                        .map(|dt| dt.with_timezone(&chrono::Utc))
336                        .map_err(|e| {
337                            rusqlite::Error::FromSqlConversionFailure(
338                                2,
339                                rusqlite::types::Type::Text,
340                                Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())),
341                            )
342                        })?;
343
344                    let history_timestamp = hist_ts_str
345                        .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
346                        .map(|dt| dt.with_timezone(&chrono::Utc));
347
348                    let status = status_str.parse::<InvocationStatus>().map_err(|e| {
349                        rusqlite::Error::FromSqlConversionFailure(
350                            0,
351                            rusqlite::types::Type::Text,
352                            Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e)),
353                        )
354                    })?;
355
356                    Ok(InvocationHistory {
357                        invocation_id: invocation_id.clone(),
358                        status_record: InvocationStatusRecord {
359                            status,
360                            runner_id: runner_id.clone().map(RunnerId::from_string),
361                            timestamp,
362                        },
363                        message,
364                        runner_id: runner_id.map(RunnerId::from_string),
365                        registered_by_inv_id: None,
366                        history_timestamp,
367                    })
368                })
369                .map_err(sql_err)?
370                .collect::<Result<Vec<_>, _>>()
371                .map_err(sql_err)?;
372
373            Ok(histories)
374
375        })
376        .await
377    }
378
379    async fn purge(&self) -> RustvelloResult<()> {
380        let db = Arc::clone(&self.db);
381        blocking(move || {
382            let conn = db.conn.lock().map_err(lock_err)?;
383            conn.execute_batch(
384                "DELETE FROM invocations;
385                 DELETE FROM calls;
386                 DELETE FROM results;
387                 DELETE FROM errors;
388                 DELETE FROM history;
389                 DELETE FROM status_records;
390                 DELETE FROM waiting_for;
391                 DELETE FROM broker_queue;
392                 DELETE FROM workflow_runs;
393                 DELETE FROM workflow_data;
394                 DELETE FROM app_infos;
395                 DELETE FROM workflow_sub_invocations;
396                 DELETE FROM runner_contexts;",
397            )
398            .map_err(sql_err)?;
399            Ok(())
400        })
401        .await
402    }
403}