rustvello_postgres/orchestrator/
query.rs1use async_trait::async_trait;
2
3use rustvello_core::error::RustvelloResult;
4use rustvello_core::orchestrator::OrchestratorQuery;
5use rustvello_proto::call::SerializedArguments;
6use rustvello_proto::identifiers::{CallId, InvocationId, TaskId};
7use rustvello_proto::status::InvocationStatus;
8
9use super::PostgresOrchestrator;
10use crate::db::pg_err;
11
12#[async_trait]
13impl OrchestratorQuery for PostgresOrchestrator {
14 async fn get_invocations_by_task(
15 &self,
16 task_id: &TaskId,
17 ) -> RustvelloResult<Vec<InvocationId>> {
18 let client = self.db.conn().await?;
19 let task_id_str = task_id.to_string();
20
21 let rows = client
22 .query(
23 "SELECT invocation_id FROM invocations WHERE task_id = $1",
24 &[&task_id_str],
25 )
26 .await
27 .map_err(pg_err)?;
28
29 Ok(rows
30 .iter()
31 .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
32 .collect())
33 }
34
35 async fn get_invocations_by_call(
36 &self,
37 call_id: &CallId,
38 ) -> RustvelloResult<Vec<InvocationId>> {
39 let client = self.db.conn().await?;
40 let call_id_str = call_id.to_string();
41
42 let rows = client
43 .query(
44 "SELECT invocation_id FROM invocations WHERE call_id = $1",
45 &[&call_id_str],
46 )
47 .await
48 .map_err(pg_err)?;
49
50 Ok(rows
51 .iter()
52 .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
53 .collect())
54 }
55
56 async fn get_invocations_by_status(
57 &self,
58 status: InvocationStatus,
59 task_id: Option<&TaskId>,
60 ) -> RustvelloResult<Vec<InvocationId>> {
61 let client = self.db.conn().await?;
62 let status_str = status.to_string();
63
64 let rows = if let Some(tid) = task_id {
65 let task_id_str = tid.to_string();
66 client
67 .query(
68 "SELECT invocation_id FROM invocations WHERE status = $1 AND task_id = $2",
69 &[&status_str, &task_id_str],
70 )
71 .await
72 .map_err(pg_err)?
73 } else {
74 client
75 .query(
76 "SELECT invocation_id FROM invocations WHERE status = $1",
77 &[&status_str],
78 )
79 .await
80 .map_err(pg_err)?
81 };
82
83 Ok(rows
84 .iter()
85 .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
86 .collect())
87 }
88
89 async fn count_invocations(
90 &self,
91 task_id: Option<&TaskId>,
92 statuses: Option<&[InvocationStatus]>,
93 ) -> RustvelloResult<usize> {
94 let client = self.db.conn().await?;
95 let mut query = "SELECT COUNT(*) FROM invocations WHERE 1=1".to_string();
96 let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
97 let mut idx = 1;
98
99 if let Some(tid) = task_id {
100 query.push_str(&format!(" AND task_id = ${idx}"));
101 params.push(Box::new(tid.to_string()));
102 idx += 1;
103 }
104 if let Some(statuses) = statuses {
105 if !statuses.is_empty() {
106 let placeholders: Vec<String> = statuses
107 .iter()
108 .map(|s| {
109 let p = format!("${idx}");
110 params.push(Box::new(s.to_string()));
111 idx += 1;
112 p
113 })
114 .collect();
115 query.push_str(&format!(" AND status IN ({})", placeholders.join(",")));
116 }
117 }
118
119 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
120 .iter()
121 .map(|p| &**p as &(dyn tokio_postgres::types::ToSql + Sync))
122 .collect();
123 let row = client
124 .query_one(&query, ¶m_refs)
125 .await
126 .map_err(pg_err)?;
127 let count: i64 = row.get(0);
128 Ok(count as usize)
129 }
130
131 async fn get_invocation_ids_paginated(
132 &self,
133 task_id: Option<&TaskId>,
134 statuses: Option<&[InvocationStatus]>,
135 limit: usize,
136 offset: usize,
137 ) -> RustvelloResult<Vec<InvocationId>> {
138 let client = self.db.conn().await?;
139 let mut query = "SELECT invocation_id FROM invocations WHERE 1=1".to_string();
140 let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
141 let mut idx = 1;
142
143 if let Some(tid) = task_id {
144 query.push_str(&format!(" AND task_id = ${idx}"));
145 params.push(Box::new(tid.to_string()));
146 idx += 1;
147 }
148 if let Some(statuses) = statuses {
149 if !statuses.is_empty() {
150 let placeholders: Vec<String> = statuses
151 .iter()
152 .map(|s| {
153 let p = format!("${idx}");
154 params.push(Box::new(s.to_string()));
155 idx += 1;
156 p
157 })
158 .collect();
159 query.push_str(&format!(" AND status IN ({})", placeholders.join(",")));
160 }
161 }
162 query.push_str(&format!(
163 " ORDER BY created_at LIMIT ${idx} OFFSET ${}",
164 idx + 1
165 ));
166 params.push(Box::new(limit as i64));
167 params.push(Box::new(offset as i64));
168
169 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
170 .iter()
171 .map(|p| &**p as &(dyn tokio_postgres::types::ToSql + Sync))
172 .collect();
173 let rows = client.query(&query, ¶m_refs).await.map_err(pg_err)?;
174 Ok(rows
175 .iter()
176 .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
177 .collect())
178 }
179
180 async fn get_blocking_invocations(&self, max_num: usize) -> RustvelloResult<Vec<InvocationId>> {
181 let client = self.db.conn().await?;
182 let rows = client
183 .query(
184 "SELECT DISTINCT w.waited_on_id FROM waiting_for w
185 JOIN status_records sr ON w.waited_on_id = sr.invocation_id
186 WHERE sr.status IN ('REGISTERED', 'PENDING', 'RUNNING')
187 AND NOT EXISTS (
188 SELECT 1 FROM waiting_for w2
189 WHERE w2.waiter_id = w.waited_on_id
190 )
191 LIMIT $1",
192 &[&(max_num as i64)],
193 )
194 .await
195 .map_err(pg_err)?;
196 Ok(rows
197 .iter()
198 .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
199 .collect())
200 }
201
202 async fn get_existing_invocations(
203 &self,
204 task_id: &TaskId,
205 cc_args: Option<&SerializedArguments>,
206 statuses: &[InvocationStatus],
207 ) -> RustvelloResult<Vec<InvocationId>> {
208 let client = self.db.conn().await?;
209 let task_key = task_id.to_string();
210
211 let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
212 let mut idx = 1;
213
214 let status_clause = if statuses.is_empty() {
216 String::new()
217 } else {
218 let placeholders: Vec<String> = statuses
219 .iter()
220 .map(|s| {
221 let p = format!("${idx}");
222 params.push(Box::new(s.to_string()));
223 idx += 1;
224 p
225 })
226 .collect();
227 format!(" AND i.status IN ({})", placeholders.join(","))
228 };
229
230 let query = match cc_args {
231 Some(args) => {
232 let pairs = args.cc_arg_pairs();
233 let n_pairs = pairs.len();
234 let task_p = format!("${idx}");
235 params.push(Box::new(task_key));
236 idx += 1;
237 let pair_conds: Vec<String> = pairs
238 .iter()
239 .map(|(k, v)| {
240 let kp = format!("${idx}");
241 params.push(Box::new(k.clone()));
242 idx += 1;
243 let vp = format!("${idx}");
244 params.push(Box::new(v.clone()));
245 idx += 1;
246 format!("(cp.arg_key = {kp} AND cp.arg_value = {vp})")
247 })
248 .collect();
249 let where_pairs = pair_conds.join(" OR ");
250 format!(
251 "SELECT cp.invocation_id FROM cc_arg_pairs cp
252 JOIN invocations i ON cp.invocation_id = i.invocation_id
253 WHERE cp.task_id = {task_p} AND ({where_pairs}){status_clause}
254 GROUP BY cp.invocation_id
255 HAVING COUNT(*) = {n_pairs}"
256 )
257 }
258 None => {
259 let task_p = format!("${idx}");
260 params.push(Box::new(task_key));
261 if statuses.is_empty() {
262 format!(
263 "SELECT invocation_id FROM invocations
264 WHERE task_id = {task_p}"
265 )
266 } else {
267 format!(
268 "SELECT i.invocation_id FROM invocations i
269 WHERE i.task_id = {task_p}{status_clause}"
270 )
271 }
272 }
273 };
274
275 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
276 .iter()
277 .map(|p| &**p as &(dyn tokio_postgres::types::ToSql + Sync))
278 .collect();
279 let rows = client.query(&query, ¶m_refs).await.map_err(pg_err)?;
280 Ok(rows
281 .iter()
282 .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
283 .collect())
284 }
285}