rustvello_sqlite/orchestrator/
query.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::RustvelloResult;
6use rustvello_core::orchestrator::OrchestratorQuery;
7use rustvello_proto::call::SerializedArguments;
8use rustvello_proto::identifiers::{CallId, InvocationId, TaskId};
9use rustvello_proto::status::InvocationStatus;
10
11use crate::db::{blocking, lock_err, sql_err};
12
13use super::SqliteOrchestrator;
14
15#[async_trait]
16impl OrchestratorQuery for SqliteOrchestrator {
17 async fn get_invocations_by_task(
18 &self,
19 task_id: &TaskId,
20 ) -> RustvelloResult<Vec<InvocationId>> {
21 let db = Arc::clone(&self.db);
22 let task_id = task_id.clone();
23 blocking(move || {
24 let conn = db.conn.lock().map_err(lock_err)?;
25 let task_id_str = task_id.to_string();
26
27 let mut stmt = conn
28 .prepare("SELECT invocation_id FROM invocations WHERE task_id = ?1")
29 .map_err(sql_err)?;
30
31 let ids: Vec<InvocationId> = stmt
32 .query_map([&task_id_str], |row| {
33 let id: String = row.get(0)?;
34 Ok(InvocationId::from_string(id))
35 })
36 .map_err(sql_err)?
37 .collect::<Result<Vec<_>, _>>()
38 .map_err(sql_err)?;
39
40 Ok(ids)
41 })
42 .await
43 }
44
45 async fn get_invocations_by_call(
46 &self,
47 call_id: &CallId,
48 ) -> RustvelloResult<Vec<InvocationId>> {
49 let db = Arc::clone(&self.db);
50 let call_id = call_id.clone();
51 blocking(move || {
52 let conn = db.conn.lock().map_err(lock_err)?;
53 let call_id_str = call_id.to_string();
54
55 let mut stmt = conn
56 .prepare("SELECT invocation_id FROM invocations WHERE call_id = ?1")
57 .map_err(sql_err)?;
58
59 let ids: Vec<InvocationId> = stmt
60 .query_map([&call_id_str], |row| {
61 let id: String = row.get(0)?;
62 Ok(InvocationId::from_string(id))
63 })
64 .map_err(sql_err)?
65 .collect::<Result<Vec<_>, _>>()
66 .map_err(sql_err)?;
67
68 Ok(ids)
69 })
70 .await
71 }
72
73 async fn get_invocations_by_status(
74 &self,
75 status: InvocationStatus,
76 task_id: Option<&TaskId>,
77 ) -> RustvelloResult<Vec<InvocationId>> {
78 let db = Arc::clone(&self.db);
79 let task_id = task_id.cloned();
80 blocking(move || {
81 let conn = db.conn.lock().map_err(lock_err)?;
82 let status_str = status.to_string();
83
84 let ids: Vec<InvocationId> = if let Some(tid) = task_id {
85 let task_id_str = tid.to_string();
86 let mut stmt = conn
87 .prepare(
88 "SELECT invocation_id FROM invocations WHERE status = ?1 AND task_id = ?2",
89 )
90 .map_err(sql_err)?;
91 let result: Vec<InvocationId> = stmt
92 .query_map(rusqlite::params![&status_str, &task_id_str], |row| {
93 let id: String = row.get(0)?;
94 Ok(InvocationId::from_string(id))
95 })
96 .map_err(sql_err)?
97 .collect::<Result<Vec<_>, _>>()
98 .map_err(sql_err)?;
99 result
100 } else {
101 let mut stmt = conn
102 .prepare("SELECT invocation_id FROM invocations WHERE status = ?1")
103 .map_err(sql_err)?;
104 let result: Vec<InvocationId> = stmt
105 .query_map([&status_str], |row| {
106 let id: String = row.get(0)?;
107 Ok(InvocationId::from_string(id))
108 })
109 .map_err(sql_err)?
110 .collect::<Result<Vec<_>, _>>()
111 .map_err(sql_err)?;
112 result
113 };
114
115 Ok(ids)
116 })
117 .await
118 }
119
120 async fn count_invocations(
121 &self,
122 task_id: Option<&TaskId>,
123 statuses: Option<&[InvocationStatus]>,
124 ) -> RustvelloResult<usize> {
125 let db = Arc::clone(&self.db);
126 let task_id = task_id.cloned();
127 let statuses = statuses.map(<[InvocationStatus]>::to_vec);
128 blocking(move || {
129 let conn = db.conn.lock().map_err(lock_err)?;
130
131 let mut sql = String::from("SELECT COUNT(*) FROM status_records sr");
132 let mut params: Vec<String> = Vec::new();
133 let mut where_clauses = Vec::new();
134
135 if let Some(tid) = task_id {
136 sql.push_str(" JOIN invocations inv ON sr.invocation_id = inv.invocation_id");
137 where_clauses.push(format!("inv.task_id = ?{}", params.len() + 1));
138 params.push(tid.to_string());
139 }
140
141 if let Some(ss) = statuses {
142 if !ss.is_empty() {
143 let placeholders: Vec<String> = (0..ss.len())
144 .map(|i| format!("?{}", params.len() + i + 1))
145 .collect();
146 where_clauses.push(format!("sr.status IN ({})", placeholders.join(",")));
147 for s in ss {
148 params.push(s.to_string());
149 }
150 }
151 }
152
153 if !where_clauses.is_empty() {
154 sql.push_str(" WHERE ");
155 sql.push_str(&where_clauses.join(" AND "));
156 }
157
158 let count: usize = conn
159 .query_row(&sql, rusqlite::params_from_iter(params.iter()), |row| {
160 row.get(0)
161 })
162 .map_err(sql_err)?;
163 Ok(count)
164 })
165 .await
166 }
167
168 async fn get_invocation_ids_paginated(
169 &self,
170 task_id: Option<&TaskId>,
171 statuses: Option<&[InvocationStatus]>,
172 limit: usize,
173 offset: usize,
174 ) -> RustvelloResult<Vec<InvocationId>> {
175 let db = Arc::clone(&self.db);
176 let task_id = task_id.cloned();
177 let statuses = statuses.map(<[InvocationStatus]>::to_vec);
178 blocking(move || {
179 let conn = db.conn.lock().map_err(lock_err)?;
180 let mut sql = String::from("SELECT sr.invocation_id FROM status_records sr");
181 let mut params: Vec<String> = Vec::new();
182 let mut where_clauses = Vec::new();
183
184 if let Some(tid) = task_id {
185 sql.push_str(" JOIN invocations inv ON sr.invocation_id = inv.invocation_id");
186 where_clauses.push(format!("inv.task_id = ?{}", params.len() + 1));
187 params.push(tid.to_string());
188 }
189
190 if let Some(ss) = statuses {
191 if !ss.is_empty() {
192 let placeholders: Vec<String> = (0..ss.len())
193 .map(|i| format!("?{}", params.len() + i + 1))
194 .collect();
195 where_clauses.push(format!("sr.status IN ({})", placeholders.join(",")));
196 for s in ss {
197 params.push(s.to_string());
198 }
199 }
200 }
201
202 if !where_clauses.is_empty() {
203 sql.push_str(" WHERE ");
204 sql.push_str(&where_clauses.join(" AND "));
205 }
206
207 sql.push_str(&format!(
208 " LIMIT ?{} OFFSET ?{}",
209 params.len() + 1,
210 params.len() + 2
211 ));
212 params.push(limit.to_string());
213 params.push(offset.to_string());
214
215 let mut stmt = conn.prepare(&sql).map_err(sql_err)?;
216 let ids: Vec<InvocationId> = stmt
217 .query_map(rusqlite::params_from_iter(params.iter()), |row| {
218 let id: String = row.get(0)?;
219 Ok(InvocationId::from_string(id))
220 })
221 .map_err(sql_err)?
222 .collect::<Result<Vec<_>, _>>()
223 .map_err(sql_err)?;
224
225 Ok(ids)
226 })
227 .await
228 }
229
230 async fn get_blocking_invocations(&self, max_num: usize) -> RustvelloResult<Vec<InvocationId>> {
231 let db = Arc::clone(&self.db);
232 blocking(move || {
233 let conn = db.conn.lock().map_err(lock_err)?;
234 let mut stmt = conn
235 .prepare(
236 "SELECT DISTINCT wf.waited_on_id FROM waiting_for wf
237 JOIN status_records sr ON wf.waited_on_id = sr.invocation_id
238 WHERE sr.status IN ('REGISTERED', 'PENDING', 'RUNNING')
239 AND NOT EXISTS (
240 SELECT 1 FROM waiting_for wf2
241 WHERE wf2.waiter_id = wf.waited_on_id
242 )
243 LIMIT ?1",
244 )
245 .map_err(sql_err)?;
246 let ids: Vec<InvocationId> = stmt
247 .query_map([max_num as i64], |row| {
248 let id: String = row.get(0)?;
249 Ok(InvocationId::from_string(id))
250 })
251 .map_err(sql_err)?
252 .collect::<Result<Vec<_>, _>>()
253 .map_err(sql_err)?;
254 Ok(ids)
255 })
256 .await
257 }
258
259 async fn get_existing_invocations(
260 &self,
261 task_id: &TaskId,
262 cc_args: Option<&SerializedArguments>,
263 statuses: &[InvocationStatus],
264 ) -> RustvelloResult<Vec<InvocationId>> {
265 let db = Arc::clone(&self.db);
266 let task_id = task_id.clone();
267 let cc_args = cc_args.cloned();
268 let statuses = statuses.to_vec();
269 blocking(move || {
270 let conn = db.conn.lock().map_err(lock_err)?;
271 let task_key = task_id.to_string();
272
273 let mut params: Vec<String> = statuses
274 .iter()
275 .map(std::string::ToString::to_string)
276 .collect();
277 let status_clause = if statuses.is_empty() {
278 String::new()
279 } else {
280 let placeholders: Vec<String> =
281 (0..statuses.len()).map(|i| format!("?{}", i + 1)).collect();
282 format!(" AND i.status IN ({})", placeholders.join(","))
283 };
284
285 let sql = match cc_args {
286 Some(ref args) => {
287 let pairs = args.cc_arg_pairs();
288 let n_pairs = pairs.len();
289 let task_idx = params.len() + 1;
290 params.push(task_key);
291 let mut pair_conds = Vec::with_capacity(n_pairs);
292 for (k, v) in &pairs {
293 let ki = params.len() + 1;
294 let vi = params.len() + 2;
295 params.push(k.clone());
296 params.push(v.clone());
297 pair_conds.push(format!("(cp.arg_key = ?{ki} AND cp.arg_value = ?{vi})"));
298 }
299 let where_pairs = pair_conds.join(" OR ");
300 format!(
301 "SELECT cp.invocation_id FROM cc_arg_pairs cp
302 JOIN invocations i ON cp.invocation_id = i.invocation_id
303 WHERE cp.task_id = ?{task_idx} AND ({where_pairs}){status_clause}
304 GROUP BY cp.invocation_id
305 HAVING COUNT(*) = {n_pairs}"
306 )
307 }
308 None => {
309 let task_idx = params.len() + 1;
310 params.push(task_key);
311 if statuses.is_empty() {
312 format!(
313 "SELECT invocation_id FROM invocations
314 WHERE task_id = ?{task_idx}"
315 )
316 } else {
317 format!(
318 "SELECT invocation_id FROM invocations i
319 WHERE i.task_id = ?{task_idx}{status_clause}"
320 )
321 }
322 }
323 };
324
325 let mut stmt = conn.prepare(&sql).map_err(sql_err)?;
326 let ids: Vec<InvocationId> = stmt
327 .query_map(rusqlite::params_from_iter(params.iter()), |row| {
328 let id: String = row.get(0)?;
329 Ok(InvocationId::from_string(id))
330 })
331 .map_err(sql_err)?
332 .collect::<Result<Vec<_>, _>>()
333 .map_err(sql_err)?;
334
335 Ok(ids)
336 })
337 .await
338 }
339}