Skip to main content

rustvello_postgres/orchestrator/
query.rs

1use async_trait::async_trait;
2
3use rustvello_core::error::RustvelloResult;
4use rustvello_core::orchestrator::OrchestratorQuery;
5use rustvello_proto::call::SerializedArguments;
6use rustvello_proto::identifiers::{CallId, InvocationId, TaskId};
7use rustvello_proto::status::InvocationStatus;
8
9use super::PostgresOrchestrator;
10use crate::db::pg_err;
11
12#[async_trait]
13impl OrchestratorQuery for PostgresOrchestrator {
14    async fn get_invocations_by_task(
15        &self,
16        task_id: &TaskId,
17    ) -> RustvelloResult<Vec<InvocationId>> {
18        let client = self.db.conn().await?;
19        let task_id_str = task_id.to_string();
20
21        let rows = client
22            .query(
23                "SELECT invocation_id FROM invocations WHERE task_id = $1",
24                &[&task_id_str],
25            )
26            .await
27            .map_err(pg_err)?;
28
29        Ok(rows
30            .iter()
31            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
32            .collect())
33    }
34
35    async fn get_invocations_by_call(
36        &self,
37        call_id: &CallId,
38    ) -> RustvelloResult<Vec<InvocationId>> {
39        let client = self.db.conn().await?;
40        let call_id_str = call_id.to_string();
41
42        let rows = client
43            .query(
44                "SELECT invocation_id FROM invocations WHERE call_id = $1",
45                &[&call_id_str],
46            )
47            .await
48            .map_err(pg_err)?;
49
50        Ok(rows
51            .iter()
52            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
53            .collect())
54    }
55
56    async fn get_invocations_by_status(
57        &self,
58        status: InvocationStatus,
59        task_id: Option<&TaskId>,
60    ) -> RustvelloResult<Vec<InvocationId>> {
61        let client = self.db.conn().await?;
62        let status_str = status.to_string();
63
64        let rows = if let Some(tid) = task_id {
65            let task_id_str = tid.to_string();
66            client
67                .query(
68                    "SELECT invocation_id FROM invocations WHERE status = $1 AND task_id = $2",
69                    &[&status_str, &task_id_str],
70                )
71                .await
72                .map_err(pg_err)?
73        } else {
74            client
75                .query(
76                    "SELECT invocation_id FROM invocations WHERE status = $1",
77                    &[&status_str],
78                )
79                .await
80                .map_err(pg_err)?
81        };
82
83        Ok(rows
84            .iter()
85            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
86            .collect())
87    }
88
89    async fn count_invocations(
90        &self,
91        task_id: Option<&TaskId>,
92        statuses: Option<&[InvocationStatus]>,
93    ) -> RustvelloResult<usize> {
94        let client = self.db.conn().await?;
95        let mut query = "SELECT COUNT(*) FROM invocations WHERE 1=1".to_string();
96        let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
97        let mut idx = 1;
98
99        if let Some(tid) = task_id {
100            query.push_str(&format!(" AND task_id = ${idx}"));
101            params.push(Box::new(tid.to_string()));
102            idx += 1;
103        }
104        if let Some(statuses) = statuses {
105            if !statuses.is_empty() {
106                let placeholders: Vec<String> = statuses
107                    .iter()
108                    .map(|s| {
109                        let p = format!("${idx}");
110                        params.push(Box::new(s.to_string()));
111                        idx += 1;
112                        p
113                    })
114                    .collect();
115                query.push_str(&format!(" AND status IN ({})", placeholders.join(",")));
116            }
117        }
118
119        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
120            .iter()
121            .map(|p| &**p as &(dyn tokio_postgres::types::ToSql + Sync))
122            .collect();
123        let row = client
124            .query_one(&query, &param_refs)
125            .await
126            .map_err(pg_err)?;
127        let count: i64 = row.get(0);
128        Ok(count as usize)
129    }
130
131    async fn get_invocation_ids_paginated(
132        &self,
133        task_id: Option<&TaskId>,
134        statuses: Option<&[InvocationStatus]>,
135        limit: usize,
136        offset: usize,
137    ) -> RustvelloResult<Vec<InvocationId>> {
138        let client = self.db.conn().await?;
139        let mut query = "SELECT invocation_id FROM invocations WHERE 1=1".to_string();
140        let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
141        let mut idx = 1;
142
143        if let Some(tid) = task_id {
144            query.push_str(&format!(" AND task_id = ${idx}"));
145            params.push(Box::new(tid.to_string()));
146            idx += 1;
147        }
148        if let Some(statuses) = statuses {
149            if !statuses.is_empty() {
150                let placeholders: Vec<String> = statuses
151                    .iter()
152                    .map(|s| {
153                        let p = format!("${idx}");
154                        params.push(Box::new(s.to_string()));
155                        idx += 1;
156                        p
157                    })
158                    .collect();
159                query.push_str(&format!(" AND status IN ({})", placeholders.join(",")));
160            }
161        }
162        query.push_str(&format!(
163            " ORDER BY created_at LIMIT ${idx} OFFSET ${}",
164            idx + 1
165        ));
166        params.push(Box::new(limit as i64));
167        params.push(Box::new(offset as i64));
168
169        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
170            .iter()
171            .map(|p| &**p as &(dyn tokio_postgres::types::ToSql + Sync))
172            .collect();
173        let rows = client.query(&query, &param_refs).await.map_err(pg_err)?;
174        Ok(rows
175            .iter()
176            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
177            .collect())
178    }
179
180    async fn get_blocking_invocations(&self, max_num: usize) -> RustvelloResult<Vec<InvocationId>> {
181        let client = self.db.conn().await?;
182        let rows = client
183            .query(
184                "SELECT DISTINCT w.waited_on_id FROM waiting_for w
185                 JOIN status_records sr ON w.waited_on_id = sr.invocation_id
186                 WHERE sr.status IN ('REGISTERED', 'PENDING', 'RUNNING')
187                   AND NOT EXISTS (
188                       SELECT 1 FROM waiting_for w2
189                       WHERE w2.waiter_id = w.waited_on_id
190                   )
191                 LIMIT $1",
192                &[&(max_num as i64)],
193            )
194            .await
195            .map_err(pg_err)?;
196        Ok(rows
197            .iter()
198            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
199            .collect())
200    }
201
202    async fn get_existing_invocations(
203        &self,
204        task_id: &TaskId,
205        cc_args: Option<&SerializedArguments>,
206        statuses: &[InvocationStatus],
207    ) -> RustvelloResult<Vec<InvocationId>> {
208        let client = self.db.conn().await?;
209        let task_key = task_id.to_string();
210
211        let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
212        let mut idx = 1;
213
214        // Empty statuses means "no filter — return all" (matches mem/sqlite).
215        let status_clause = if statuses.is_empty() {
216            String::new()
217        } else {
218            let placeholders: Vec<String> = statuses
219                .iter()
220                .map(|s| {
221                    let p = format!("${idx}");
222                    params.push(Box::new(s.to_string()));
223                    idx += 1;
224                    p
225                })
226                .collect();
227            format!(" AND i.status IN ({})", placeholders.join(","))
228        };
229
230        let query = match cc_args {
231            Some(args) => {
232                let pairs = args.cc_arg_pairs();
233                let n_pairs = pairs.len();
234                let task_p = format!("${idx}");
235                params.push(Box::new(task_key));
236                idx += 1;
237                let pair_conds: Vec<String> = pairs
238                    .iter()
239                    .map(|(k, v)| {
240                        let kp = format!("${idx}");
241                        params.push(Box::new(k.clone()));
242                        idx += 1;
243                        let vp = format!("${idx}");
244                        params.push(Box::new(v.clone()));
245                        idx += 1;
246                        format!("(cp.arg_key = {kp} AND cp.arg_value = {vp})")
247                    })
248                    .collect();
249                let where_pairs = pair_conds.join(" OR ");
250                format!(
251                    "SELECT cp.invocation_id FROM cc_arg_pairs cp
252                     JOIN invocations i ON cp.invocation_id = i.invocation_id
253                     WHERE cp.task_id = {task_p} AND ({where_pairs}){status_clause}
254                     GROUP BY cp.invocation_id
255                     HAVING COUNT(*) = {n_pairs}"
256                )
257            }
258            None => {
259                let task_p = format!("${idx}");
260                params.push(Box::new(task_key));
261                if statuses.is_empty() {
262                    format!(
263                        "SELECT invocation_id FROM invocations
264                         WHERE task_id = {task_p}"
265                    )
266                } else {
267                    format!(
268                        "SELECT i.invocation_id FROM invocations i
269                         WHERE i.task_id = {task_p}{status_clause}"
270                    )
271                }
272            }
273        };
274
275        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
276            .iter()
277            .map(|p| &**p as &(dyn tokio_postgres::types::ToSql + Sync))
278            .collect();
279        let rows = client.query(&query, &param_refs).await.map_err(pg_err)?;
280        Ok(rows
281            .iter()
282            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
283            .collect())
284    }
285}