1use std::collections::{BTreeMap, VecDeque};
2use tokio::sync::Mutex;
3use tokio_util::sync::CancellationToken;
4
5use async_trait::async_trait;
6use tracing::instrument;
7
8use rustvello_core::broker::Broker;
9use rustvello_core::error::RustvelloResult;
10use rustvello_proto::identifiers::{InvocationId, TaskId};
11pub struct MemBroker {
31 queues: Mutex<BTreeMap<String, VecDeque<InvocationId>>>,
35 notify: tokio::sync::Notify,
37}
38
39const GLOBAL_QUEUE: &str = "__global__";
40
41impl MemBroker {
42 pub fn new() -> Self {
43 Self {
44 queues: Mutex::new(BTreeMap::new()),
45 notify: tokio::sync::Notify::new(),
46 }
47 }
48}
49
50impl Default for MemBroker {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56#[async_trait]
57impl Broker for MemBroker {
58 #[instrument(skip(self), fields(%invocation_id))]
60 async fn route_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<()> {
61 let mut queues = self.queues.lock().await;
62 queues
63 .entry(GLOBAL_QUEUE.to_owned())
64 .or_default()
65 .push_back(invocation_id.clone());
66 drop(queues);
67 self.notify.notify_one();
68 Ok(())
69 }
70
71 #[instrument(skip(self), fields(%invocation_id, %task_id))]
76 async fn route_invocation_for_task(
77 &self,
78 invocation_id: &InvocationId,
79 task_id: &TaskId,
80 ) -> RustvelloResult<()> {
81 let mut queues = self.queues.lock().await;
82 queues
83 .entry(task_id.to_string())
84 .or_default()
85 .push_back(invocation_id.clone());
86 drop(queues);
87 self.notify.notify_one();
88 Ok(())
89 }
90
91 #[instrument(skip(self))]
92 async fn retrieve_invocation(
93 &self,
94 task_id: Option<&TaskId>,
95 ) -> RustvelloResult<Option<InvocationId>> {
96 let mut queues = self.queues.lock().await;
97 if let Some(tid) = task_id {
98 return Ok(queues
100 .get_mut(&tid.to_string())
101 .and_then(VecDeque::pop_front));
102 }
103 if let Some(id) = queues.get_mut(GLOBAL_QUEUE).and_then(VecDeque::pop_front) {
105 return Ok(Some(id));
106 }
107 for (key, queue) in queues.iter_mut() {
109 if key == GLOBAL_QUEUE {
110 continue;
111 }
112 if let Some(id) = queue.pop_front() {
113 return Ok(Some(id));
114 }
115 }
116 Ok(None)
117 }
118
119 async fn retrieve_invocation_for_language(
130 &self,
131 language: &str,
132 ) -> RustvelloResult<Option<InvocationId>> {
133 let mut queues = self.queues.lock().await;
134 if let Some(id) = queues.get_mut(GLOBAL_QUEUE).and_then(VecDeque::pop_front) {
136 return Ok(Some(id));
137 }
138 let prefix = format!("{language}::");
139 for (key, queue) in queues.iter_mut() {
140 if key == GLOBAL_QUEUE {
141 continue;
142 }
143 let matches = if language.is_empty() {
145 !key.contains("::")
146 } else {
147 key.starts_with(&prefix)
148 };
149 if matches {
150 if let Some(id) = queue.pop_front() {
151 return Ok(Some(id));
152 }
153 }
154 }
155 Ok(None)
156 }
157
158 async fn count_invocations(&self, task_id: Option<&TaskId>) -> RustvelloResult<usize> {
159 let queues = self.queues.lock().await;
160 if let Some(tid) = task_id {
161 return Ok(queues.get(&tid.to_string()).map_or(0, VecDeque::len));
162 }
163 Ok(queues.values().map(VecDeque::len).sum())
164 }
165
166 async fn purge(&self, task_id: Option<&TaskId>) -> RustvelloResult<()> {
167 let mut queues = self.queues.lock().await;
168 if let Some(tid) = task_id {
169 queues.remove(&tid.to_string());
170 return Ok(());
171 }
172 queues.clear();
173 Ok(())
174 }
175
176 async fn retrieve_invocations(
178 &self,
179 max: usize,
180 task_id: Option<&TaskId>,
181 ) -> RustvelloResult<Vec<InvocationId>> {
182 let mut queues = self.queues.lock().await;
183 let capped = max.min(10_000);
184 let mut results = Vec::with_capacity(capped);
185 for _ in 0..capped {
186 let item = if let Some(tid) = task_id {
187 queues
188 .get_mut(&tid.to_string())
189 .and_then(VecDeque::pop_front)
190 } else {
191 let global = queues.get_mut(GLOBAL_QUEUE).and_then(VecDeque::pop_front);
193 if global.is_some() {
194 global
195 } else {
196 let mut found = None;
197 for (key, queue) in queues.iter_mut() {
198 if key == GLOBAL_QUEUE {
199 continue;
200 }
201 if let Some(id) = queue.pop_front() {
202 found = Some(id);
203 break;
204 }
205 }
206 found
207 }
208 };
209 match item {
210 Some(id) => results.push(id),
211 None => break,
212 }
213 }
214 Ok(results)
215 }
216
217 async fn wait_for_work(&self, cancel: &CancellationToken) -> bool {
219 tokio::select! {
220 _ = cancel.cancelled() => false,
221 _ = self.notify.notified() => true,
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[tokio::test]
231 async fn test_route_and_retrieve() {
232 let broker = MemBroker::new();
233 let id1 = InvocationId::new();
234 let id2 = InvocationId::new();
235
236 broker.route_invocation(&id1).await.unwrap();
237 broker.route_invocation(&id2).await.unwrap();
238
239 assert_eq!(broker.count_invocations(None).await.unwrap(), 2);
240
241 let retrieved1 = broker.retrieve_invocation(None).await.unwrap();
242 assert_eq!(retrieved1, Some(id1));
243
244 let retrieved2 = broker.retrieve_invocation(None).await.unwrap();
245 assert_eq!(retrieved2, Some(id2));
246
247 let retrieved3 = broker.retrieve_invocation(None).await.unwrap();
248 assert_eq!(retrieved3, None);
249 }
250
251 #[tokio::test]
252 async fn test_per_task_routing() {
253 let broker = MemBroker::new();
254 let task_a = TaskId::new("mod", "task_a");
255 let task_b = TaskId::new("mod", "task_b");
256 let id_a = InvocationId::new();
257 let id_b = InvocationId::new();
258
259 broker
260 .route_invocation_for_task(&id_a, &task_a)
261 .await
262 .unwrap();
263 broker
264 .route_invocation_for_task(&id_b, &task_b)
265 .await
266 .unwrap();
267
268 let got_a = broker.retrieve_invocation(Some(&task_a)).await.unwrap();
270 assert_eq!(got_a, Some(id_a));
271 assert_eq!(broker.count_invocations(Some(&task_b)).await.unwrap(), 1);
273 assert_eq!(broker.count_invocations(None).await.unwrap(), 1);
275 let got_b = broker.retrieve_invocation(None).await.unwrap();
277 assert_eq!(got_b, Some(id_b));
278 }
279
280 #[tokio::test]
281 async fn test_per_task_purge() {
282 let broker = MemBroker::new();
283 let task_a = TaskId::new("mod", "task_a");
284 let task_b = TaskId::new("mod", "task_b");
285 broker
286 .route_invocation_for_task(&InvocationId::new(), &task_a)
287 .await
288 .unwrap();
289 broker
290 .route_invocation_for_task(&InvocationId::new(), &task_b)
291 .await
292 .unwrap();
293
294 assert_eq!(broker.count_invocations(None).await.unwrap(), 2);
295 broker.purge(Some(&task_a)).await.unwrap();
296 assert_eq!(broker.count_invocations(None).await.unwrap(), 1);
297 assert_eq!(broker.count_invocations(Some(&task_a)).await.unwrap(), 0);
298 assert_eq!(broker.count_invocations(Some(&task_b)).await.unwrap(), 1);
299 }
300
301 #[tokio::test]
302 async fn test_purge() {
303 let broker = MemBroker::new();
304 broker.route_invocation(&InvocationId::new()).await.unwrap();
305 broker.route_invocation(&InvocationId::new()).await.unwrap();
306
307 assert_eq!(broker.count_invocations(None).await.unwrap(), 2);
308
309 broker.purge(None).await.unwrap();
310 assert_eq!(broker.count_invocations(None).await.unwrap(), 0);
311 }
312
313 #[tokio::test]
314 async fn test_batch_route() {
315 let broker = MemBroker::new();
316 let ids: Vec<InvocationId> = (0..5).map(|_| InvocationId::new()).collect();
317
318 broker.route_invocations(&ids).await.unwrap();
319 assert_eq!(broker.count_invocations(None).await.unwrap(), 5);
320 }
321
322 #[tokio::test]
323 async fn test_language_routing_foreign_task() {
324 let broker = MemBroker::new();
325 let py_task = TaskId::foreign("python", "analytics.tasks", "train");
326 let rs_task = TaskId::new("math", "add");
327 let py_inv = InvocationId::new();
328 let rs_inv = InvocationId::new();
329
330 broker
331 .route_invocation_for_task(&py_inv, &py_task)
332 .await
333 .unwrap();
334 broker
335 .route_invocation_for_task(&rs_inv, &rs_task)
336 .await
337 .unwrap();
338
339 let got = broker
341 .retrieve_invocation_for_language("python")
342 .await
343 .unwrap();
344 assert_eq!(got, Some(py_inv));
345
346 let got = broker
348 .retrieve_invocation_for_language("python")
349 .await
350 .unwrap();
351 assert_eq!(got, None);
352
353 let got = broker.retrieve_invocation_for_language("").await.unwrap();
355 assert_eq!(got, Some(rs_inv));
356 }
357
358 #[tokio::test]
359 async fn test_language_routing_global_queue_serves_all() {
360 let broker = MemBroker::new();
361 let inv = InvocationId::new();
362
363 broker.route_invocation(&inv).await.unwrap();
365
366 let got = broker
368 .retrieve_invocation_for_language("python")
369 .await
370 .unwrap();
371 assert_eq!(got, Some(inv));
372 }
373}