Skip to main content

workflow_graph_queue/
scheduler.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use workflow_graph_shared::{JobStatus, Workflow};
7
8use crate::error::SchedulerError;
9use crate::traits::*;
10
11/// Shared workflow state, readable by the frontend polling API.
12pub type SharedState = Arc<RwLock<WorkflowState>>;
13
14/// In-memory workflow state that the frontend reads via polling.
15pub struct WorkflowState {
16    pub workflows: HashMap<String, Workflow>,
17}
18
19impl WorkflowState {
20    pub fn new() -> Self {
21        Self {
22            workflows: HashMap::new(),
23        }
24    }
25}
26
27impl Default for WorkflowState {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33/// Event-driven DAG scheduler.
34///
35/// Listens for `JobEvent`s from the queue and enqueues downstream jobs
36/// when their dependencies are satisfied. Replaces the inline orchestrator.
37pub struct DagScheduler<Q: JobQueue, A: ArtifactStore> {
38    queue: Arc<Q>,
39    artifacts: Arc<A>,
40    state: SharedState,
41}
42
43impl<Q: JobQueue, A: ArtifactStore> DagScheduler<Q, A> {
44    pub fn new(queue: Arc<Q>, artifacts: Arc<A>, state: SharedState) -> Self {
45        Self {
46            queue,
47            artifacts,
48            state,
49        }
50    }
51
52    /// Initiate a workflow run: reset all jobs to Queued, then enqueue root jobs.
53    pub async fn start_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
54        let root_jobs = {
55            let mut state = self.state.write().await;
56            let wf = state
57                .workflows
58                .get_mut(workflow_id)
59                .ok_or_else(|| SchedulerError::WorkflowNotFound(workflow_id.to_string()))?;
60
61            // Reset all jobs
62            for job in &mut wf.jobs {
63                job.status = JobStatus::Queued;
64                job.duration_secs = None;
65                job.started_at = None;
66                job.output = None;
67            }
68
69            // Find root jobs (no dependencies)
70            wf.jobs
71                .iter()
72                .filter(|j| j.depends_on.is_empty())
73                .map(|j| (j.id.clone(), j.command.clone()))
74                .collect::<Vec<_>>()
75        };
76
77        // Enqueue root jobs
78        for (job_id, command) in root_jobs {
79            let queued = QueuedJob {
80                job_id,
81                workflow_id: workflow_id.to_string(),
82                command,
83                required_labels: vec![],
84                retry_policy: RetryPolicy::default(),
85                attempt: 0,
86                upstream_outputs: HashMap::new(),
87                enqueued_at_ms: now_ms(),
88                delayed_until_ms: 0,
89            };
90            self.queue.enqueue(queued).await?;
91        }
92
93        Ok(())
94    }
95
96    /// Cancel a running workflow: cancel all pending/active jobs.
97    pub async fn cancel_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
98        self.queue.cancel_workflow(workflow_id).await?;
99
100        let mut state = self.state.write().await;
101        if let Some(wf) = state.workflows.get_mut(workflow_id) {
102            for job in &mut wf.jobs {
103                if job.status == JobStatus::Queued || job.status == JobStatus::Running {
104                    job.status = JobStatus::Cancelled;
105                }
106            }
107        }
108        Ok(())
109    }
110
111    /// Run the scheduler event loop. Listens for queue events and drives the DAG.
112    /// This should be spawned as a background task.
113    pub async fn run(self: Arc<Self>) {
114        let mut rx = self.queue.subscribe();
115
116        loop {
117            let event = match rx.recv().await {
118                Ok(event) => event,
119                Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
120                    eprintln!("Scheduler lagged by {n} events, some jobs may need manual recovery");
121                    continue;
122                }
123                Err(tokio::sync::broadcast::error::RecvError::Closed) => {
124                    eprintln!("Queue event channel closed, scheduler shutting down");
125                    break;
126                }
127            };
128
129            if let Err(e) = self.handle_event(event).await {
130                eprintln!("Scheduler error: {e}");
131            }
132        }
133    }
134
135    async fn handle_event(&self, event: JobEvent) -> Result<(), SchedulerError> {
136        match event {
137            JobEvent::Started {
138                workflow_id,
139                job_id,
140                ..
141            } => {
142                self.on_job_started(&workflow_id, &job_id).await;
143            }
144            JobEvent::Completed {
145                workflow_id,
146                job_id,
147                outputs,
148            } => {
149                self.on_job_completed(&workflow_id, &job_id, outputs)
150                    .await?;
151            }
152            JobEvent::Failed {
153                workflow_id,
154                job_id,
155                error,
156                retryable,
157            } => {
158                self.on_job_failed(&workflow_id, &job_id, &error, retryable)
159                    .await;
160            }
161            JobEvent::LeaseExpired {
162                workflow_id,
163                job_id,
164                ..
165            } => {
166                self.on_lease_expired(&workflow_id, &job_id).await;
167            }
168            JobEvent::Cancelled {
169                workflow_id,
170                job_id,
171            } => {
172                self.on_job_cancelled(&workflow_id, &job_id).await;
173            }
174            JobEvent::Ready { .. } => {
175                // No action needed — job is in queue waiting for a worker
176            }
177        }
178        Ok(())
179    }
180
181    async fn on_job_started(&self, workflow_id: &str, job_id: &str) {
182        let mut state = self.state.write().await;
183        if let Some(wf) = state.workflows.get_mut(workflow_id)
184            && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
185        {
186            job.status = JobStatus::Running;
187            job.started_at = Some(now_ms() as f64);
188        }
189    }
190
191    async fn on_job_completed(
192        &self,
193        workflow_id: &str,
194        job_id: &str,
195        outputs: HashMap<String, String>,
196    ) -> Result<(), SchedulerError> {
197        // Store outputs
198        self.artifacts
199            .put_outputs(workflow_id, job_id, outputs)
200            .await?;
201
202        // Update state
203        let ready_jobs = {
204            let mut state = self.state.write().await;
205            let wf = match state.workflows.get_mut(workflow_id) {
206                Some(wf) => wf,
207                None => return Ok(()),
208            };
209
210            // Mark this job as success
211            if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
212                job.status = JobStatus::Success;
213                if let Some(started) = job.started_at {
214                    job.duration_secs =
215                        Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
216                }
217            }
218
219            // Find downstream jobs whose deps are ALL succeeded
220            let ready: Vec<(String, String, Vec<String>)> = wf
221                .jobs
222                .iter()
223                .filter(|j| {
224                    j.status == JobStatus::Queued
225                        && j.depends_on.contains(&job_id.to_string())
226                        && j.depends_on.iter().all(|dep| {
227                            wf.jobs
228                                .iter()
229                                .find(|dj| dj.id == *dep)
230                                .is_some_and(|dj| dj.status == JobStatus::Success)
231                        })
232                })
233                .map(|j| (j.id.clone(), j.command.clone(), j.depends_on.clone()))
234                .collect();
235
236            ready
237        };
238
239        // Enqueue ready downstream jobs with upstream outputs
240        for (next_id, command, deps) in ready_jobs {
241            let upstream_outputs = self
242                .artifacts
243                .get_upstream_outputs(workflow_id, &deps)
244                .await?;
245
246            let queued = QueuedJob {
247                job_id: next_id,
248                workflow_id: workflow_id.to_string(),
249                command,
250                required_labels: vec![],
251                retry_policy: RetryPolicy::default(),
252                attempt: 0,
253                upstream_outputs,
254                enqueued_at_ms: now_ms(),
255                delayed_until_ms: 0,
256            };
257            self.queue.enqueue(queued).await?;
258        }
259
260        Ok(())
261    }
262
263    async fn on_job_failed(&self, workflow_id: &str, job_id: &str, error: &str, retryable: bool) {
264        let mut state = self.state.write().await;
265        let Some(wf) = state.workflows.get_mut(workflow_id) else {
266            return;
267        };
268
269        if retryable {
270            // Job was re-enqueued by the queue — mark as queued again
271            if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
272                job.status = JobStatus::Queued;
273                job.started_at = None;
274            }
275        } else {
276            // Permanent failure — mark job and skip all downstream
277            if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
278                job.status = JobStatus::Failure;
279                job.output = Some(error.to_string());
280                if let Some(started) = job.started_at {
281                    job.duration_secs =
282                        Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
283                }
284            }
285
286            // Skip transitive downstream
287            let skip_ids = find_transitive_downstream(wf, job_id);
288            for skip_id in &skip_ids {
289                if let Some(j) = wf.jobs.iter_mut().find(|j| j.id == *skip_id) {
290                    j.status = JobStatus::Skipped;
291                }
292            }
293        }
294    }
295
296    async fn on_lease_expired(&self, workflow_id: &str, job_id: &str) {
297        // The queue already handled re-enqueueing if retries remain.
298        // We just need to update the state if the job was permanently failed.
299        let mut state = self.state.write().await;
300        if let Some(wf) = state.workflows.get_mut(workflow_id)
301            && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
302        {
303            // If the queue re-enqueued it, mark as queued; otherwise mark as failure
304            // We can't easily know here, so mark as Queued — the next Started event
305            // will update it correctly.
306            job.status = JobStatus::Queued;
307            job.started_at = None;
308        }
309    }
310
311    async fn on_job_cancelled(&self, workflow_id: &str, job_id: &str) {
312        let mut state = self.state.write().await;
313        if let Some(wf) = state.workflows.get_mut(workflow_id)
314            && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
315        {
316            job.status = JobStatus::Cancelled;
317        }
318    }
319}
320
321/// Find all jobs transitively downstream of the given job.
322fn find_transitive_downstream(wf: &Workflow, job_id: &str) -> Vec<String> {
323    let mut result = Vec::new();
324    let mut stack = vec![job_id.to_string()];
325
326    while let Some(current) = stack.pop() {
327        for job in &wf.jobs {
328            if job.depends_on.contains(&current) && !result.contains(&job.id) {
329                result.push(job.id.clone());
330                stack.push(job.id.clone());
331            }
332        }
333    }
334
335    result
336}
337
338fn now_ms() -> u64 {
339    std::time::SystemTime::now()
340        .duration_since(std::time::UNIX_EPOCH)
341        .unwrap_or_default()
342        .as_millis() as u64
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::memory::{InMemoryArtifactStore, InMemoryJobQueue};
349
350    fn sample_workflow() -> Workflow {
351        Workflow {
352            id: "wf1".into(),
353            name: "test".into(),
354            trigger: "manual".into(),
355            jobs: vec![
356                workflow_graph_shared::Job {
357                    id: "a".into(),
358                    name: "Job A".into(),
359                    status: JobStatus::Queued,
360                    command: "echo a".into(),
361                    duration_secs: None,
362                    started_at: None,
363                    required_labels: vec![],
364                    max_retries: 0,
365                    attempt: 0,
366                    depends_on: vec![],
367                    output: None,
368                },
369                workflow_graph_shared::Job {
370                    id: "b".into(),
371                    name: "Job B".into(),
372                    status: JobStatus::Queued,
373                    command: "echo b".into(),
374                    duration_secs: None,
375                    started_at: None,
376                    required_labels: vec![],
377                    max_retries: 0,
378                    attempt: 0,
379                    depends_on: vec!["a".into()],
380                    output: None,
381                },
382                workflow_graph_shared::Job {
383                    id: "c".into(),
384                    name: "Job C".into(),
385                    status: JobStatus::Queued,
386                    command: "echo c".into(),
387                    duration_secs: None,
388                    started_at: None,
389                    required_labels: vec![],
390                    max_retries: 0,
391                    attempt: 0,
392                    depends_on: vec!["a".into()],
393                    output: None,
394                },
395            ],
396        }
397    }
398
399    async fn setup() -> (
400        Arc<DagScheduler<InMemoryJobQueue, InMemoryArtifactStore>>,
401        Arc<InMemoryJobQueue>,
402        SharedState,
403    ) {
404        let queue = Arc::new(InMemoryJobQueue::new());
405        let artifacts = Arc::new(InMemoryArtifactStore::new());
406        let state = Arc::new(RwLock::new(WorkflowState::new()));
407
408        state
409            .write()
410            .await
411            .workflows
412            .insert("wf1".into(), sample_workflow());
413
414        let scheduler = Arc::new(DagScheduler::new(
415            queue.clone(),
416            artifacts.clone(),
417            state.clone(),
418        ));
419
420        (scheduler, queue, state)
421    }
422
423    #[tokio::test]
424    async fn test_start_workflow_enqueues_roots() {
425        let (scheduler, queue, _state) = setup().await;
426
427        scheduler.start_workflow("wf1").await.unwrap();
428
429        // Only job "a" (root) should be enqueued
430        let (job, _lease) = queue
431            .claim("w1", &[], std::time::Duration::from_secs(30))
432            .await
433            .unwrap()
434            .unwrap();
435        assert_eq!(job.job_id, "a");
436
437        // No more jobs available
438        assert!(
439            queue
440                .claim("w1", &[], std::time::Duration::from_secs(30))
441                .await
442                .unwrap()
443                .is_none()
444        );
445    }
446
447    #[tokio::test]
448    async fn test_completed_enqueues_downstream() {
449        let (scheduler, queue, state) = setup().await;
450
451        scheduler.start_workflow("wf1").await.unwrap();
452
453        // Claim and complete job A
454        let (_, lease) = queue
455            .claim("w1", &[], std::time::Duration::from_secs(30))
456            .await
457            .unwrap()
458            .unwrap();
459
460        // Process the Started event
461        scheduler
462            .handle_event(JobEvent::Started {
463                workflow_id: "wf1".into(),
464                job_id: "a".into(),
465                worker_id: "w1".into(),
466            })
467            .await
468            .unwrap();
469
470        // Complete job A
471        queue
472            .complete(&lease.lease_id, HashMap::new())
473            .await
474            .unwrap();
475
476        // Process the Completed event — should enqueue B and C
477        scheduler
478            .handle_event(JobEvent::Completed {
479                workflow_id: "wf1".into(),
480                job_id: "a".into(),
481                outputs: HashMap::new(),
482            })
483            .await
484            .unwrap();
485
486        // Both B and C should now be claimable
487        let (job1, _) = queue
488            .claim("w1", &[], std::time::Duration::from_secs(30))
489            .await
490            .unwrap()
491            .unwrap();
492        let (job2, _) = queue
493            .claim("w1", &[], std::time::Duration::from_secs(30))
494            .await
495            .unwrap()
496            .unwrap();
497
498        let mut ids = vec![job1.job_id, job2.job_id];
499        ids.sort();
500        assert_eq!(ids, vec!["b", "c"]);
501
502        // Check state
503        let s = state.read().await;
504        let wf = &s.workflows["wf1"];
505        assert_eq!(
506            wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
507            JobStatus::Success
508        );
509    }
510
511    #[tokio::test]
512    async fn test_failure_skips_downstream() {
513        let (scheduler, _queue, state) = setup().await;
514
515        scheduler.start_workflow("wf1").await.unwrap();
516
517        // Simulate job A failing permanently
518        scheduler
519            .handle_event(JobEvent::Failed {
520                workflow_id: "wf1".into(),
521                job_id: "a".into(),
522                error: "boom".into(),
523                retryable: false,
524            })
525            .await
526            .unwrap();
527
528        let s = state.read().await;
529        let wf = &s.workflows["wf1"];
530        assert_eq!(
531            wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
532            JobStatus::Failure
533        );
534        assert_eq!(
535            wf.jobs.iter().find(|j| j.id == "b").unwrap().status,
536            JobStatus::Skipped
537        );
538        assert_eq!(
539            wf.jobs.iter().find(|j| j.id == "c").unwrap().status,
540            JobStatus::Skipped
541        );
542    }
543
544    #[tokio::test]
545    async fn test_cancel_workflow() {
546        let (scheduler, _queue, state) = setup().await;
547
548        scheduler.start_workflow("wf1").await.unwrap();
549        scheduler.cancel_workflow("wf1").await.unwrap();
550
551        let s = state.read().await;
552        let wf = &s.workflows["wf1"];
553        for job in &wf.jobs {
554            assert!(
555                job.status == JobStatus::Cancelled,
556                "job {} should be cancelled, got {:?}",
557                job.id,
558                job.status
559            );
560        }
561    }
562}