Skip to main content

rustvello_postgres/orchestrator/
blocking.rs

1use async_trait::async_trait;
2
3use rustvello_core::error::RustvelloResult;
4use rustvello_core::orchestrator::OrchestratorBlocking;
5use rustvello_proto::identifiers::InvocationId;
6
7use super::PostgresOrchestrator;
8use crate::db::pg_err;
9
10#[async_trait]
11impl OrchestratorBlocking for PostgresOrchestrator {
12    async fn set_waiting_for(
13        &self,
14        waiter: &InvocationId,
15        waited_on: &InvocationId,
16    ) -> RustvelloResult<()> {
17        let client = self.db.conn().await?;
18        client
19            .execute(
20                "INSERT INTO waiting_for (waiter_id, waited_on_id) VALUES ($1, $2)
21                 ON CONFLICT DO NOTHING",
22                &[&waiter.as_str(), &waited_on.as_str()],
23            )
24            .await
25            .map_err(pg_err)?;
26        Ok(())
27    }
28
29    async fn get_waiters(&self, waited_on: &InvocationId) -> RustvelloResult<Vec<InvocationId>> {
30        let client = self.db.conn().await?;
31
32        let rows = client
33            .query(
34                "SELECT waiter_id FROM waiting_for WHERE waited_on_id = $1",
35                &[&waited_on.as_str()],
36            )
37            .await
38            .map_err(pg_err)?;
39
40        Ok(rows
41            .iter()
42            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
43            .collect())
44    }
45
46    async fn release_waiters(
47        &self,
48        completed: &InvocationId,
49    ) -> RustvelloResult<Vec<InvocationId>> {
50        let client = self.db.conn().await?;
51
52        // Atomically select and delete using a CTE.
53        let rows = client
54            .query(
55                "DELETE FROM waiting_for WHERE waited_on_id = $1 RETURNING waiter_id",
56                &[&completed.as_str()],
57            )
58            .await
59            .map_err(pg_err)?;
60
61        Ok(rows
62            .iter()
63            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
64            .collect())
65    }
66}