Skip to main content

rustvello_sqlite/state_backend/
query.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::{RustvelloError, RustvelloResult};
6use rustvello_core::state_backend::StateBackendQuery;
7
8use rustvello_proto::identifiers::{InvocationId, TaskId};
9use rustvello_proto::invocation::WorkflowIdentity;
10
11use crate::db::{blocking, lock_err, sql_err};
12
13use super::SqliteStateBackend;
14
15#[async_trait]
16impl StateBackendQuery for SqliteStateBackend {
17    async fn get_workflow_invocations(
18        &self,
19        workflow_id: &InvocationId,
20    ) -> RustvelloResult<Vec<InvocationId>> {
21        let db = Arc::clone(&self.db);
22        let workflow_id = workflow_id.clone();
23        blocking(move || {
24            let conn = db.conn.lock().map_err(lock_err)?;
25            let mut stmt = conn
26                .prepare("SELECT invocation_id FROM invocations WHERE workflow_id = ?1")
27                .map_err(sql_err)?;
28            let ids: Vec<InvocationId> = stmt
29                .query_map([workflow_id.as_str()], |row| {
30                    let id: String = row.get(0)?;
31                    Ok(InvocationId::from_string(id))
32                })
33                .map_err(sql_err)?
34                .collect::<Result<Vec<_>, _>>()
35                .map_err(sql_err)?;
36            Ok(ids)
37        })
38        .await
39    }
40
41    async fn get_child_invocations(
42        &self,
43        parent_invocation_id: &InvocationId,
44    ) -> RustvelloResult<Vec<InvocationId>> {
45        let db = Arc::clone(&self.db);
46        let parent_invocation_id = parent_invocation_id.clone();
47        blocking(move || {
48            let conn = db.conn.lock().map_err(lock_err)?;
49            let mut stmt = conn
50                .prepare("SELECT invocation_id FROM invocations WHERE parent_invocation_id = ?1")
51                .map_err(sql_err)?;
52            let ids: Vec<InvocationId> = stmt
53                .query_map([parent_invocation_id.as_str()], |row| {
54                    let id: String = row.get(0)?;
55                    Ok(InvocationId::from_string(id))
56                })
57                .map_err(sql_err)?
58                .collect::<Result<Vec<_>, _>>()
59                .map_err(sql_err)?;
60            Ok(ids)
61        })
62        .await
63    }
64
65    async fn store_workflow_run(&self, workflow: &WorkflowIdentity) -> RustvelloResult<()> {
66        let db = Arc::clone(&self.db);
67        let workflow = workflow.clone();
68        blocking(move || {
69
70            let conn = db.conn.lock().map_err(lock_err)?;
71            conn.execute(
72                "INSERT OR REPLACE INTO workflow_runs (workflow_id, workflow_type, parent_workflow_id, depth) VALUES (?1, ?2, ?3, ?4)",
73                rusqlite::params![
74                    &workflow.workflow_id.as_str(),
75                    &workflow.workflow_type.to_string(),
76                    &workflow.parent_id.as_ref().map(|id| id.as_str().to_owned()),
77                    workflow.depth as i64,
78                ],
79            )
80            .map_err(sql_err)?;
81            Ok(())
82
83        })
84        .await
85    }
86
87    async fn get_all_workflow_types(&self) -> RustvelloResult<Vec<TaskId>> {
88        let db = Arc::clone(&self.db);
89        blocking(move || {
90            let conn = db.conn.lock().map_err(lock_err)?;
91            let mut stmt = conn
92                .prepare("SELECT DISTINCT workflow_type FROM workflow_runs")
93                .map_err(sql_err)?;
94            let types: Vec<TaskId> = stmt
95                .query_map([], |row| {
96                    let type_str: String = row.get(0)?;
97                    Ok(type_str)
98                })
99                .map_err(sql_err)?
100                .collect::<Result<Vec<_>, _>>()
101                .map_err(sql_err)?
102                .into_iter()
103                .map(|s| {
104                    s.parse::<TaskId>().map_err(|e| {
105                        RustvelloError::state_backend(format!("invalid task_id in database: {e}"))
106                    })
107                })
108                .collect::<RustvelloResult<Vec<_>>>()?;
109            Ok(types)
110        })
111        .await
112    }
113
114    async fn get_workflow_runs(
115        &self,
116        workflow_type: &TaskId,
117    ) -> RustvelloResult<Vec<WorkflowIdentity>> {
118        let db = Arc::clone(&self.db);
119        let workflow_type = workflow_type.clone();
120        blocking(move || {
121
122            let conn = db.conn.lock().map_err(lock_err)?;
123            let type_key = workflow_type.to_string();
124            let mut stmt = conn
125                .prepare(
126                    "SELECT workflow_id, workflow_type, parent_workflow_id, depth FROM workflow_runs WHERE workflow_type = ?1",
127                )
128                .map_err(sql_err)?;
129            let runs: Vec<WorkflowIdentity> = stmt
130                .query_map([&type_key], |row| {
131                    let wf_id: String = row.get(0)?;
132                    let wf_type: String = row.get(1)?;
133                    let parent_id: Option<String> = row.get(2)?;
134                    let depth: i64 = row.get(3)?;
135                    Ok((wf_id, wf_type, parent_id, depth))
136                })
137                .map_err(sql_err)?
138                .collect::<Result<Vec<_>, _>>()
139                .map_err(sql_err)?
140                .into_iter()
141                .map(|(wf_id, wf_type, parent_id, depth)| {
142                    let task_id = wf_type.parse::<TaskId>()
143                        .map_err(|e| RustvelloError::state_backend(format!("invalid workflow task_id in database: {e}")))?;
144                    Ok(WorkflowIdentity {
145                        workflow_id: InvocationId::from_string(wf_id),
146                        workflow_type: task_id,
147                        parent_id: parent_id.map(InvocationId::from_string),
148                        depth: u32::try_from(depth).unwrap_or(0),
149                    })
150                })
151                .collect::<RustvelloResult<Vec<_>>>()?;
152            Ok(runs)
153
154        })
155        .await
156    }
157
158    async fn set_workflow_data(
159        &self,
160        workflow_id: &InvocationId,
161        key: &str,
162        value: &str,
163    ) -> RustvelloResult<()> {
164        let db = Arc::clone(&self.db);
165        let workflow_id = workflow_id.clone();
166        let key = key.to_owned();
167        let value = value.to_owned();
168        blocking(move || {
169
170            let conn = db.conn.lock().map_err(lock_err)?;
171            conn.execute(
172                "INSERT OR REPLACE INTO workflow_data (workflow_id, data_key, data_value) VALUES (?1, ?2, ?3)",
173                rusqlite::params![workflow_id.as_str(), key, value],
174            )
175            .map_err(sql_err)?;
176            Ok(())
177
178        })
179        .await
180    }
181
182    async fn get_workflow_data(
183        &self,
184        workflow_id: &InvocationId,
185        key: &str,
186    ) -> RustvelloResult<Option<String>> {
187        let db = Arc::clone(&self.db);
188        let workflow_id = workflow_id.clone();
189        let key = key.to_owned();
190        blocking(move || {
191            let conn = db.conn.lock().map_err(lock_err)?;
192            let result: Option<String> = conn
193                .query_row(
194                    "SELECT data_value FROM workflow_data WHERE workflow_id = ?1 AND data_key = ?2",
195                    rusqlite::params![workflow_id.as_str(), key],
196                    |row| row.get(0),
197                )
198                .ok();
199            Ok(result)
200        })
201        .await
202    }
203
204    async fn store_app_info(&self, app_id: &str, info_json: &str) -> RustvelloResult<()> {
205        let db = Arc::clone(&self.db);
206        let app_id = app_id.to_owned();
207        let info_json = info_json.to_owned();
208        blocking(move || {
209            let conn = db.conn.lock().map_err(lock_err)?;
210            conn.execute(
211                "INSERT OR REPLACE INTO app_infos (app_id, info_json) VALUES (?1, ?2)",
212                rusqlite::params![app_id, info_json],
213            )
214            .map_err(sql_err)?;
215            Ok(())
216        })
217        .await
218    }
219
220    async fn get_app_info(&self, app_id: &str) -> RustvelloResult<Option<String>> {
221        let db = Arc::clone(&self.db);
222        let app_id = app_id.to_owned();
223        blocking(move || {
224            let conn = db.conn.lock().map_err(lock_err)?;
225            let result: Option<String> = conn
226                .query_row(
227                    "SELECT info_json FROM app_infos WHERE app_id = ?1",
228                    rusqlite::params![app_id],
229                    |row| row.get(0),
230                )
231                .ok();
232            Ok(result)
233        })
234        .await
235    }
236
237    async fn get_all_app_infos(&self) -> RustvelloResult<Vec<(String, String)>> {
238        let db = Arc::clone(&self.db);
239        blocking(move || {
240            let conn = db.conn.lock().map_err(lock_err)?;
241            let mut stmt = conn
242                .prepare("SELECT app_id, info_json FROM app_infos")
243                .map_err(sql_err)?;
244            let infos = stmt
245                .query_map([], |row| {
246                    let app_id: String = row.get(0)?;
247                    let info_json: String = row.get(1)?;
248                    Ok((app_id, info_json))
249                })
250                .map_err(sql_err)?
251                .collect::<Result<Vec<_>, _>>()
252                .map_err(sql_err)?;
253            Ok(infos)
254        })
255        .await
256    }
257
258    async fn store_workflow_sub_invocation(
259        &self,
260        workflow_id: &InvocationId,
261        sub_inv_id: &InvocationId,
262    ) -> RustvelloResult<()> {
263        let db = Arc::clone(&self.db);
264        let workflow_id = workflow_id.clone();
265        let sub_inv_id = sub_inv_id.clone();
266        blocking(move || {
267
268            let conn = db.conn.lock().map_err(lock_err)?;
269            conn.execute(
270                "INSERT OR IGNORE INTO workflow_sub_invocations (workflow_id, sub_invocation_id) VALUES (?1, ?2)",
271                rusqlite::params![workflow_id.as_str(), sub_inv_id.as_str()],
272            )
273            .map_err(sql_err)?;
274            Ok(())
275
276        })
277        .await
278    }
279
280    async fn get_workflow_sub_invocations(
281        &self,
282        workflow_id: &InvocationId,
283    ) -> RustvelloResult<Vec<InvocationId>> {
284        let db = Arc::clone(&self.db);
285        let workflow_id = workflow_id.clone();
286        blocking(move || {
287            let conn = db.conn.lock().map_err(lock_err)?;
288            let mut stmt = conn
289                .prepare(
290                    "SELECT sub_invocation_id FROM workflow_sub_invocations WHERE workflow_id = ?1",
291                )
292                .map_err(sql_err)?;
293            let ids = stmt
294                .query_map([workflow_id.as_str()], |row| {
295                    let id: String = row.get(0)?;
296                    Ok(InvocationId::from_string(id))
297                })
298                .map_err(sql_err)?
299                .collect::<Result<Vec<_>, _>>()
300                .map_err(sql_err)?;
301            Ok(ids)
302        })
303        .await
304    }
305
306    async fn get_all_workflow_runs(&self) -> RustvelloResult<Vec<WorkflowIdentity>> {
307        let db = Arc::clone(&self.db);
308        blocking(move || {
309
310            let conn = db.conn.lock().map_err(lock_err)?;
311            let mut stmt = conn
312                .prepare(
313                    "SELECT workflow_id, workflow_type, parent_workflow_id, depth FROM workflow_runs",
314                )
315                .map_err(sql_err)?;
316            let runs = stmt
317                .query_map([], |row| {
318                    Ok((
319                        row.get::<_, String>(0)?,
320                        row.get::<_, String>(1)?,
321                        row.get::<_, Option<String>>(2)?,
322                        row.get::<_, i64>(3)?,
323                    ))
324                })
325                .map_err(sql_err)?
326                .collect::<Result<Vec<_>, _>>()
327                .map_err(sql_err)?
328                .into_iter()
329                .map(|(wf_id, wf_type, parent_id, depth)| {
330                    let task_id = wf_type.parse::<TaskId>()
331                        .map_err(|e| RustvelloError::state_backend(format!("invalid workflow task_id in database: {e}")))?;
332                    Ok(WorkflowIdentity {
333                        workflow_id: InvocationId::from_string(wf_id),
334                        workflow_type: task_id,
335                        parent_id: parent_id.map(InvocationId::from_string),
336                        depth: u32::try_from(depth).unwrap_or(0),
337                    })
338                })
339                .collect::<RustvelloResult<Vec<_>>>()?;
340            Ok(runs)
341
342        })
343        .await
344    }
345}