Skip to main content

rustvello_sqlite/
broker.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::broker::Broker;
6use rustvello_core::error::RustvelloResult;
7use rustvello_proto::identifiers::{InvocationId, TaskId};
8
9use crate::db::{blocking, lock_err, sql_err, Database};
10
11/// SQLite-backed broker implementation.
12///
13/// Persists the queue to a SQLite database, surviving process restarts.
14pub struct SqliteBroker {
15    db: Arc<Database>,
16}
17
18impl SqliteBroker {
19    pub fn new(db: Arc<Database>) -> Self {
20        Self { db }
21    }
22}
23
24#[async_trait]
25impl Broker for SqliteBroker {
26    async fn route_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<()> {
27        let db = Arc::clone(&self.db);
28        let id = invocation_id.clone();
29        blocking(move || {
30            let conn = db.conn.lock().map_err(lock_err)?;
31            conn.execute(
32                "INSERT INTO broker_queue (invocation_id) VALUES (?1)",
33                [id.as_str()],
34            )
35            .map_err(sql_err)?;
36            Ok(())
37        })
38        .await
39    }
40
41    async fn retrieve_invocation(
42        &self,
43        task_id: Option<&TaskId>,
44    ) -> RustvelloResult<Option<InvocationId>> {
45        let db = Arc::clone(&self.db);
46        let task_id = task_id.cloned();
47        blocking(move || {
48            let conn = db.conn.lock().map_err(lock_err)?;
49
50            let tx = conn.unchecked_transaction().map_err(sql_err)?;
51
52            let result: Option<(i64, String)> = if let Some(ref tid) = task_id {
53                tx.query_row(
54                    "SELECT bq.id, bq.invocation_id FROM broker_queue bq \
55                     JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
56                     WHERE inv.task_id = ?1 \
57                     ORDER BY bq.id ASC LIMIT 1",
58                    [&tid.to_string()],
59                    |row| Ok((row.get(0)?, row.get(1)?)),
60                )
61                .ok()
62            } else {
63                tx.query_row(
64                    "SELECT id, invocation_id FROM broker_queue ORDER BY id ASC LIMIT 1",
65                    [],
66                    |row| Ok((row.get(0)?, row.get(1)?)),
67                )
68                .ok()
69            };
70
71            if let Some((row_id, inv_id)) = result {
72                tx.execute("DELETE FROM broker_queue WHERE id = ?1", [row_id])
73                    .map_err(sql_err)?;
74                tx.commit().map_err(sql_err)?;
75                Ok(Some(InvocationId::from_string(inv_id)))
76            } else {
77                Ok(None)
78            }
79        })
80        .await
81    }
82
83    async fn count_invocations(&self, task_id: Option<&TaskId>) -> RustvelloResult<usize> {
84        let db = Arc::clone(&self.db);
85        let task_id = task_id.cloned();
86        blocking(move || {
87            let conn = db.conn.lock().map_err(lock_err)?;
88            let count: i64 = if let Some(ref tid) = task_id {
89                conn.query_row(
90                    "SELECT COUNT(*) FROM broker_queue bq \
91                     JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
92                     WHERE inv.task_id = ?1",
93                    [&tid.to_string()],
94                    |row| row.get(0),
95                )
96                .map_err(sql_err)?
97            } else {
98                conn.query_row("SELECT COUNT(*) FROM broker_queue", [], |row| row.get(0))
99                    .map_err(sql_err)?
100            };
101            Ok(count as usize)
102        })
103        .await
104    }
105
106    async fn purge(&self, task_id: Option<&TaskId>) -> RustvelloResult<()> {
107        let db = Arc::clone(&self.db);
108        let task_id = task_id.cloned();
109        blocking(move || {
110            let conn = db.conn.lock().map_err(lock_err)?;
111            if let Some(ref tid) = task_id {
112                conn.execute(
113                    "DELETE FROM broker_queue WHERE invocation_id IN (\
114                     SELECT bq.invocation_id FROM broker_queue bq \
115                     JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
116                     WHERE inv.task_id = ?1)",
117                    [&tid.to_string()],
118                )
119                .map_err(sql_err)?;
120            } else {
121                conn.execute("DELETE FROM broker_queue", [])
122                    .map_err(sql_err)?;
123            }
124            Ok(())
125        })
126        .await
127    }
128
129    async fn retrieve_invocation_for_language(
130        &self,
131        language: &str,
132    ) -> RustvelloResult<Option<InvocationId>> {
133        let db = Arc::clone(&self.db);
134        let language = language.to_owned();
135        blocking(move || {
136            let conn = db.conn.lock().map_err(lock_err)?;
137            let tx = conn.unchecked_transaction().map_err(sql_err)?;
138
139            // First check the global queue: items without an invocations table entry
140            // (routed via route_invocation without task context).
141            let global: Option<(i64, String)> = tx
142                .query_row(
143                    "SELECT bq.id, bq.invocation_id FROM broker_queue bq \
144                     LEFT JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
145                     WHERE inv.invocation_id IS NULL \
146                     ORDER BY bq.id ASC LIMIT 1",
147                    [],
148                    |row| Ok((row.get(0)?, row.get(1)?)),
149                )
150                .ok();
151
152            let result = if global.is_some() {
153                global
154            } else {
155                // Fall back to language-specific items.
156                let prefix = format!("{language}::");
157                tx.query_row(
158                    "SELECT bq.id, bq.invocation_id FROM broker_queue bq \
159                     JOIN invocations inv ON bq.invocation_id = inv.invocation_id \
160                     WHERE inv.task_id LIKE ?1 || '%' \
161                     ORDER BY bq.id ASC LIMIT 1",
162                    [&prefix],
163                    |row| Ok((row.get(0)?, row.get(1)?)),
164                )
165                .ok()
166            };
167
168            if let Some((row_id, inv_id)) = result {
169                tx.execute("DELETE FROM broker_queue WHERE id = ?1", [row_id])
170                    .map_err(sql_err)?;
171                tx.commit().map_err(sql_err)?;
172                Ok(Some(InvocationId::from_string(inv_id)))
173            } else {
174                Ok(None)
175            }
176        })
177        .await
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    fn make_broker() -> SqliteBroker {
186        let db = Arc::new(Database::in_memory().unwrap());
187        SqliteBroker::new(db)
188    }
189
190    #[tokio::test]
191    async fn test_route_and_retrieve() {
192        let broker = make_broker();
193        let id1 = InvocationId::new();
194        let id2 = InvocationId::new();
195
196        broker.route_invocation(&id1).await.unwrap();
197        broker.route_invocation(&id2).await.unwrap();
198
199        assert_eq!(broker.count_invocations(None).await.unwrap(), 2);
200
201        let r1 = broker.retrieve_invocation(None).await.unwrap();
202        assert_eq!(r1.unwrap().as_str(), id1.as_str());
203
204        let r2 = broker.retrieve_invocation(None).await.unwrap();
205        assert_eq!(r2.unwrap().as_str(), id2.as_str());
206
207        assert!(broker.retrieve_invocation(None).await.unwrap().is_none());
208    }
209
210    #[tokio::test]
211    async fn test_purge() {
212        let broker = make_broker();
213        broker.route_invocation(&InvocationId::new()).await.unwrap();
214        broker.route_invocation(&InvocationId::new()).await.unwrap();
215
216        broker.purge(None).await.unwrap();
217        assert_eq!(broker.count_invocations(None).await.unwrap(), 0);
218    }
219}