1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::{RustvelloError, RustvelloResult, TaskError};
6use rustvello_core::state_backend::StateBackendCore;
7
8use rustvello_proto::call::{CallDTO, SerializedArguments};
9use rustvello_proto::identifiers::{CallId, InvocationId, RunnerId, TaskId};
10use rustvello_proto::invocation::{InvocationDTO, InvocationHistory, WorkflowIdentity};
11use rustvello_proto::status::{InvocationStatus, InvocationStatusRecord};
12
13use crate::db::{blocking, lock_err, parse_status, parse_timestamp, sql_err};
14
15use super::SqliteStateBackend;
16
17#[async_trait]
18impl StateBackendCore for SqliteStateBackend {
19 async fn upsert_invocation(
20 &self,
21 invocation: &InvocationDTO,
22 call: &CallDTO,
23 ) -> RustvelloResult<()> {
24 let db = Arc::clone(&self.db);
25 let invocation = invocation.clone();
26 let call = call.clone();
27 blocking(move || {
28
29 let conn = db.conn.lock().map_err(lock_err)?;
30
31 let tx = conn.unchecked_transaction().map_err(sql_err)?;
32
33 let args_json = serde_json::to_string(&call.serialized_arguments.0)
34 .map_err(|e| RustvelloError::Serialization { message: e.to_string() })?;
35
36 let (parent_inv_id, wf_id, wf_type, wf_depth) = match &invocation.workflow {
37 Some(wf) => (
38 invocation
39 .parent_invocation_id
40 .as_ref()
41 .map(|id| id.as_str().to_owned()),
42 Some(wf.workflow_id.as_str().to_owned()),
43 Some(wf.workflow_type.to_string()),
44 Some(wf.depth as i64),
45 ),
46 None => (
47 invocation
48 .parent_invocation_id
49 .as_ref()
50 .map(|id| id.as_str().to_owned()),
51 None,
52 None,
53 None,
54 ),
55 };
56
57 tx.execute(
58 "INSERT OR REPLACE INTO invocations (invocation_id, task_id, call_id, status, created_at, updated_at, parent_invocation_id, workflow_id, workflow_type, workflow_depth)
59 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
60 rusqlite::params![
61 &invocation.invocation_id.as_str(),
62 &invocation.task_id.to_string(),
63 &invocation.call_id.to_string(),
64 &invocation.status.to_string(),
65 &invocation.created_at.to_rfc3339(),
66 &invocation.updated_at.to_rfc3339(),
67 &parent_inv_id,
68 &wf_id,
69 &wf_type,
70 &wf_depth,
71 ],
72 )
73 .map_err(sql_err)?;
74
75 tx.execute(
76 "INSERT OR REPLACE INTO calls (call_id, task_id, serialized_arguments) VALUES (?1, ?2, ?3)",
77 rusqlite::params![
78 &call.call_id.to_string(),
79 &call.task_id.to_string(),
80 &args_json,
81 ],
82 )
83 .map_err(sql_err)?;
84
85 tx.commit().map_err(sql_err)?;
86
87 Ok(())
88
89 })
90 .await
91 }
92
93 async fn get_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<InvocationDTO> {
94 let db = Arc::clone(&self.db);
95 let invocation_id = invocation_id.clone();
96 blocking(move || {
97
98 let conn = db.conn.lock().map_err(lock_err)?;
99
100 let (task_id_str, call_id_str, status_str, created_str, updated_str, parent_inv_id, wf_id, wf_type, wf_depth): (
101 String,
102 String,
103 String,
104 String,
105 String,
106 Option<String>,
107 Option<String>,
108 Option<String>,
109 Option<i64>,
110 ) = conn
111 .query_row(
112 "SELECT task_id, call_id, status, created_at, updated_at, parent_invocation_id, workflow_id, workflow_type, workflow_depth FROM invocations WHERE invocation_id = ?1",
113 [invocation_id.as_str()],
114 |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?, row.get(4)?, row.get(5)?, row.get(6)?, row.get(7)?, row.get(8)?)),
115 )
116 .map_err(|_| RustvelloError::InvocationNotFound { invocation_id: invocation_id.clone() })?;
117
118 let task_id: TaskId = task_id_str
119 .parse()
120 .map_err(|e| RustvelloError::state_backend(format!("invalid task_id in database: {e}")))?;
121
122 let args_id = call_id_str
123 .rsplit_once(':')
124 .map_or(call_id_str.as_str(), |(_, a)| a);
125 let call_id = CallId::new(task_id.clone(), args_id);
126
127 let created_at = parse_timestamp(&created_str)?;
128 let updated_at = parse_timestamp(&updated_str)?;
129
130 let parent_invocation_id = parent_inv_id.map(InvocationId::from_string);
131
132 let workflow = match (wf_id, wf_type) {
133 (Some(wf_id_str), Some(wf_type_str)) => {
134 let wf_task_id: TaskId = wf_type_str.parse().map_err(|e| {
135 RustvelloError::state_backend(format!("invalid workflow task_id in database: {e}"))
136 })?;
137 Some(WorkflowIdentity {
138 workflow_id: InvocationId::from_string(wf_id_str),
139 workflow_type: wf_task_id,
140 parent_id: None,
141 depth: u32::try_from(wf_depth.unwrap_or(0)).unwrap_or(0),
142 })
143 }
144 _ => None,
145 };
146
147 Ok(InvocationDTO {
148 invocation_id: invocation_id.clone(),
149 task_id,
150 call_id,
151 status: parse_status(&status_str)?,
152 created_at,
153 updated_at,
154 parent_invocation_id,
155 workflow,
156 })
157
158 })
159 .await
160 }
161
162 async fn get_call(&self, call_id: &CallId) -> RustvelloResult<CallDTO> {
163 let db = Arc::clone(&self.db);
164 let call_id = call_id.clone();
165 blocking(move || {
166 let conn = db.conn.lock().map_err(lock_err)?;
167 let call_id_str = call_id.to_string();
168
169 let (task_id_str, args_json): (String, String) = conn
170 .query_row(
171 "SELECT task_id, serialized_arguments FROM calls WHERE call_id = ?1",
172 [&call_id_str],
173 |row| Ok((row.get(0)?, row.get(1)?)),
174 )
175 .map_err(|_| {
176 RustvelloError::state_backend(format!("call not found: {}", call_id_str))
177 })?;
178
179 let task_id: TaskId = task_id_str.parse().map_err(|e| {
180 RustvelloError::state_backend(format!("invalid task_id in database: {e}"))
181 })?;
182
183 let args_map: std::collections::BTreeMap<String, String> =
184 serde_json::from_str(&args_json).map_err(|e| RustvelloError::Serialization {
185 message: e.to_string(),
186 })?;
187
188 let args = SerializedArguments(args_map);
189
190 Ok(CallDTO {
191 call_id: call_id.clone(),
192 task_id,
193 serialized_arguments: args,
194 })
195 })
196 .await
197 }
198
199 async fn store_result(
200 &self,
201 invocation_id: &InvocationId,
202 result: &str,
203 ) -> RustvelloResult<()> {
204 let db = Arc::clone(&self.db);
205 let invocation_id = invocation_id.clone();
206 let result = result.to_owned();
207 blocking(move || {
208 let conn = db.conn.lock().map_err(lock_err)?;
209 conn.execute(
210 "INSERT OR REPLACE INTO results (invocation_id, result) VALUES (?1, ?2)",
211 rusqlite::params![invocation_id.as_str(), result],
212 )
213 .map_err(sql_err)?;
214 Ok(())
215 })
216 .await
217 }
218
219 async fn get_result(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<String>> {
220 let db = Arc::clone(&self.db);
221 let invocation_id = invocation_id.clone();
222 blocking(move || {
223 let conn = db.conn.lock().map_err(lock_err)?;
224 let result: Option<String> = conn
225 .query_row(
226 "SELECT result FROM results WHERE invocation_id = ?1",
227 [invocation_id.as_str()],
228 |row| row.get(0),
229 )
230 .ok();
231 Ok(result)
232 })
233 .await
234 }
235
236 async fn store_error(
237 &self,
238 invocation_id: &InvocationId,
239 error: &TaskError,
240 ) -> RustvelloResult<()> {
241 let db = Arc::clone(&self.db);
242 let invocation_id = invocation_id.clone();
243 let error = error.clone();
244 blocking(move || {
245
246 let conn = db.conn.lock().map_err(lock_err)?;
247 conn.execute(
248 "INSERT OR REPLACE INTO errors (invocation_id, error_type, message, traceback) VALUES (?1, ?2, ?3, ?4)",
249 rusqlite::params![
250 invocation_id.as_str(),
251 &error.error_type,
252 &error.message,
253 &error.traceback,
254 ],
255 )
256 .map_err(sql_err)?;
257 Ok(())
258
259 })
260 .await
261 }
262
263 async fn get_error(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<TaskError>> {
264 let db = Arc::clone(&self.db);
265 let invocation_id = invocation_id.clone();
266 blocking(move || {
267 let conn = db.conn.lock().map_err(lock_err)?;
268 let result: Option<(String, String, Option<String>)> = conn
269 .query_row(
270 "SELECT error_type, message, traceback FROM errors WHERE invocation_id = ?1",
271 [invocation_id.as_str()],
272 |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
273 )
274 .ok();
275
276 Ok(result.map(|(error_type, message, traceback)| TaskError {
277 error_type,
278 message,
279 traceback,
280 }))
281 })
282 .await
283 }
284
285 async fn add_history(&self, history: &InvocationHistory) -> RustvelloResult<()> {
286 let db = Arc::clone(&self.db);
287 let history = history.clone();
288 blocking(move || {
289
290 let conn = db.conn.lock().map_err(lock_err)?;
291 let hist_ts = history.history_timestamp.map(|ts| ts.to_rfc3339());
292 conn.execute(
293 "INSERT INTO history (invocation_id, status, runner_id, timestamp, message, history_timestamp) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
294 rusqlite::params![
295 &history.invocation_id.as_str(),
296 &history.status_record.status.to_string(),
297 &history.status_record.runner_id.as_ref().map(|r| r.as_str().to_string()),
298 &history.status_record.timestamp.to_rfc3339(),
299 &history.message,
300 &hist_ts,
301 ],
302 )
303 .map_err(sql_err)?;
304 Ok(())
305
306 })
307 .await
308 }
309
310 async fn get_history(
311 &self,
312 invocation_id: &InvocationId,
313 ) -> RustvelloResult<Vec<InvocationHistory>> {
314 let db = Arc::clone(&self.db);
315 let invocation_id = invocation_id.clone();
316 blocking(move || {
317
318 let conn = db.conn.lock().map_err(lock_err)?;
319
320 let mut stmt = conn
321 .prepare(
322 "SELECT status, runner_id, timestamp, message, history_timestamp FROM history WHERE invocation_id = ?1 ORDER BY id",
323 )
324 .map_err(sql_err)?;
325
326 let histories: Vec<InvocationHistory> = stmt
327 .query_map([invocation_id.as_str()], |row| {
328 let status_str: String = row.get(0)?;
329 let runner_id: Option<String> = row.get(1)?;
330 let timestamp_str: String = row.get(2)?;
331 let message: Option<String> = row.get(3)?;
332 let hist_ts_str: Option<String> = row.get(4)?;
333
334 let timestamp = chrono::DateTime::parse_from_rfc3339(×tamp_str)
335 .map(|dt| dt.with_timezone(&chrono::Utc))
336 .map_err(|e| {
337 rusqlite::Error::FromSqlConversionFailure(
338 2,
339 rusqlite::types::Type::Text,
340 Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())),
341 )
342 })?;
343
344 let history_timestamp = hist_ts_str
345 .and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
346 .map(|dt| dt.with_timezone(&chrono::Utc));
347
348 let status = status_str.parse::<InvocationStatus>().map_err(|e| {
349 rusqlite::Error::FromSqlConversionFailure(
350 0,
351 rusqlite::types::Type::Text,
352 Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e)),
353 )
354 })?;
355
356 Ok(InvocationHistory {
357 invocation_id: invocation_id.clone(),
358 status_record: InvocationStatusRecord {
359 status,
360 runner_id: runner_id.clone().map(RunnerId::from_string),
361 timestamp,
362 },
363 message,
364 runner_id: runner_id.map(RunnerId::from_string),
365 registered_by_inv_id: None,
366 history_timestamp,
367 })
368 })
369 .map_err(sql_err)?
370 .collect::<Result<Vec<_>, _>>()
371 .map_err(sql_err)?;
372
373 Ok(histories)
374
375 })
376 .await
377 }
378
379 async fn purge(&self) -> RustvelloResult<()> {
380 let db = Arc::clone(&self.db);
381 blocking(move || {
382 let conn = db.conn.lock().map_err(lock_err)?;
383 conn.execute_batch(
384 "DELETE FROM invocations;
385 DELETE FROM calls;
386 DELETE FROM results;
387 DELETE FROM errors;
388 DELETE FROM history;
389 DELETE FROM status_records;
390 DELETE FROM waiting_for;
391 DELETE FROM broker_queue;
392 DELETE FROM workflow_runs;
393 DELETE FROM workflow_data;
394 DELETE FROM app_infos;
395 DELETE FROM workflow_sub_invocations;
396 DELETE FROM runner_contexts;",
397 )
398 .map_err(sql_err)?;
399 Ok(())
400 })
401 .await
402 }
403}