Skip to main content

rustvello_sqlite/orchestrator/
blocking.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::RustvelloResult;
6use rustvello_core::orchestrator::OrchestratorBlocking;
7use rustvello_proto::identifiers::InvocationId;
8
9use crate::db::{blocking, lock_err, sql_err};
10
11use super::SqliteOrchestrator;
12
13#[async_trait]
14impl OrchestratorBlocking for SqliteOrchestrator {
15    async fn set_waiting_for(
16        &self,
17        waiter: &InvocationId,
18        waited_on: &InvocationId,
19    ) -> RustvelloResult<()> {
20        let db = Arc::clone(&self.db);
21        let waiter = waiter.clone();
22        let waited_on = waited_on.clone();
23        blocking(move || {
24            let conn = db.conn.lock().map_err(lock_err)?;
25            conn.execute(
26                "INSERT OR REPLACE INTO waiting_for (waiter_id, waited_on_id) VALUES (?1, ?2)",
27                rusqlite::params![waiter.as_str(), waited_on.as_str()],
28            )
29            .map_err(sql_err)?;
30            Ok(())
31        })
32        .await
33    }
34
35    async fn get_waiters(&self, waited_on: &InvocationId) -> RustvelloResult<Vec<InvocationId>> {
36        let db = Arc::clone(&self.db);
37        let waited_on = waited_on.clone();
38        blocking(move || {
39            let conn = db.conn.lock().map_err(lock_err)?;
40
41            let mut stmt = conn
42                .prepare("SELECT waiter_id FROM waiting_for WHERE waited_on_id = ?1")
43                .map_err(sql_err)?;
44
45            let ids: Vec<InvocationId> = stmt
46                .query_map([waited_on.as_str()], |row| {
47                    let id: String = row.get(0)?;
48                    Ok(InvocationId::from_string(id))
49                })
50                .map_err(sql_err)?
51                .collect::<Result<Vec<_>, _>>()
52                .map_err(sql_err)?;
53
54            Ok(ids)
55        })
56        .await
57    }
58
59    async fn release_waiters(
60        &self,
61        completed: &InvocationId,
62    ) -> RustvelloResult<Vec<InvocationId>> {
63        let db = Arc::clone(&self.db);
64        let completed = completed.clone();
65        blocking(move || {
66            let conn = db.conn.lock().map_err(lock_err)?;
67            let tx = conn.unchecked_transaction().map_err(sql_err)?;
68
69            let mut stmt = tx
70                .prepare("SELECT waiter_id FROM waiting_for WHERE waited_on_id = ?1")
71                .map_err(sql_err)?;
72
73            let waiters: Vec<InvocationId> = stmt
74                .query_map([completed.as_str()], |row| {
75                    let id: String = row.get(0)?;
76                    Ok(InvocationId::from_string(id))
77                })
78                .map_err(sql_err)?
79                .collect::<Result<Vec<_>, _>>()
80                .map_err(sql_err)?;
81
82            drop(stmt);
83            tx.execute(
84                "DELETE FROM waiting_for WHERE waited_on_id = ?1",
85                [completed.as_str()],
86            )
87            .map_err(sql_err)?;
88
89            tx.commit().map_err(sql_err)?;
90            Ok(waiters)
91        })
92        .await
93    }
94}