Skip to main content

rustvello_sqlite/state_backend/
runner.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::RustvelloResult;
6use rustvello_core::state_backend::{StateBackendRunner, StoredRunnerContext};
7
8use rustvello_proto::identifiers::{InvocationId, RunnerId};
9use rustvello_proto::invocation::InvocationHistory;
10use rustvello_proto::status::InvocationStatusRecord;
11
12use crate::db::{blocking, lock_err, parse_status, parse_timestamp, sql_err};
13
14use super::SqliteStateBackend;
15
16#[async_trait]
17impl StateBackendRunner for SqliteStateBackend {
18    async fn store_runner_context(&self, context: &StoredRunnerContext) -> RustvelloResult<()> {
19        let db = Arc::clone(&self.db);
20        let context = context.clone();
21        blocking(move || {
22
23            let conn = db.conn.lock().map_err(lock_err)?;
24            conn.execute(
25                "INSERT OR REPLACE INTO runner_contexts
26                 (runner_id, runner_cls, pid, hostname, thread_id, started_at, parent_runner_id, parent_runner_cls)
27                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
28                rusqlite::params![
29                    &context.runner_id,
30                    &context.runner_cls,
31                    context.pid as i64,
32                    &context.hostname,
33                    context.thread_id as i64,
34                    context.started_at.to_rfc3339(),
35                    &context.parent_runner_id,
36                    &context.parent_runner_cls,
37                ],
38            )
39            .map_err(sql_err)?;
40            Ok(())
41
42        })
43        .await
44    }
45
46    async fn get_runner_context(
47        &self,
48        runner_id: &str,
49    ) -> RustvelloResult<Option<StoredRunnerContext>> {
50        let db = Arc::clone(&self.db);
51        let runner_id = runner_id.to_owned();
52        blocking(move || {
53            let conn = db.conn.lock().map_err(lock_err)?;
54            let result = conn
55                .query_row(
56                    "SELECT runner_id, runner_cls, pid, hostname, thread_id, started_at,
57                            parent_runner_id, parent_runner_cls
58                     FROM runner_contexts WHERE runner_id = ?1",
59                    rusqlite::params![runner_id],
60                    parse_runner_context,
61                )
62                .ok();
63            Ok(result)
64        })
65        .await
66    }
67
68    async fn get_runner_contexts_by_parent(
69        &self,
70        parent_runner_id: &str,
71    ) -> RustvelloResult<Vec<StoredRunnerContext>> {
72        let db = Arc::clone(&self.db);
73        let parent_runner_id = parent_runner_id.to_owned();
74        blocking(move || {
75            let conn = db.conn.lock().map_err(lock_err)?;
76            let mut stmt = conn
77                .prepare(
78                    "SELECT runner_id, runner_cls, pid, hostname, thread_id, started_at,
79                            parent_runner_id, parent_runner_cls
80                     FROM runner_contexts WHERE parent_runner_id = ?1",
81                )
82                .map_err(sql_err)?;
83            let contexts = stmt
84                .query_map([parent_runner_id], parse_runner_context)
85                .map_err(sql_err)?
86                .collect::<Result<Vec<_>, _>>()
87                .map_err(sql_err)?;
88            Ok(contexts)
89        })
90        .await
91    }
92
93    async fn get_invocation_ids_by_runner(
94        &self,
95        runner_id: &str,
96        limit: usize,
97        offset: usize,
98    ) -> RustvelloResult<Vec<InvocationId>> {
99        let db = Arc::clone(&self.db);
100        let runner_id = runner_id.to_owned();
101        blocking(move || {
102            let conn = db.conn.lock().map_err(lock_err)?;
103            let sql = if limit > 0 {
104                "SELECT DISTINCT invocation_id FROM history WHERE runner_id = ?1 LIMIT ?2 OFFSET ?3"
105            } else {
106                "SELECT DISTINCT invocation_id FROM history WHERE runner_id = ?1 LIMIT -1 OFFSET ?3"
107            };
108            let mut stmt = conn.prepare(sql).map_err(sql_err)?;
109            let ids = stmt
110                .query_map(
111                    rusqlite::params![runner_id, limit as i64, offset as i64],
112                    |row| {
113                        let id: String = row.get(0)?;
114                        Ok(InvocationId::from_string(id))
115                    },
116                )
117                .map_err(sql_err)?
118                .collect::<Result<Vec<_>, _>>()
119                .map_err(sql_err)?;
120            Ok(ids)
121        })
122        .await
123    }
124
125    async fn count_invocations_by_runner(&self, runner_id: &str) -> RustvelloResult<usize> {
126        let db = Arc::clone(&self.db);
127        let runner_id = runner_id.to_owned();
128        blocking(move || {
129            let conn = db.conn.lock().map_err(lock_err)?;
130            let count: i64 = conn
131                .query_row(
132                    "SELECT COUNT(DISTINCT invocation_id) FROM history WHERE runner_id = ?1",
133                    rusqlite::params![runner_id],
134                    |row| row.get(0),
135                )
136                .map_err(sql_err)?;
137            Ok(count as usize)
138        })
139        .await
140    }
141
142    async fn get_history_in_timerange(
143        &self,
144        start: chrono::DateTime<chrono::Utc>,
145        end: chrono::DateTime<chrono::Utc>,
146        limit: usize,
147        offset: usize,
148    ) -> RustvelloResult<Vec<InvocationHistory>> {
149        let db = Arc::clone(&self.db);
150        blocking(move || {
151            let conn = db.conn.lock().map_err(lock_err)?;
152            let start_str = start.to_rfc3339();
153            let end_str = end.to_rfc3339();
154            let effective_ts = "COALESCE(history_timestamp, timestamp)";
155            let sql = format!(
156                "SELECT invocation_id, status, runner_id, timestamp, message, history_timestamp
157                 FROM history WHERE {effective_ts} >= ?1 AND {effective_ts} <= ?2
158                 ORDER BY {effective_ts} ASC LIMIT ?3 OFFSET ?4"
159            );
160            let limit_val: i64 = if limit > 0 { limit as i64 } else { -1 };
161            let mut stmt = conn.prepare(&sql).map_err(sql_err)?;
162            let histories = stmt
163                .query_map(
164                    rusqlite::params![&start_str, &end_str, limit_val, offset as i64],
165                    |row| {
166                        let inv_id: String = row.get(0)?;
167                        let status_str: String = row.get(1)?;
168                        let runner_id: Option<String> = row.get(2)?;
169                        let ts_str: String = row.get(3)?;
170                        let message: Option<String> = row.get(4)?;
171                        let hist_ts_str: Option<String> = row.get(5)?;
172                        Ok((inv_id, status_str, runner_id, ts_str, message, hist_ts_str))
173                    },
174                )
175                .map_err(sql_err)?
176                .collect::<Result<Vec<_>, _>>()
177                .map_err(sql_err)?
178                .into_iter()
179                .map(
180                    |(inv_id, status_str, runner_id, ts_str, message, hist_ts_str)| {
181                        let status = parse_status(&status_str)?;
182                        let timestamp = parse_timestamp(&ts_str)?;
183                        let history_timestamp =
184                            hist_ts_str.map(|s| parse_timestamp(&s)).transpose()?;
185                        Ok(InvocationHistory {
186                            invocation_id: InvocationId::from_string(inv_id),
187                            status_record: InvocationStatusRecord {
188                                status,
189                                runner_id: runner_id.clone().map(RunnerId::from_string),
190                                timestamp,
191                            },
192                            message,
193                            runner_id: runner_id.map(RunnerId::from_string),
194                            registered_by_inv_id: None,
195                            history_timestamp,
196                        })
197                    },
198                )
199                .collect::<RustvelloResult<Vec<_>>>()?;
200            Ok(histories)
201        })
202        .await
203    }
204
205    async fn get_matching_runner_contexts(
206        &self,
207        partial_id: &str,
208    ) -> RustvelloResult<Vec<StoredRunnerContext>> {
209        let db = Arc::clone(&self.db);
210        let partial_id = partial_id.to_owned();
211        blocking(move || {
212            let conn = db.conn.lock().map_err(lock_err)?;
213            let pattern = format!("%{partial_id}%");
214            let mut stmt = conn
215                .prepare(
216                    "SELECT runner_id, runner_cls, pid, hostname, thread_id, started_at,
217                            parent_runner_id, parent_runner_cls
218                     FROM runner_contexts WHERE runner_id LIKE ?1",
219                )
220                .map_err(sql_err)?;
221            let contexts = stmt
222                .query_map([&pattern], parse_runner_context)
223                .map_err(sql_err)?
224                .collect::<Result<Vec<_>, _>>()
225                .map_err(sql_err)?;
226            Ok(contexts)
227        })
228        .await
229    }
230}
231
232/// Parse a `runner_contexts` row into a `StoredRunnerContext`.
233fn parse_runner_context(row: &rusqlite::Row<'_>) -> rusqlite::Result<StoredRunnerContext> {
234    let runner_id: String = row.get(0)?;
235    let runner_cls: String = row.get(1)?;
236    let pid: i64 = row.get(2)?;
237    let hostname: String = row.get(3)?;
238    let thread_id: i64 = row.get(4)?;
239    let started_at_str: String = row.get(5)?;
240    let parent_runner_id: Option<String> = row.get(6)?;
241    let parent_runner_cls: Option<String> = row.get(7)?;
242
243    let started_at = chrono::DateTime::parse_from_rfc3339(&started_at_str)
244        .map_or_else(|_| chrono::Utc::now(), |dt| dt.with_timezone(&chrono::Utc));
245
246    Ok(StoredRunnerContext {
247        runner_id,
248        runner_cls,
249        pid: u32::try_from(pid).unwrap_or(0),
250        hostname,
251        thread_id: u64::try_from(thread_id).unwrap_or(0),
252        started_at,
253        parent_runner_id,
254        parent_runner_cls,
255    })
256}