Skip to main content

rustvello_sqlite/orchestrator/
query.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::RustvelloResult;
6use rustvello_core::orchestrator::OrchestratorQuery;
7use rustvello_proto::call::SerializedArguments;
8use rustvello_proto::identifiers::{CallId, InvocationId, TaskId};
9use rustvello_proto::status::InvocationStatus;
10
11use crate::db::{blocking, lock_err, sql_err};
12
13use super::SqliteOrchestrator;
14
15#[async_trait]
16impl OrchestratorQuery for SqliteOrchestrator {
17    async fn get_invocations_by_task(
18        &self,
19        task_id: &TaskId,
20    ) -> RustvelloResult<Vec<InvocationId>> {
21        let db = Arc::clone(&self.db);
22        let task_id = task_id.clone();
23        blocking(move || {
24            let conn = db.conn.lock().map_err(lock_err)?;
25            let task_id_str = task_id.to_string();
26
27            let mut stmt = conn
28                .prepare("SELECT invocation_id FROM invocations WHERE task_id = ?1")
29                .map_err(sql_err)?;
30
31            let ids: Vec<InvocationId> = stmt
32                .query_map([&task_id_str], |row| {
33                    let id: String = row.get(0)?;
34                    Ok(InvocationId::from_string(id))
35                })
36                .map_err(sql_err)?
37                .collect::<Result<Vec<_>, _>>()
38                .map_err(sql_err)?;
39
40            Ok(ids)
41        })
42        .await
43    }
44
45    async fn get_invocations_by_call(
46        &self,
47        call_id: &CallId,
48    ) -> RustvelloResult<Vec<InvocationId>> {
49        let db = Arc::clone(&self.db);
50        let call_id = call_id.clone();
51        blocking(move || {
52            let conn = db.conn.lock().map_err(lock_err)?;
53            let call_id_str = call_id.to_string();
54
55            let mut stmt = conn
56                .prepare("SELECT invocation_id FROM invocations WHERE call_id = ?1")
57                .map_err(sql_err)?;
58
59            let ids: Vec<InvocationId> = stmt
60                .query_map([&call_id_str], |row| {
61                    let id: String = row.get(0)?;
62                    Ok(InvocationId::from_string(id))
63                })
64                .map_err(sql_err)?
65                .collect::<Result<Vec<_>, _>>()
66                .map_err(sql_err)?;
67
68            Ok(ids)
69        })
70        .await
71    }
72
73    async fn get_invocations_by_status(
74        &self,
75        status: InvocationStatus,
76        task_id: Option<&TaskId>,
77    ) -> RustvelloResult<Vec<InvocationId>> {
78        let db = Arc::clone(&self.db);
79        let task_id = task_id.cloned();
80        blocking(move || {
81            let conn = db.conn.lock().map_err(lock_err)?;
82            let status_str = status.to_string();
83
84            let ids: Vec<InvocationId> = if let Some(tid) = task_id {
85                let task_id_str = tid.to_string();
86                let mut stmt = conn
87                    .prepare(
88                        "SELECT invocation_id FROM invocations WHERE status = ?1 AND task_id = ?2",
89                    )
90                    .map_err(sql_err)?;
91                let result: Vec<InvocationId> = stmt
92                    .query_map(rusqlite::params![&status_str, &task_id_str], |row| {
93                        let id: String = row.get(0)?;
94                        Ok(InvocationId::from_string(id))
95                    })
96                    .map_err(sql_err)?
97                    .collect::<Result<Vec<_>, _>>()
98                    .map_err(sql_err)?;
99                result
100            } else {
101                let mut stmt = conn
102                    .prepare("SELECT invocation_id FROM invocations WHERE status = ?1")
103                    .map_err(sql_err)?;
104                let result: Vec<InvocationId> = stmt
105                    .query_map([&status_str], |row| {
106                        let id: String = row.get(0)?;
107                        Ok(InvocationId::from_string(id))
108                    })
109                    .map_err(sql_err)?
110                    .collect::<Result<Vec<_>, _>>()
111                    .map_err(sql_err)?;
112                result
113            };
114
115            Ok(ids)
116        })
117        .await
118    }
119
120    async fn count_invocations(
121        &self,
122        task_id: Option<&TaskId>,
123        statuses: Option<&[InvocationStatus]>,
124    ) -> RustvelloResult<usize> {
125        let db = Arc::clone(&self.db);
126        let task_id = task_id.cloned();
127        let statuses = statuses.map(<[InvocationStatus]>::to_vec);
128        blocking(move || {
129            let conn = db.conn.lock().map_err(lock_err)?;
130
131            let mut sql = String::from("SELECT COUNT(*) FROM status_records sr");
132            let mut params: Vec<String> = Vec::new();
133            let mut where_clauses = Vec::new();
134
135            if let Some(tid) = task_id {
136                sql.push_str(" JOIN invocations inv ON sr.invocation_id = inv.invocation_id");
137                where_clauses.push(format!("inv.task_id = ?{}", params.len() + 1));
138                params.push(tid.to_string());
139            }
140
141            if let Some(ss) = statuses {
142                if !ss.is_empty() {
143                    let placeholders: Vec<String> = (0..ss.len())
144                        .map(|i| format!("?{}", params.len() + i + 1))
145                        .collect();
146                    where_clauses.push(format!("sr.status IN ({})", placeholders.join(",")));
147                    for s in ss {
148                        params.push(s.to_string());
149                    }
150                }
151            }
152
153            if !where_clauses.is_empty() {
154                sql.push_str(" WHERE ");
155                sql.push_str(&where_clauses.join(" AND "));
156            }
157
158            let count: usize = conn
159                .query_row(&sql, rusqlite::params_from_iter(params.iter()), |row| {
160                    row.get(0)
161                })
162                .map_err(sql_err)?;
163            Ok(count)
164        })
165        .await
166    }
167
168    async fn get_invocation_ids_paginated(
169        &self,
170        task_id: Option<&TaskId>,
171        statuses: Option<&[InvocationStatus]>,
172        limit: usize,
173        offset: usize,
174    ) -> RustvelloResult<Vec<InvocationId>> {
175        let db = Arc::clone(&self.db);
176        let task_id = task_id.cloned();
177        let statuses = statuses.map(<[InvocationStatus]>::to_vec);
178        blocking(move || {
179            let conn = db.conn.lock().map_err(lock_err)?;
180            let mut sql = String::from("SELECT sr.invocation_id FROM status_records sr");
181            let mut params: Vec<String> = Vec::new();
182            let mut where_clauses = Vec::new();
183
184            if let Some(tid) = task_id {
185                sql.push_str(" JOIN invocations inv ON sr.invocation_id = inv.invocation_id");
186                where_clauses.push(format!("inv.task_id = ?{}", params.len() + 1));
187                params.push(tid.to_string());
188            }
189
190            if let Some(ss) = statuses {
191                if !ss.is_empty() {
192                    let placeholders: Vec<String> = (0..ss.len())
193                        .map(|i| format!("?{}", params.len() + i + 1))
194                        .collect();
195                    where_clauses.push(format!("sr.status IN ({})", placeholders.join(",")));
196                    for s in ss {
197                        params.push(s.to_string());
198                    }
199                }
200            }
201
202            if !where_clauses.is_empty() {
203                sql.push_str(" WHERE ");
204                sql.push_str(&where_clauses.join(" AND "));
205            }
206
207            sql.push_str(&format!(
208                " LIMIT ?{} OFFSET ?{}",
209                params.len() + 1,
210                params.len() + 2
211            ));
212            params.push(limit.to_string());
213            params.push(offset.to_string());
214
215            let mut stmt = conn.prepare(&sql).map_err(sql_err)?;
216            let ids: Vec<InvocationId> = stmt
217                .query_map(rusqlite::params_from_iter(params.iter()), |row| {
218                    let id: String = row.get(0)?;
219                    Ok(InvocationId::from_string(id))
220                })
221                .map_err(sql_err)?
222                .collect::<Result<Vec<_>, _>>()
223                .map_err(sql_err)?;
224
225            Ok(ids)
226        })
227        .await
228    }
229
230    async fn get_blocking_invocations(&self, max_num: usize) -> RustvelloResult<Vec<InvocationId>> {
231        let db = Arc::clone(&self.db);
232        blocking(move || {
233            let conn = db.conn.lock().map_err(lock_err)?;
234            let mut stmt = conn
235                .prepare(
236                    "SELECT DISTINCT wf.waited_on_id FROM waiting_for wf
237                     JOIN status_records sr ON wf.waited_on_id = sr.invocation_id
238                     WHERE sr.status IN ('REGISTERED', 'PENDING', 'RUNNING')
239                       AND NOT EXISTS (
240                           SELECT 1 FROM waiting_for wf2
241                           WHERE wf2.waiter_id = wf.waited_on_id
242                       )
243                     LIMIT ?1",
244                )
245                .map_err(sql_err)?;
246            let ids: Vec<InvocationId> = stmt
247                .query_map([max_num as i64], |row| {
248                    let id: String = row.get(0)?;
249                    Ok(InvocationId::from_string(id))
250                })
251                .map_err(sql_err)?
252                .collect::<Result<Vec<_>, _>>()
253                .map_err(sql_err)?;
254            Ok(ids)
255        })
256        .await
257    }
258
259    async fn get_existing_invocations(
260        &self,
261        task_id: &TaskId,
262        cc_args: Option<&SerializedArguments>,
263        statuses: &[InvocationStatus],
264    ) -> RustvelloResult<Vec<InvocationId>> {
265        let db = Arc::clone(&self.db);
266        let task_id = task_id.clone();
267        let cc_args = cc_args.cloned();
268        let statuses = statuses.to_vec();
269        blocking(move || {
270            let conn = db.conn.lock().map_err(lock_err)?;
271            let task_key = task_id.to_string();
272
273            let mut params: Vec<String> = statuses
274                .iter()
275                .map(std::string::ToString::to_string)
276                .collect();
277            let status_clause = if statuses.is_empty() {
278                String::new()
279            } else {
280                let placeholders: Vec<String> =
281                    (0..statuses.len()).map(|i| format!("?{}", i + 1)).collect();
282                format!(" AND i.status IN ({})", placeholders.join(","))
283            };
284
285            let sql = match cc_args {
286                Some(ref args) => {
287                    let pairs = args.cc_arg_pairs();
288                    let n_pairs = pairs.len();
289                    let task_idx = params.len() + 1;
290                    params.push(task_key);
291                    let mut pair_conds = Vec::with_capacity(n_pairs);
292                    for (k, v) in &pairs {
293                        let ki = params.len() + 1;
294                        let vi = params.len() + 2;
295                        params.push(k.clone());
296                        params.push(v.clone());
297                        pair_conds.push(format!("(cp.arg_key = ?{ki} AND cp.arg_value = ?{vi})"));
298                    }
299                    let where_pairs = pair_conds.join(" OR ");
300                    format!(
301                        "SELECT cp.invocation_id FROM cc_arg_pairs cp
302                         JOIN invocations i ON cp.invocation_id = i.invocation_id
303                         WHERE cp.task_id = ?{task_idx} AND ({where_pairs}){status_clause}
304                         GROUP BY cp.invocation_id
305                         HAVING COUNT(*) = {n_pairs}"
306                    )
307                }
308                None => {
309                    let task_idx = params.len() + 1;
310                    params.push(task_key);
311                    if statuses.is_empty() {
312                        format!(
313                            "SELECT invocation_id FROM invocations
314                             WHERE task_id = ?{task_idx}"
315                        )
316                    } else {
317                        format!(
318                            "SELECT invocation_id FROM invocations i
319                             WHERE i.task_id = ?{task_idx}{status_clause}"
320                        )
321                    }
322                }
323            };
324
325            let mut stmt = conn.prepare(&sql).map_err(sql_err)?;
326            let ids: Vec<InvocationId> = stmt
327                .query_map(rusqlite::params_from_iter(params.iter()), |row| {
328                    let id: String = row.get(0)?;
329                    Ok(InvocationId::from_string(id))
330                })
331                .map_err(sql_err)?
332                .collect::<Result<Vec<_>, _>>()
333                .map_err(sql_err)?;
334
335            Ok(ids)
336        })
337        .await
338    }
339}