rustvello_sqlite/orchestrator/
blocking.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::RustvelloResult;
6use rustvello_core::orchestrator::OrchestratorBlocking;
7use rustvello_proto::identifiers::InvocationId;
8
9use crate::db::{blocking, lock_err, sql_err};
10
11use super::SqliteOrchestrator;
12
13#[async_trait]
14impl OrchestratorBlocking for SqliteOrchestrator {
15 async fn set_waiting_for(
16 &self,
17 waiter: &InvocationId,
18 waited_on: &InvocationId,
19 ) -> RustvelloResult<()> {
20 let db = Arc::clone(&self.db);
21 let waiter = waiter.clone();
22 let waited_on = waited_on.clone();
23 blocking(move || {
24 let conn = db.conn.lock().map_err(lock_err)?;
25 conn.execute(
26 "INSERT OR REPLACE INTO waiting_for (waiter_id, waited_on_id) VALUES (?1, ?2)",
27 rusqlite::params![waiter.as_str(), waited_on.as_str()],
28 )
29 .map_err(sql_err)?;
30 Ok(())
31 })
32 .await
33 }
34
35 async fn get_waiters(&self, waited_on: &InvocationId) -> RustvelloResult<Vec<InvocationId>> {
36 let db = Arc::clone(&self.db);
37 let waited_on = waited_on.clone();
38 blocking(move || {
39 let conn = db.conn.lock().map_err(lock_err)?;
40
41 let mut stmt = conn
42 .prepare("SELECT waiter_id FROM waiting_for WHERE waited_on_id = ?1")
43 .map_err(sql_err)?;
44
45 let ids: Vec<InvocationId> = stmt
46 .query_map([waited_on.as_str()], |row| {
47 let id: String = row.get(0)?;
48 Ok(InvocationId::from_string(id))
49 })
50 .map_err(sql_err)?
51 .collect::<Result<Vec<_>, _>>()
52 .map_err(sql_err)?;
53
54 Ok(ids)
55 })
56 .await
57 }
58
59 async fn release_waiters(
60 &self,
61 completed: &InvocationId,
62 ) -> RustvelloResult<Vec<InvocationId>> {
63 let db = Arc::clone(&self.db);
64 let completed = completed.clone();
65 blocking(move || {
66 let conn = db.conn.lock().map_err(lock_err)?;
67 let tx = conn.unchecked_transaction().map_err(sql_err)?;
68
69 let mut stmt = tx
70 .prepare("SELECT waiter_id FROM waiting_for WHERE waited_on_id = ?1")
71 .map_err(sql_err)?;
72
73 let waiters: Vec<InvocationId> = stmt
74 .query_map([completed.as_str()], |row| {
75 let id: String = row.get(0)?;
76 Ok(InvocationId::from_string(id))
77 })
78 .map_err(sql_err)?
79 .collect::<Result<Vec<_>, _>>()
80 .map_err(sql_err)?;
81
82 drop(stmt);
83 tx.execute(
84 "DELETE FROM waiting_for WHERE waited_on_id = ?1",
85 [completed.as_str()],
86 )
87 .map_err(sql_err)?;
88
89 tx.commit().map_err(sql_err)?;
90 Ok(waiters)
91 })
92 .await
93 }
94}