Skip to main content

rustvello_sqlite/orchestrator/
status.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use chrono::Utc;
5
6use rustvello_core::error::{RustvelloError, RustvelloResult};
7use rustvello_core::orchestrator::OrchestratorStatus;
8use rustvello_proto::call::CallDTO;
9use rustvello_proto::identifiers::{InvocationId, RunnerId};
10use rustvello_proto::status::{InvocationStatus, InvocationStatusRecord};
11
12use crate::db::{blocking, lock_err, parse_status, parse_timestamp, sql_err};
13
14use super::SqliteOrchestrator;
15
16#[async_trait]
17impl OrchestratorStatus for SqliteOrchestrator {
18    async fn register_invocation(&self, call: &CallDTO) -> RustvelloResult<InvocationId> {
19        let invocation_id = InvocationId::new();
20        self.register_invocation_with_id(&invocation_id, call, None)
21            .await?;
22        Ok(invocation_id)
23    }
24
25    async fn register_invocation_with_id(
26        &self,
27        invocation_id: &InvocationId,
28        call: &CallDTO,
29        runner_id: Option<&RunnerId>,
30    ) -> RustvelloResult<InvocationStatusRecord> {
31        let db = Arc::clone(&self.db);
32        let invocation_id = invocation_id.clone();
33        let call = call.clone();
34        let runner_id = runner_id.cloned();
35        blocking(move || {
36
37            let now = Utc::now();
38            let now_str = now.to_rfc3339();
39            let status = InvocationStatus::Registered;
40            let status_str = status.to_string();
41            let task_id_str = call.task_id.to_string();
42            let call_id_str = call.call_id.to_string();
43            let runner_id_str = runner_id.as_ref().map(|r| r.as_str().to_owned());
44
45            let conn = db.conn.lock().map_err(lock_err)?;
46            let tx = conn.unchecked_transaction().map_err(sql_err)?;
47
48            tx.execute(
49                "INSERT OR IGNORE INTO invocations (invocation_id, task_id, call_id, status, created_at, updated_at)
50                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
51                rusqlite::params![
52                    invocation_id.as_str(),
53                    &task_id_str,
54                    &call_id_str,
55                    &status_str,
56                    &now_str,
57                    &now_str,
58                ],
59            )
60            .map_err(sql_err)?;
61
62            tx.execute(
63                "INSERT OR REPLACE INTO status_records (invocation_id, status, runner_id, timestamp)
64                 VALUES (?1, ?2, ?3, ?4)",
65                rusqlite::params![invocation_id.as_str(), &status_str, &runner_id_str, &now_str],
66            )
67            .map_err(sql_err)?;
68
69            tx.commit().map_err(sql_err)?;
70
71            Ok(InvocationStatusRecord {
72                status,
73                runner_id,
74                timestamp: now,
75            })
76
77        })
78        .await
79    }
80
81    async fn increment_invocation_retries(
82        &self,
83        invocation_id: &InvocationId,
84    ) -> RustvelloResult<u32> {
85        let db = Arc::clone(&self.db);
86        let invocation_id = invocation_id.clone();
87        blocking(move || {
88            let conn = db.conn.lock().map_err(lock_err)?;
89            let tx = conn.unchecked_transaction().map_err(sql_err)?;
90            tx.execute(
91                "INSERT INTO retries (invocation_id, retry_count) VALUES (?1, 1)
92                 ON CONFLICT(invocation_id) DO UPDATE SET retry_count = retry_count + 1",
93                [invocation_id.as_str()],
94            )
95            .map_err(sql_err)?;
96            let count: u32 = tx
97                .query_row(
98                    "SELECT retry_count FROM retries WHERE invocation_id = ?1",
99                    [invocation_id.as_str()],
100                    |row| row.get(0),
101                )
102                .map_err(sql_err)?;
103            tx.commit().map_err(sql_err)?;
104            Ok(count)
105        })
106        .await
107    }
108
109    async fn get_invocation_retries(&self, invocation_id: &InvocationId) -> RustvelloResult<u32> {
110        let db = Arc::clone(&self.db);
111        let invocation_id = invocation_id.clone();
112        blocking(move || {
113            let conn = db.conn.lock().map_err(lock_err)?;
114            let count: u32 = conn
115                .query_row(
116                    "SELECT retry_count FROM retries WHERE invocation_id = ?1",
117                    [invocation_id.as_str()],
118                    |row| row.get(0),
119                )
120                .unwrap_or(0);
121            Ok(count)
122        })
123        .await
124    }
125
126    async fn remove_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<()> {
127        let db = Arc::clone(&self.db);
128        let invocation_id = invocation_id.clone();
129        blocking(move || {
130            let conn = db.conn.lock().map_err(lock_err)?;
131            let tx = conn.unchecked_transaction().map_err(sql_err)?;
132            let id = invocation_id.as_str();
133            tx.execute("DELETE FROM status_records WHERE invocation_id = ?1", [id])
134                .map_err(sql_err)?;
135            tx.execute("DELETE FROM cc_arg_pairs WHERE invocation_id = ?1", [id])
136                .map_err(sql_err)?;
137            tx.execute(
138                "DELETE FROM waiting_for WHERE waiter_id = ?1 OR waited_on_id = ?1",
139                [id],
140            )
141            .map_err(sql_err)?;
142            tx.execute("DELETE FROM retries WHERE invocation_id = ?1", [id])
143                .map_err(sql_err)?;
144            tx.execute("DELETE FROM invocations WHERE invocation_id = ?1", [id])
145                .map_err(sql_err)?;
146            tx.execute(
147                "DELETE FROM auto_purge_schedule WHERE invocation_id = ?1",
148                [id],
149            )
150            .map_err(sql_err)?;
151            tx.commit().map_err(sql_err)?;
152            Ok(())
153        })
154        .await
155    }
156
157    async fn get_invocation_status(
158        &self,
159        invocation_id: &InvocationId,
160    ) -> RustvelloResult<InvocationStatusRecord> {
161        let db = Arc::clone(&self.db);
162        let invocation_id = invocation_id.clone();
163        blocking(move || {
164
165            let conn = db.conn.lock().map_err(lock_err)?;
166
167            let (status_str, runner_id_opt, timestamp_str): (String, Option<String>, String) = conn
168                .query_row(
169                    "SELECT status, runner_id, timestamp FROM status_records WHERE invocation_id = ?1",
170                    [invocation_id.as_str()],
171                    |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
172                )
173                .map_err(|_| RustvelloError::InvocationNotFound {
174                    invocation_id: invocation_id.clone(),
175                })?;
176
177            Ok(InvocationStatusRecord {
178                status: parse_status(&status_str)?,
179                runner_id: runner_id_opt.map(RunnerId::from_string),
180                timestamp: parse_timestamp(&timestamp_str)?,
181            })
182
183        })
184        .await
185    }
186
187    async fn set_invocation_status(
188        &self,
189        invocation_id: &InvocationId,
190        status: InvocationStatus,
191        runner_id: Option<&RunnerId>,
192    ) -> RustvelloResult<InvocationStatusRecord> {
193        let db = Arc::clone(&self.db);
194        let invocation_id = invocation_id.clone();
195        let runner_id = runner_id.cloned();
196        blocking(move || {
197
198            use rustvello_proto::status::status_record_transition;
199
200            let conn = db.conn.lock().map_err(lock_err)?;
201
202            let tx = conn.unchecked_transaction().map_err(sql_err)?;
203
204            let (current_status_str, current_runner_id_str, current_ts_str): (
205                String,
206                Option<String>,
207                String,
208            ) = tx
209                .query_row(
210                    "SELECT status, runner_id, timestamp FROM status_records WHERE invocation_id = ?1",
211                    [invocation_id.as_str()],
212                    |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
213                )
214                .map_err(|_| RustvelloError::InvocationNotFound {
215                    invocation_id: invocation_id.clone(),
216                })?;
217            let current_status = parse_status(&current_status_str)?;
218            let current_record = InvocationStatusRecord {
219                status: current_status,
220                runner_id: current_runner_id_str.map(RunnerId::from_string),
221                timestamp: chrono::DateTime::parse_from_rfc3339(&current_ts_str)
222                    .map_or_else(|_| Utc::now(), |dt| dt.with_timezone(&Utc)),
223            };
224
225            let new_record = status_record_transition(Some(&current_record), status, runner_id.as_ref())
226                .map_err(|e| {
227                    rustvello_core::error::status_machine_error_to_rustvello(
228                        e,
229                        &invocation_id,
230                        current_status,
231                    )
232                })?;
233
234            let now_str = new_record.timestamp.to_rfc3339();
235            let status_str = status.to_string();
236            let runner_id_str = new_record.runner_id.as_ref().map(|r| r.as_str().to_owned());
237
238            tx.execute(
239                "UPDATE status_records SET status = ?1, runner_id = ?2, timestamp = ?3 WHERE invocation_id = ?4",
240                rusqlite::params![&status_str, &runner_id_str, &now_str, invocation_id.as_str()],
241            )
242            .map_err(sql_err)?;
243
244            tx.execute(
245                "UPDATE invocations SET status = ?1, updated_at = ?2 WHERE invocation_id = ?3",
246                rusqlite::params![&status_str, &now_str, invocation_id.as_str()],
247            )
248            .map_err(sql_err)?;
249
250            tx.commit().map_err(sql_err)?;
251
252            Ok(new_record)
253
254        })
255        .await
256    }
257
258    async fn purge(&self) -> RustvelloResult<()> {
259        let db = Arc::clone(&self.db);
260        blocking(move || {
261            let conn = db.conn.lock().map_err(lock_err)?;
262            conn.execute_batch(
263                "DELETE FROM cc_arg_pairs;
264                 DELETE FROM waiting_for;
265                 DELETE FROM status_records;
266                 DELETE FROM retries;
267                 DELETE FROM runner_heartbeats;
268                 DELETE FROM auto_purge_schedule;
269                 DELETE FROM invocations;",
270            )
271            .map_err(sql_err)?;
272            Ok(())
273        })
274        .await
275    }
276
277    async fn schedule_auto_purge(&self, invocation_id: &InvocationId) -> RustvelloResult<()> {
278        let db = Arc::clone(&self.db);
279        let invocation_id = invocation_id.clone();
280        blocking(move || {
281
282            let now_str = Utc::now().to_rfc3339();
283            let conn = db.conn.lock().map_err(lock_err)?;
284            conn.execute(
285                "INSERT OR REPLACE INTO auto_purge_schedule (invocation_id, scheduled_at) VALUES (?1, ?2)",
286                rusqlite::params![invocation_id.as_str(), &now_str],
287            )
288            .map_err(sql_err)?;
289            Ok(())
290
291        })
292        .await
293    }
294
295    async fn run_auto_purge(&self, max_age_secs: u64) -> RustvelloResult<Vec<InvocationId>> {
296        let db = Arc::clone(&self.db);
297        let expired: Vec<String> = blocking(move || {
298            let threshold = Utc::now()
299                - chrono::Duration::seconds(i64::try_from(max_age_secs).unwrap_or(i64::MAX));
300            let threshold_str = threshold.to_rfc3339();
301
302            let conn = db.conn.lock().map_err(lock_err)?;
303            let tx = conn.unchecked_transaction().map_err(sql_err)?;
304            let mut stmt = tx
305                .prepare("SELECT invocation_id FROM auto_purge_schedule WHERE scheduled_at <= ?1")
306                .map_err(sql_err)?;
307            let rows: Vec<String> = stmt
308                .query_map([&threshold_str], |row| row.get(0))
309                .map_err(sql_err)?
310                .collect::<Result<Vec<String>, _>>()
311                .map_err(sql_err)?;
312            drop(stmt);
313            tx.execute(
314                "DELETE FROM auto_purge_schedule WHERE scheduled_at <= ?1",
315                [&threshold_str],
316            )
317            .map_err(sql_err)?;
318            tx.commit().map_err(sql_err)?;
319            Ok(rows)
320        })
321        .await?;
322
323        let mut purged = Vec::new();
324        for id_str in expired {
325            let inv_id = InvocationId::from_string(id_str);
326            if self.remove_invocation(&inv_id).await.is_ok() {
327                purged.push(inv_id);
328            }
329        }
330        Ok(purged)
331    }
332}