Skip to main content

workflow_graph_queue/memory/
queue.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2use std::time::Duration;
3
4use tokio::sync::{Mutex, broadcast};
5
6use crate::error::QueueError;
7use crate::traits::*;
8
9struct Inner {
10    pending: VecDeque<QueuedJob>,
11    active: HashMap<String, (Lease, QueuedJob)>, // keyed by lease_id
12    cancelled: HashSet<(String, String)>,        // (workflow_id, job_id)
13}
14
15pub struct InMemoryJobQueue {
16    inner: Mutex<Inner>,
17    events: broadcast::Sender<JobEvent>,
18}
19
20impl InMemoryJobQueue {
21    pub fn new() -> Self {
22        let (tx, _) = broadcast::channel(256);
23        Self {
24            inner: Mutex::new(Inner {
25                pending: VecDeque::new(),
26                active: HashMap::new(),
27                cancelled: HashSet::new(),
28            }),
29            events: tx,
30        }
31    }
32
33    fn now_ms() -> u64 {
34        std::time::SystemTime::now()
35            .duration_since(std::time::UNIX_EPOCH)
36            .unwrap_or_default()
37            .as_millis() as u64
38    }
39}
40
41impl Default for InMemoryJobQueue {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl JobQueue for InMemoryJobQueue {
48    async fn enqueue(&self, job: QueuedJob) -> Result<(), QueueError> {
49        let event = JobEvent::Ready {
50            workflow_id: job.workflow_id.clone(),
51            job_id: job.job_id.clone(),
52        };
53        self.inner.lock().await.pending.push_back(job);
54        self.events.send(event).ok();
55        Ok(())
56    }
57
58    async fn claim(
59        &self,
60        worker_id: &str,
61        worker_labels: &[String],
62        lease_ttl: Duration,
63    ) -> Result<Option<(QueuedJob, Lease)>, QueueError> {
64        let mut inner = self.inner.lock().await;
65
66        let now = Self::now_ms();
67        // Find first pending job whose required_labels match and delay has elapsed
68        let pos = inner.pending.iter().position(|job| {
69            job.delayed_until_ms <= now
70                && job
71                    .required_labels
72                    .iter()
73                    .all(|label| worker_labels.contains(label))
74        });
75
76        let Some(idx) = pos else {
77            return Ok(None);
78        };
79
80        let job = inner.pending.remove(idx).unwrap();
81        let lease = Lease {
82            lease_id: uuid::Uuid::new_v4().to_string(),
83            job_id: job.job_id.clone(),
84            workflow_id: job.workflow_id.clone(),
85            worker_id: worker_id.to_string(),
86            ttl_secs: lease_ttl.as_secs(),
87            granted_at_ms: Self::now_ms(),
88        };
89
90        inner
91            .active
92            .insert(lease.lease_id.clone(), (lease.clone(), job.clone()));
93
94        let event = JobEvent::Started {
95            workflow_id: job.workflow_id.clone(),
96            job_id: job.job_id.clone(),
97            worker_id: worker_id.to_string(),
98        };
99        drop(inner);
100        self.events.send(event).ok();
101
102        Ok(Some((job, lease)))
103    }
104
105    async fn renew_lease(&self, lease_id: &str, extend_by: Duration) -> Result<(), QueueError> {
106        let mut inner = self.inner.lock().await;
107        let (lease, _) = inner
108            .active
109            .get_mut(lease_id)
110            .ok_or_else(|| QueueError::LeaseNotFound(lease_id.to_string()))?;
111
112        lease.granted_at_ms = Self::now_ms();
113        lease.ttl_secs = extend_by.as_secs();
114        Ok(())
115    }
116
117    async fn complete(
118        &self,
119        lease_id: &str,
120        outputs: HashMap<String, String>,
121    ) -> Result<(), QueueError> {
122        let mut inner = self.inner.lock().await;
123        let (_, job) = inner
124            .active
125            .remove(lease_id)
126            .ok_or_else(|| QueueError::LeaseNotFound(lease_id.to_string()))?;
127
128        let event = JobEvent::Completed {
129            workflow_id: job.workflow_id.clone(),
130            job_id: job.job_id.clone(),
131            outputs,
132        };
133        drop(inner);
134        self.events.send(event).ok();
135        Ok(())
136    }
137
138    async fn fail(&self, lease_id: &str, error: String, retryable: bool) -> Result<(), QueueError> {
139        let mut inner = self.inner.lock().await;
140        let (_, job) = inner
141            .active
142            .remove(lease_id)
143            .ok_or_else(|| QueueError::LeaseNotFound(lease_id.to_string()))?;
144
145        let should_retry = retryable && job.attempt < job.retry_policy.max_retries;
146
147        if should_retry {
148            // Re-enqueue with incremented attempt and backoff delay
149            let mut retried = job.clone();
150            retried.attempt += 1;
151            let now = Self::now_ms();
152            retried.enqueued_at_ms = now;
153            let delay_ms = retried.retry_policy.backoff.delay_ms(retried.attempt);
154            retried.delayed_until_ms = now + delay_ms;
155            inner.pending.push_back(retried);
156        }
157
158        let event = JobEvent::Failed {
159            workflow_id: job.workflow_id.clone(),
160            job_id: job.job_id.clone(),
161            error,
162            retryable: should_retry,
163        };
164        drop(inner);
165        self.events.send(event).ok();
166        Ok(())
167    }
168
169    async fn cancel(&self, workflow_id: &str, job_id: &str) -> Result<(), QueueError> {
170        let mut inner = self.inner.lock().await;
171
172        // Remove from pending if present
173        inner
174            .pending
175            .retain(|j| !(j.workflow_id == workflow_id && j.job_id == job_id));
176
177        // Mark as cancelled (so active workers can check)
178        inner
179            .cancelled
180            .insert((workflow_id.to_string(), job_id.to_string()));
181
182        let event = JobEvent::Cancelled {
183            workflow_id: workflow_id.to_string(),
184            job_id: job_id.to_string(),
185        };
186        drop(inner);
187        self.events.send(event).ok();
188        Ok(())
189    }
190
191    async fn cancel_workflow(&self, workflow_id: &str) -> Result<(), QueueError> {
192        let mut inner = self.inner.lock().await;
193
194        // Collect job IDs to cancel
195        let pending_ids: Vec<String> = inner
196            .pending
197            .iter()
198            .filter(|j| j.workflow_id == workflow_id)
199            .map(|j| j.job_id.clone())
200            .collect();
201        let active_ids: Vec<String> = inner
202            .active
203            .values()
204            .filter(|(_, j)| j.workflow_id == workflow_id)
205            .map(|(_, j)| j.job_id.clone())
206            .collect();
207
208        // Remove pending jobs
209        inner.pending.retain(|j| j.workflow_id != workflow_id);
210
211        // Mark all as cancelled
212        for id in pending_ids.iter().chain(active_ids.iter()) {
213            inner
214                .cancelled
215                .insert((workflow_id.to_string(), id.clone()));
216        }
217
218        drop(inner);
219
220        for id in pending_ids.iter().chain(active_ids.iter()) {
221            self.events
222                .send(JobEvent::Cancelled {
223                    workflow_id: workflow_id.to_string(),
224                    job_id: id.clone(),
225                })
226                .ok();
227        }
228
229        Ok(())
230    }
231
232    async fn is_cancelled(&self, workflow_id: &str, job_id: &str) -> Result<bool, QueueError> {
233        let inner = self.inner.lock().await;
234        Ok(inner
235            .cancelled
236            .contains(&(workflow_id.to_string(), job_id.to_string())))
237    }
238
239    async fn reap_expired_leases(&self) -> Result<Vec<JobEvent>, QueueError> {
240        let mut inner = self.inner.lock().await;
241        let now = Self::now_ms();
242        let mut events = Vec::new();
243
244        let expired_ids: Vec<String> = inner
245            .active
246            .iter()
247            .filter(|(_, (lease, _))| {
248                let expires_at = lease.granted_at_ms + lease.ttl_secs * 1000;
249                now > expires_at
250            })
251            .map(|(id, _)| id.clone())
252            .collect();
253
254        for lease_id in expired_ids {
255            let (lease, job) = inner.active.remove(&lease_id).unwrap();
256
257            events.push(JobEvent::LeaseExpired {
258                workflow_id: job.workflow_id.clone(),
259                job_id: job.job_id.clone(),
260                worker_id: lease.worker_id.clone(),
261            });
262
263            // Re-enqueue if retries remain (with backoff)
264            if job.attempt < job.retry_policy.max_retries {
265                let mut retried = job;
266                retried.attempt += 1;
267                retried.enqueued_at_ms = now;
268                let delay_ms = retried.retry_policy.backoff.delay_ms(retried.attempt);
269                retried.delayed_until_ms = now + delay_ms;
270                inner.pending.push_back(retried);
271            }
272        }
273
274        drop(inner);
275        for event in &events {
276            self.events.send(event.clone()).ok();
277        }
278
279        Ok(events)
280    }
281
282    fn subscribe(&self) -> broadcast::Receiver<JobEvent> {
283        self.events.subscribe()
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[tokio::test]
292    async fn test_enqueue_and_claim() {
293        let queue = InMemoryJobQueue::new();
294        let job = QueuedJob {
295            job_id: "j1".into(),
296            workflow_id: "wf1".into(),
297            command: "echo hello".into(),
298            required_labels: vec![],
299            retry_policy: RetryPolicy::default(),
300            attempt: 0,
301            upstream_outputs: HashMap::new(),
302            enqueued_at_ms: 0,
303            delayed_until_ms: 0,
304        };
305
306        queue.enqueue(job).await.unwrap();
307
308        let result = queue
309            .claim("w1", &[], Duration::from_secs(30))
310            .await
311            .unwrap();
312        assert!(result.is_some());
313
314        let (claimed_job, lease) = result.unwrap();
315        assert_eq!(claimed_job.job_id, "j1");
316        assert_eq!(lease.worker_id, "w1");
317
318        // Queue should be empty now
319        let result2 = queue
320            .claim("w2", &[], Duration::from_secs(30))
321            .await
322            .unwrap();
323        assert!(result2.is_none());
324    }
325
326    #[tokio::test]
327    async fn test_claim_respects_labels() {
328        let queue = InMemoryJobQueue::new();
329        let job = QueuedJob {
330            job_id: "j1".into(),
331            workflow_id: "wf1".into(),
332            command: "echo hello".into(),
333            required_labels: vec!["docker".into()],
334            retry_policy: RetryPolicy::default(),
335            attempt: 0,
336            upstream_outputs: HashMap::new(),
337            enqueued_at_ms: 0,
338            delayed_until_ms: 0,
339        };
340
341        queue.enqueue(job).await.unwrap();
342
343        // Worker without docker label can't claim
344        let result = queue
345            .claim("w1", &[], Duration::from_secs(30))
346            .await
347            .unwrap();
348        assert!(result.is_none());
349
350        // Worker with docker label can claim
351        let result = queue
352            .claim("w2", &["docker".into()], Duration::from_secs(30))
353            .await
354            .unwrap();
355        assert!(result.is_some());
356    }
357
358    #[tokio::test]
359    async fn test_complete() {
360        let queue = InMemoryJobQueue::new();
361        let mut rx = queue.subscribe();
362
363        let job = QueuedJob {
364            job_id: "j1".into(),
365            workflow_id: "wf1".into(),
366            command: "echo".into(),
367            required_labels: vec![],
368            retry_policy: RetryPolicy::default(),
369            attempt: 0,
370            upstream_outputs: HashMap::new(),
371            enqueued_at_ms: 0,
372            delayed_until_ms: 0,
373        };
374
375        queue.enqueue(job).await.unwrap();
376        let _ = rx.recv().await; // Ready event
377
378        let (_, lease) = queue
379            .claim("w1", &[], Duration::from_secs(30))
380            .await
381            .unwrap()
382            .unwrap();
383        let _ = rx.recv().await; // Started event
384
385        let mut outputs = HashMap::new();
386        outputs.insert("result".into(), "success".into());
387        queue.complete(&lease.lease_id, outputs).await.unwrap();
388
389        if let Ok(JobEvent::Completed {
390            job_id, outputs, ..
391        }) = rx.recv().await
392        {
393            assert_eq!(job_id, "j1");
394            assert_eq!(outputs.get("result").unwrap(), "success");
395        } else {
396            panic!("expected Completed event");
397        }
398    }
399
400    #[tokio::test]
401    async fn test_fail_with_retry() {
402        let queue = InMemoryJobQueue::new();
403        let job = QueuedJob {
404            job_id: "j1".into(),
405            workflow_id: "wf1".into(),
406            command: "echo".into(),
407            required_labels: vec![],
408            retry_policy: RetryPolicy {
409                max_retries: 2,
410                backoff: BackoffStrategy::None,
411            },
412            attempt: 0,
413            upstream_outputs: HashMap::new(),
414            enqueued_at_ms: 0,
415            delayed_until_ms: 0,
416        };
417
418        queue.enqueue(job).await.unwrap();
419        let (_, lease) = queue
420            .claim("w1", &[], Duration::from_secs(30))
421            .await
422            .unwrap()
423            .unwrap();
424
425        // Fail with retryable — should re-enqueue
426        queue
427            .fail(&lease.lease_id, "oops".into(), true)
428            .await
429            .unwrap();
430
431        // Should be available again with attempt=1
432        let (retried, _) = queue
433            .claim("w1", &[], Duration::from_secs(30))
434            .await
435            .unwrap()
436            .unwrap();
437        assert_eq!(retried.attempt, 1);
438    }
439
440    #[tokio::test]
441    async fn test_cancel() {
442        let queue = InMemoryJobQueue::new();
443        let job = QueuedJob {
444            job_id: "j1".into(),
445            workflow_id: "wf1".into(),
446            command: "echo".into(),
447            required_labels: vec![],
448            retry_policy: RetryPolicy::default(),
449            attempt: 0,
450            upstream_outputs: HashMap::new(),
451            enqueued_at_ms: 0,
452            delayed_until_ms: 0,
453        };
454
455        queue.enqueue(job).await.unwrap();
456        queue.cancel("wf1", "j1").await.unwrap();
457
458        // Job should be removed from pending
459        let result = queue
460            .claim("w1", &[], Duration::from_secs(30))
461            .await
462            .unwrap();
463        assert!(result.is_none());
464
465        // Should be marked as cancelled
466        assert!(queue.is_cancelled("wf1", "j1").await.unwrap());
467    }
468
469    #[tokio::test]
470    async fn test_reap_expired_leases() {
471        let queue = InMemoryJobQueue::new();
472        let job = QueuedJob {
473            job_id: "j1".into(),
474            workflow_id: "wf1".into(),
475            command: "echo".into(),
476            required_labels: vec![],
477            retry_policy: RetryPolicy {
478                max_retries: 1,
479                backoff: BackoffStrategy::None,
480            },
481            attempt: 0,
482            upstream_outputs: HashMap::new(),
483            enqueued_at_ms: 0,
484            delayed_until_ms: 0,
485        };
486
487        queue.enqueue(job).await.unwrap();
488
489        // Claim with 0-second TTL (expires immediately)
490        let (_, _lease) = queue
491            .claim("w1", &[], Duration::from_secs(0))
492            .await
493            .unwrap()
494            .unwrap();
495
496        // Wait a tick so the lease is definitely expired
497        tokio::time::sleep(Duration::from_millis(10)).await;
498
499        // Reap should find the expired lease
500        let events = queue.reap_expired_leases().await.unwrap();
501        assert_eq!(events.len(), 1);
502        assert!(matches!(&events[0], JobEvent::LeaseExpired { job_id, .. } if job_id == "j1"));
503
504        // Job should be re-enqueued (retry budget allows it)
505        let (retried, _) = queue
506            .claim("w2", &[], Duration::from_secs(30))
507            .await
508            .unwrap()
509            .unwrap();
510        assert_eq!(retried.attempt, 1);
511    }
512}