Skip to main content

rustvello_postgres/orchestrator/
status.rs

1//! PostgreSQL-backed [`OrchestratorStatus`] implementation.
2
3use async_trait::async_trait;
4use chrono::Utc;
5
6use rustvello_core::error::{RustvelloError, RustvelloResult};
7use rustvello_core::orchestrator::OrchestratorStatus;
8use rustvello_proto::call::CallDTO;
9use rustvello_proto::identifiers::{InvocationId, RunnerId};
10use rustvello_proto::status::{InvocationStatus, InvocationStatusRecord};
11
12use super::PostgresOrchestrator;
13use crate::db::{parse_status, pg_err};
14
15#[async_trait]
16impl OrchestratorStatus for PostgresOrchestrator {
17    async fn register_invocation(&self, call: &CallDTO) -> RustvelloResult<InvocationId> {
18        let invocation_id = InvocationId::new();
19        let now = Utc::now();
20        let status_str = InvocationStatus::Registered.to_string();
21        let task_id_str = call.task_id.to_string();
22        let call_id_str = call.call_id.to_string();
23
24        let mut client = self.db.conn().await?;
25        let tx = client.transaction().await.map_err(pg_err)?;
26
27        tx.execute(
28                "INSERT INTO invocations (invocation_id, task_id, call_id, status, created_at, updated_at)
29                 VALUES ($1, $2, $3, $4, $5, $6)",
30                &[
31                    &invocation_id.as_str(),
32                    &task_id_str,
33                    &call_id_str,
34                    &status_str,
35                    &now,
36                    &now,
37                ],
38            )
39            .await
40            .map_err(pg_err)?;
41
42        tx.execute(
43            "INSERT INTO status_records (invocation_id, status, runner_id, timestamp)
44                 VALUES ($1, $2, NULL, $3)
45                 ON CONFLICT (invocation_id) DO UPDATE SET status = $2, timestamp = $3",
46            &[&invocation_id.as_str(), &status_str, &now],
47        )
48        .await
49        .map_err(pg_err)?;
50
51        tx.commit().await.map_err(pg_err)?;
52
53        Ok(invocation_id)
54    }
55
56    async fn get_invocation_status(
57        &self,
58        invocation_id: &InvocationId,
59    ) -> RustvelloResult<InvocationStatusRecord> {
60        let client = self.db.conn().await?;
61
62        let row = client
63            .query_opt(
64                "SELECT status, runner_id, timestamp FROM status_records WHERE invocation_id = $1",
65                &[&invocation_id.as_str()],
66            )
67            .await
68            .map_err(pg_err)?
69            .ok_or_else(|| RustvelloError::InvocationNotFound {
70                invocation_id: invocation_id.clone(),
71            })?;
72
73        let status_str: String = row.get(0);
74        let runner_id_opt: Option<String> = row.get(1);
75        let timestamp: chrono::DateTime<Utc> = row.get(2);
76
77        Ok(InvocationStatusRecord {
78            status: parse_status(&status_str)?,
79            runner_id: runner_id_opt.map(RunnerId::from_string),
80            timestamp,
81        })
82    }
83
84    async fn set_invocation_status(
85        &self,
86        invocation_id: &InvocationId,
87        status: InvocationStatus,
88        runner_id: Option<&RunnerId>,
89    ) -> RustvelloResult<InvocationStatusRecord> {
90        use rustvello_proto::status::status_record_transition;
91
92        let mut client = self.db.conn().await?;
93
94        // Use a transaction to atomically check-and-set status.
95        let tx = client.transaction().await.map_err(pg_err)?;
96
97        let row = tx
98            .query_opt(
99                "SELECT status, runner_id, timestamp FROM status_records WHERE invocation_id = $1 FOR UPDATE",
100                &[&invocation_id.as_str()],
101            )
102            .await
103            .map_err(pg_err)?
104            .ok_or_else(|| RustvelloError::InvocationNotFound { invocation_id: invocation_id.clone() })?;
105
106        let current_status_str: String = row.get(0);
107        let current_runner_id_str: Option<String> = row.get(1);
108        let current_ts: chrono::DateTime<Utc> = row.get(2);
109        let current_status = parse_status(&current_status_str)?;
110        let current_record = InvocationStatusRecord {
111            status: current_status,
112            runner_id: current_runner_id_str.map(RunnerId::from_string),
113            timestamp: current_ts,
114        };
115
116        let new_record = status_record_transition(Some(&current_record), status, runner_id)
117            .map_err(|e| {
118                rustvello_core::error::status_machine_error_to_rustvello(
119                    e,
120                    invocation_id,
121                    current_status,
122                )
123            })?;
124
125        let status_str = status.to_string();
126        let runner_id_str = new_record
127            .runner_id
128            .as_ref()
129            .map(|r| r.as_str().to_string());
130
131        tx.execute(
132            "UPDATE status_records SET status = $1, runner_id = $2, timestamp = $3 WHERE invocation_id = $4",
133            &[&status_str, &runner_id_str as &(dyn tokio_postgres::types::ToSql + Sync), &new_record.timestamp, &invocation_id.as_str()],
134        )
135        .await
136        .map_err(pg_err)?;
137
138        tx.execute(
139            "UPDATE invocations SET status = $1, updated_at = $2 WHERE invocation_id = $3",
140            &[&status_str, &new_record.timestamp, &invocation_id.as_str()],
141        )
142        .await
143        .map_err(pg_err)?;
144
145        tx.commit().await.map_err(pg_err)?;
146
147        Ok(new_record)
148    }
149
150    async fn register_invocation_with_id(
151        &self,
152        invocation_id: &InvocationId,
153        call: &CallDTO,
154        runner_id: Option<&RunnerId>,
155    ) -> RustvelloResult<InvocationStatusRecord> {
156        let now = Utc::now();
157        let status = InvocationStatus::Registered;
158        let status_str = status.to_string();
159        let task_id_str = call.task_id.to_string();
160        let call_id_str = call.call_id.to_string();
161        let runner_id_str = runner_id.map(|r| r.as_str().to_string());
162
163        let mut client = self.db.conn().await?;
164        let tx = client.transaction().await.map_err(pg_err)?;
165
166        tx.execute(
167            "INSERT INTO invocations (invocation_id, task_id, call_id, status, created_at, updated_at)
168             VALUES ($1, $2, $3, $4, $5, $6)
169             ON CONFLICT (invocation_id) DO NOTHING",
170            &[&invocation_id.as_str(), &task_id_str, &call_id_str, &status_str, &now, &now],
171        )
172        .await
173        .map_err(pg_err)?;
174
175        tx.execute(
176            "INSERT INTO status_records (invocation_id, status, runner_id, timestamp)
177             VALUES ($1, $2, $3, $4)
178             ON CONFLICT (invocation_id) DO UPDATE SET status = $2, runner_id = $3, timestamp = $4",
179            &[
180                &invocation_id.as_str(),
181                &status_str,
182                &runner_id_str as &(dyn tokio_postgres::types::ToSql + Sync),
183                &now,
184            ],
185        )
186        .await
187        .map_err(pg_err)?;
188
189        tx.commit().await.map_err(pg_err)?;
190
191        Ok(InvocationStatusRecord {
192            status,
193            runner_id: runner_id.cloned(),
194            timestamp: now,
195        })
196    }
197
198    async fn increment_invocation_retries(
199        &self,
200        invocation_id: &InvocationId,
201    ) -> RustvelloResult<u32> {
202        let client = self.db.conn().await?;
203        let row = client
204            .query_one(
205                "INSERT INTO retries (invocation_id, count) VALUES ($1, 1)
206                 ON CONFLICT (invocation_id) DO UPDATE SET count = retries.count + 1
207                 RETURNING count",
208                &[&invocation_id.as_str()],
209            )
210            .await
211            .map_err(pg_err)?;
212        let count: i32 = row.get(0);
213        Ok(u32::try_from(count).unwrap_or(0))
214    }
215
216    async fn get_invocation_retries(&self, invocation_id: &InvocationId) -> RustvelloResult<u32> {
217        let client = self.db.conn().await?;
218        let row = client
219            .query_opt(
220                "SELECT count FROM retries WHERE invocation_id = $1",
221                &[&invocation_id.as_str()],
222            )
223            .await
224            .map_err(pg_err)?;
225        Ok(row.map_or(0, |r| u32::try_from(r.get::<_, i32>(0)).unwrap_or(0)))
226    }
227
228    async fn remove_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<()> {
229        let mut client = self.db.conn().await?;
230        let tx = client.transaction().await.map_err(pg_err)?;
231        tx.execute(
232            "DELETE FROM cc_arg_pairs WHERE invocation_id = $1",
233            &[&invocation_id.as_str()],
234        )
235        .await
236        .map_err(pg_err)?;
237        tx.execute(
238            "DELETE FROM waiting_for WHERE waiter_id = $1 OR waited_on_id = $1",
239            &[&invocation_id.as_str()],
240        )
241        .await
242        .map_err(pg_err)?;
243        tx.execute(
244            "DELETE FROM retries WHERE invocation_id = $1",
245            &[&invocation_id.as_str()],
246        )
247        .await
248        .map_err(pg_err)?;
249        tx.execute(
250            "DELETE FROM status_records WHERE invocation_id = $1",
251            &[&invocation_id.as_str()],
252        )
253        .await
254        .map_err(pg_err)?;
255        tx.execute(
256            "DELETE FROM invocations WHERE invocation_id = $1",
257            &[&invocation_id.as_str()],
258        )
259        .await
260        .map_err(pg_err)?;
261        tx.commit().await.map_err(pg_err)?;
262        Ok(())
263    }
264
265    async fn purge(&self) -> RustvelloResult<()> {
266        let mut client = self.db.conn().await?;
267        let tx = client.transaction().await.map_err(pg_err)?;
268        tx.execute("DELETE FROM cc_arg_pairs", &[])
269            .await
270            .map_err(pg_err)?;
271        tx.execute("DELETE FROM waiting_for", &[])
272            .await
273            .map_err(pg_err)?;
274        tx.execute("DELETE FROM retries", &[])
275            .await
276            .map_err(pg_err)?;
277        tx.execute("DELETE FROM runner_heartbeats", &[])
278            .await
279            .map_err(pg_err)?;
280        tx.execute("DELETE FROM status_records", &[])
281            .await
282            .map_err(pg_err)?;
283        tx.execute("DELETE FROM invocations", &[])
284            .await
285            .map_err(pg_err)?;
286        tx.commit().await.map_err(pg_err)?;
287        Ok(())
288    }
289
290    async fn schedule_auto_purge(&self, _invocation_id: &InvocationId) -> RustvelloResult<()> {
291        Err(RustvelloError::NotSupported {
292            backend: "Postgres".into(),
293            method: "schedule_auto_purge".into(),
294        })
295    }
296
297    async fn run_auto_purge(&self, _max_age_secs: u64) -> RustvelloResult<Vec<InvocationId>> {
298        Err(RustvelloError::NotSupported {
299            backend: "Postgres".into(),
300            method: "run_auto_purge".into(),
301        })
302    }
303}