Skip to main content

rustvello_mem/
broker.rs

1use std::collections::{BTreeMap, VecDeque};
2use tokio::sync::Mutex;
3use tokio_util::sync::CancellationToken;
4
5use async_trait::async_trait;
6use tracing::instrument;
7
8use rustvello_core::broker::Broker;
9use rustvello_core::error::RustvelloResult;
10use rustvello_proto::identifiers::{InvocationId, TaskId};
11/// In-memory broker with a global queue and per-task queues.
12///
13/// Not suitable for production — all data is lost on process exit.
14/// Useful for unit tests and local development.
15///
16/// # Queue semantics
17///
18/// - [`route_invocation`]: pushes to the global queue (task ID unknown at call site).
19/// - [`route_invocation_for_task`]: pushes to a task-specific queue; used by callers
20///   that know the task ID (e.g. `RustvelloApp::submit_call`).
21/// - [`retrieve_invocation`] with `None`: drains the global queue first, then falls
22///   back to any non-empty task queue (round-robin); ensures that invocations routed
23///   via the task-aware path are also visible to runners that poll without a filter.
24/// - [`retrieve_invocation`] with `Some(task_id)`: drains only the task-specific queue.
25///
26/// # Notify-based wakeup
27///
28/// Workers can call [`wait_for_work`] instead of polling with sleep.
29/// When new work is routed, one waiting worker is woken via `tokio::sync::Notify`.
30pub struct MemBroker {
31    /// Queues keyed by queue name.
32    /// GLOBAL_QUEUE is used for invocations routed without a task_id.
33    /// Each TaskId string maps to its own per-task queue.
34    queues: Mutex<BTreeMap<String, VecDeque<InvocationId>>>,
35    /// Notification channel for waking idle workers.
36    notify: tokio::sync::Notify,
37}
38
39const GLOBAL_QUEUE: &str = "__global__";
40
41impl MemBroker {
42    pub fn new() -> Self {
43        Self {
44            queues: Mutex::new(BTreeMap::new()),
45            notify: tokio::sync::Notify::new(),
46        }
47    }
48}
49
50impl Default for MemBroker {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56#[async_trait]
57impl Broker for MemBroker {
58    /// Route to the global queue (task ID unknown at this call site).
59    #[instrument(skip(self), fields(%invocation_id))]
60    async fn route_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<()> {
61        let mut queues = self.queues.lock().await;
62        queues
63            .entry(GLOBAL_QUEUE.to_owned())
64            .or_default()
65            .push_back(invocation_id.clone());
66        drop(queues);
67        self.notify.notify_one();
68        Ok(())
69    }
70
71    /// Route to the task-specific queue.
72    ///
73    /// Callers that know the task ID should prefer this over `route_invocation`
74    /// so that `retrieve_invocation(Some(task_id))` can return a filtered result.
75    #[instrument(skip(self), fields(%invocation_id, %task_id))]
76    async fn route_invocation_for_task(
77        &self,
78        invocation_id: &InvocationId,
79        task_id: &TaskId,
80    ) -> RustvelloResult<()> {
81        let mut queues = self.queues.lock().await;
82        queues
83            .entry(task_id.to_string())
84            .or_default()
85            .push_back(invocation_id.clone());
86        drop(queues);
87        self.notify.notify_one();
88        Ok(())
89    }
90
91    #[instrument(skip(self))]
92    async fn retrieve_invocation(
93        &self,
94        task_id: Option<&TaskId>,
95    ) -> RustvelloResult<Option<InvocationId>> {
96        let mut queues = self.queues.lock().await;
97        if let Some(tid) = task_id {
98            // Task-filtered retrieval: pop from the task-specific queue only.
99            return Ok(queues
100                .get_mut(&tid.to_string())
101                .and_then(VecDeque::pop_front));
102        }
103        // Global retrieval: drain global queue first, then any task queue.
104        if let Some(id) = queues.get_mut(GLOBAL_QUEUE).and_then(VecDeque::pop_front) {
105            return Ok(Some(id));
106        }
107        // Fall back to the first non-empty task queue (in iteration order).
108        for (key, queue) in queues.iter_mut() {
109            if key == GLOBAL_QUEUE {
110                continue;
111            }
112            if let Some(id) = queue.pop_front() {
113                return Ok(Some(id));
114            }
115        }
116        Ok(None)
117    }
118
119    /// Retrieve from queues matching a specific language.
120    ///
121    /// **Behavior:** First checks the global queue, then per-task queues
122    /// whose keys start with `"language::"`. Because the global queue is
123    /// checked first, a single-language worker can drain globally-routed
124    /// invocations before language-agnostic workers see them.
125    ///
126    /// Queue keys for foreign tasks use the format `"language::module.name"`,
127    /// so we match keys that start with `"language::"`. For local tasks
128    /// (no language prefix), they are only retrieved if `language` is empty.
129    async fn retrieve_invocation_for_language(
130        &self,
131        language: &str,
132    ) -> RustvelloResult<Option<InvocationId>> {
133        let mut queues = self.queues.lock().await;
134        // First check the global queue (serves all languages).
135        if let Some(id) = queues.get_mut(GLOBAL_QUEUE).and_then(VecDeque::pop_front) {
136            return Ok(Some(id));
137        }
138        let prefix = format!("{language}::");
139        for (key, queue) in queues.iter_mut() {
140            if key == GLOBAL_QUEUE {
141                continue;
142            }
143            // Match: foreign task keys start with "language::"; local keys have no "::"
144            let matches = if language.is_empty() {
145                !key.contains("::")
146            } else {
147                key.starts_with(&prefix)
148            };
149            if matches {
150                if let Some(id) = queue.pop_front() {
151                    return Ok(Some(id));
152                }
153            }
154        }
155        Ok(None)
156    }
157
158    async fn count_invocations(&self, task_id: Option<&TaskId>) -> RustvelloResult<usize> {
159        let queues = self.queues.lock().await;
160        if let Some(tid) = task_id {
161            return Ok(queues.get(&tid.to_string()).map_or(0, VecDeque::len));
162        }
163        Ok(queues.values().map(VecDeque::len).sum())
164    }
165
166    async fn purge(&self, task_id: Option<&TaskId>) -> RustvelloResult<()> {
167        let mut queues = self.queues.lock().await;
168        if let Some(tid) = task_id {
169            queues.remove(&tid.to_string());
170            return Ok(());
171        }
172        queues.clear();
173        Ok(())
174    }
175
176    /// Batch retrieval: single lock acquisition drains up to `max` items.
177    async fn retrieve_invocations(
178        &self,
179        max: usize,
180        task_id: Option<&TaskId>,
181    ) -> RustvelloResult<Vec<InvocationId>> {
182        let mut queues = self.queues.lock().await;
183        let capped = max.min(10_000);
184        let mut results = Vec::with_capacity(capped);
185        for _ in 0..capped {
186            let item = if let Some(tid) = task_id {
187                queues
188                    .get_mut(&tid.to_string())
189                    .and_then(VecDeque::pop_front)
190            } else {
191                // Global first, then any task queue
192                let global = queues.get_mut(GLOBAL_QUEUE).and_then(VecDeque::pop_front);
193                if global.is_some() {
194                    global
195                } else {
196                    let mut found = None;
197                    for (key, queue) in queues.iter_mut() {
198                        if key == GLOBAL_QUEUE {
199                            continue;
200                        }
201                        if let Some(id) = queue.pop_front() {
202                            found = Some(id);
203                            break;
204                        }
205                    }
206                    found
207                }
208            };
209            match item {
210                Some(id) => results.push(id),
211                None => break,
212            }
213        }
214        Ok(results)
215    }
216
217    /// Zero-cost wait: blocks until new work is routed or cancelled.
218    async fn wait_for_work(&self, cancel: &CancellationToken) -> bool {
219        tokio::select! {
220            _ = cancel.cancelled() => false,
221            _ = self.notify.notified() => true,
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[tokio::test]
231    async fn test_route_and_retrieve() {
232        let broker = MemBroker::new();
233        let id1 = InvocationId::new();
234        let id2 = InvocationId::new();
235
236        broker.route_invocation(&id1).await.unwrap();
237        broker.route_invocation(&id2).await.unwrap();
238
239        assert_eq!(broker.count_invocations(None).await.unwrap(), 2);
240
241        let retrieved1 = broker.retrieve_invocation(None).await.unwrap();
242        assert_eq!(retrieved1, Some(id1));
243
244        let retrieved2 = broker.retrieve_invocation(None).await.unwrap();
245        assert_eq!(retrieved2, Some(id2));
246
247        let retrieved3 = broker.retrieve_invocation(None).await.unwrap();
248        assert_eq!(retrieved3, None);
249    }
250
251    #[tokio::test]
252    async fn test_per_task_routing() {
253        let broker = MemBroker::new();
254        let task_a = TaskId::new("mod", "task_a");
255        let task_b = TaskId::new("mod", "task_b");
256        let id_a = InvocationId::new();
257        let id_b = InvocationId::new();
258
259        broker
260            .route_invocation_for_task(&id_a, &task_a)
261            .await
262            .unwrap();
263        broker
264            .route_invocation_for_task(&id_b, &task_b)
265            .await
266            .unwrap();
267
268        // Per-task retrieval should return only the matching task's invocation
269        let got_a = broker.retrieve_invocation(Some(&task_a)).await.unwrap();
270        assert_eq!(got_a, Some(id_a));
271        // task_b's queue still has one item
272        assert_eq!(broker.count_invocations(Some(&task_b)).await.unwrap(), 1);
273        // Total = 1 (only task_b remains)
274        assert_eq!(broker.count_invocations(None).await.unwrap(), 1);
275        // Global retrieve should pick up the task_b item from the task queue fallback
276        let got_b = broker.retrieve_invocation(None).await.unwrap();
277        assert_eq!(got_b, Some(id_b));
278    }
279
280    #[tokio::test]
281    async fn test_per_task_purge() {
282        let broker = MemBroker::new();
283        let task_a = TaskId::new("mod", "task_a");
284        let task_b = TaskId::new("mod", "task_b");
285        broker
286            .route_invocation_for_task(&InvocationId::new(), &task_a)
287            .await
288            .unwrap();
289        broker
290            .route_invocation_for_task(&InvocationId::new(), &task_b)
291            .await
292            .unwrap();
293
294        assert_eq!(broker.count_invocations(None).await.unwrap(), 2);
295        broker.purge(Some(&task_a)).await.unwrap();
296        assert_eq!(broker.count_invocations(None).await.unwrap(), 1);
297        assert_eq!(broker.count_invocations(Some(&task_a)).await.unwrap(), 0);
298        assert_eq!(broker.count_invocations(Some(&task_b)).await.unwrap(), 1);
299    }
300
301    #[tokio::test]
302    async fn test_purge() {
303        let broker = MemBroker::new();
304        broker.route_invocation(&InvocationId::new()).await.unwrap();
305        broker.route_invocation(&InvocationId::new()).await.unwrap();
306
307        assert_eq!(broker.count_invocations(None).await.unwrap(), 2);
308
309        broker.purge(None).await.unwrap();
310        assert_eq!(broker.count_invocations(None).await.unwrap(), 0);
311    }
312
313    #[tokio::test]
314    async fn test_batch_route() {
315        let broker = MemBroker::new();
316        let ids: Vec<InvocationId> = (0..5).map(|_| InvocationId::new()).collect();
317
318        broker.route_invocations(&ids).await.unwrap();
319        assert_eq!(broker.count_invocations(None).await.unwrap(), 5);
320    }
321
322    #[tokio::test]
323    async fn test_language_routing_foreign_task() {
324        let broker = MemBroker::new();
325        let py_task = TaskId::foreign("python", "analytics.tasks", "train");
326        let rs_task = TaskId::new("math", "add");
327        let py_inv = InvocationId::new();
328        let rs_inv = InvocationId::new();
329
330        broker
331            .route_invocation_for_task(&py_inv, &py_task)
332            .await
333            .unwrap();
334        broker
335            .route_invocation_for_task(&rs_inv, &rs_task)
336            .await
337            .unwrap();
338
339        // Python worker should get only the python invocation
340        let got = broker
341            .retrieve_invocation_for_language("python")
342            .await
343            .unwrap();
344        assert_eq!(got, Some(py_inv));
345
346        // Python queue is now empty
347        let got = broker
348            .retrieve_invocation_for_language("python")
349            .await
350            .unwrap();
351        assert_eq!(got, None);
352
353        // Local (empty lang) worker should get the rust task (no "::" in key)
354        let got = broker.retrieve_invocation_for_language("").await.unwrap();
355        assert_eq!(got, Some(rs_inv));
356    }
357
358    #[tokio::test]
359    async fn test_language_routing_global_queue_serves_all() {
360        let broker = MemBroker::new();
361        let inv = InvocationId::new();
362
363        // Route via global queue (no task ID)
364        broker.route_invocation(&inv).await.unwrap();
365
366        // Any language worker should be able to get it
367        let got = broker
368            .retrieve_invocation_for_language("python")
369            .await
370            .unwrap();
371        assert_eq!(got, Some(inv));
372    }
373}