Skip to main content

workflow_graph_queue/
scheduler.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use workflow_graph_shared::{JobStatus, Workflow};
7
8use crate::error::SchedulerError;
9use crate::traits::*;
10
11/// Shared workflow state, readable by the frontend polling API.
12pub type SharedState = Arc<RwLock<WorkflowState>>;
13
14/// In-memory workflow state that the frontend reads via polling.
15pub struct WorkflowState {
16    pub workflows: HashMap<String, Workflow>,
17}
18
19impl WorkflowState {
20    pub fn new() -> Self {
21        Self {
22            workflows: HashMap::new(),
23        }
24    }
25}
26
27impl Default for WorkflowState {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33/// Event-driven DAG scheduler.
34///
35/// Listens for `JobEvent`s from the queue and enqueues downstream jobs
36/// when their dependencies are satisfied. Replaces the inline orchestrator.
37pub struct DagScheduler<Q: JobQueue, A: ArtifactStore> {
38    queue: Arc<Q>,
39    artifacts: Arc<A>,
40    state: SharedState,
41}
42
43impl<Q: JobQueue, A: ArtifactStore> DagScheduler<Q, A> {
44    pub fn new(queue: Arc<Q>, artifacts: Arc<A>, state: SharedState) -> Self {
45        Self {
46            queue,
47            artifacts,
48            state,
49        }
50    }
51
52    /// Initiate a workflow run: reset all jobs to Queued, then enqueue root jobs.
53    pub async fn start_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
54        let root_jobs = {
55            let mut state = self.state.write().await;
56            let wf = state
57                .workflows
58                .get_mut(workflow_id)
59                .ok_or_else(|| SchedulerError::WorkflowNotFound(workflow_id.to_string()))?;
60
61            // Reset all jobs
62            for job in &mut wf.jobs {
63                job.status = JobStatus::Queued;
64                job.duration_secs = None;
65                job.started_at = None;
66                job.output = None;
67            }
68
69            // Find root jobs (no dependencies)
70            wf.jobs
71                .iter()
72                .filter(|j| j.depends_on.is_empty())
73                .map(|j| (j.id.clone(), j.command.clone()))
74                .collect::<Vec<_>>()
75        };
76
77        // Enqueue root jobs
78        for (job_id, command) in root_jobs {
79            let queued = QueuedJob {
80                job_id,
81                workflow_id: workflow_id.to_string(),
82                command,
83                required_labels: vec![],
84                retry_policy: RetryPolicy::default(),
85                attempt: 0,
86                upstream_outputs: HashMap::new(),
87                enqueued_at_ms: now_ms(),
88                delayed_until_ms: 0,
89            };
90            self.queue.enqueue(queued).await?;
91        }
92
93        Ok(())
94    }
95
96    /// Cancel a running workflow: cancel all pending/active jobs.
97    pub async fn cancel_workflow(&self, workflow_id: &str) -> Result<(), SchedulerError> {
98        self.queue.cancel_workflow(workflow_id).await?;
99
100        let mut state = self.state.write().await;
101        if let Some(wf) = state.workflows.get_mut(workflow_id) {
102            for job in &mut wf.jobs {
103                if job.status == JobStatus::Queued || job.status == JobStatus::Running {
104                    job.status = JobStatus::Cancelled;
105                }
106            }
107        }
108        Ok(())
109    }
110
111    /// Run the scheduler event loop. Listens for queue events and drives the DAG.
112    /// This should be spawned as a background task.
113    pub async fn run(self: Arc<Self>) {
114        let mut rx = self.queue.subscribe();
115
116        loop {
117            let event = match rx.recv().await {
118                Ok(event) => event,
119                Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
120                    eprintln!("Scheduler lagged by {n} events, some jobs may need manual recovery");
121                    continue;
122                }
123                Err(tokio::sync::broadcast::error::RecvError::Closed) => {
124                    eprintln!("Queue event channel closed, scheduler shutting down");
125                    break;
126                }
127            };
128
129            if let Err(e) = self.handle_event(event).await {
130                eprintln!("Scheduler error: {e}");
131            }
132        }
133    }
134
135    async fn handle_event(&self, event: JobEvent) -> Result<(), SchedulerError> {
136        match event {
137            JobEvent::Started {
138                workflow_id,
139                job_id,
140                ..
141            } => {
142                self.on_job_started(&workflow_id, &job_id).await;
143            }
144            JobEvent::Completed {
145                workflow_id,
146                job_id,
147                outputs,
148            } => {
149                self.on_job_completed(&workflow_id, &job_id, outputs)
150                    .await?;
151            }
152            JobEvent::Failed {
153                workflow_id,
154                job_id,
155                error,
156                retryable,
157            } => {
158                self.on_job_failed(&workflow_id, &job_id, &error, retryable)
159                    .await;
160            }
161            JobEvent::LeaseExpired {
162                workflow_id,
163                job_id,
164                ..
165            } => {
166                self.on_lease_expired(&workflow_id, &job_id).await;
167            }
168            JobEvent::Cancelled {
169                workflow_id,
170                job_id,
171            } => {
172                self.on_job_cancelled(&workflow_id, &job_id).await;
173            }
174            JobEvent::Ready { .. } => {
175                // No action needed — job is in queue waiting for a worker
176            }
177        }
178        Ok(())
179    }
180
181    async fn on_job_started(&self, workflow_id: &str, job_id: &str) {
182        let mut state = self.state.write().await;
183        if let Some(wf) = state.workflows.get_mut(workflow_id)
184            && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
185        {
186            job.status = JobStatus::Running;
187            job.started_at = Some(now_ms() as f64);
188        }
189    }
190
191    async fn on_job_completed(
192        &self,
193        workflow_id: &str,
194        job_id: &str,
195        outputs: HashMap<String, String>,
196    ) -> Result<(), SchedulerError> {
197        // Store outputs
198        self.artifacts
199            .put_outputs(workflow_id, job_id, outputs)
200            .await?;
201
202        // Update state
203        let ready_jobs = {
204            let mut state = self.state.write().await;
205            let wf = match state.workflows.get_mut(workflow_id) {
206                Some(wf) => wf,
207                None => return Ok(()),
208            };
209
210            // Mark this job as success
211            if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
212                job.status = JobStatus::Success;
213                if let Some(started) = job.started_at {
214                    job.duration_secs =
215                        Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
216                }
217            }
218
219            // Find downstream jobs whose deps are ALL succeeded
220            let ready: Vec<(String, String, Vec<String>)> = wf
221                .jobs
222                .iter()
223                .filter(|j| {
224                    j.status == JobStatus::Queued
225                        && j.depends_on.contains(&job_id.to_string())
226                        && j.depends_on.iter().all(|dep| {
227                            wf.jobs
228                                .iter()
229                                .find(|dj| dj.id == *dep)
230                                .is_some_and(|dj| dj.status == JobStatus::Success)
231                        })
232                })
233                .map(|j| (j.id.clone(), j.command.clone(), j.depends_on.clone()))
234                .collect();
235
236            ready
237        };
238
239        // Enqueue ready downstream jobs with upstream outputs
240        for (next_id, command, deps) in ready_jobs {
241            let upstream_outputs = self
242                .artifacts
243                .get_upstream_outputs(workflow_id, &deps)
244                .await?;
245
246            let queued = QueuedJob {
247                job_id: next_id,
248                workflow_id: workflow_id.to_string(),
249                command,
250                required_labels: vec![],
251                retry_policy: RetryPolicy::default(),
252                attempt: 0,
253                upstream_outputs,
254                enqueued_at_ms: now_ms(),
255                delayed_until_ms: 0,
256            };
257            self.queue.enqueue(queued).await?;
258        }
259
260        Ok(())
261    }
262
263    async fn on_job_failed(&self, workflow_id: &str, job_id: &str, error: &str, retryable: bool) {
264        let mut state = self.state.write().await;
265        let Some(wf) = state.workflows.get_mut(workflow_id) else {
266            return;
267        };
268
269        if retryable {
270            // Job was re-enqueued by the queue — mark as queued again
271            if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
272                job.status = JobStatus::Queued;
273                job.started_at = None;
274            }
275        } else {
276            // Permanent failure — mark job and skip all downstream
277            if let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id) {
278                job.status = JobStatus::Failure;
279                job.output = Some(error.to_string());
280                if let Some(started) = job.started_at {
281                    job.duration_secs =
282                        Some(((now_ms() as f64 - started) / 1000.0).max(0.0) as u64);
283                }
284            }
285
286            // Skip transitive downstream
287            let skip_ids = find_transitive_downstream(wf, job_id);
288            for skip_id in &skip_ids {
289                if let Some(j) = wf.jobs.iter_mut().find(|j| j.id == *skip_id) {
290                    j.status = JobStatus::Skipped;
291                }
292            }
293        }
294    }
295
296    async fn on_lease_expired(&self, workflow_id: &str, job_id: &str) {
297        // The queue already handled re-enqueueing if retries remain.
298        // We just need to update the state if the job was permanently failed.
299        let mut state = self.state.write().await;
300        if let Some(wf) = state.workflows.get_mut(workflow_id)
301            && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
302        {
303            // If the queue re-enqueued it, mark as queued; otherwise mark as failure
304            // We can't easily know here, so mark as Queued — the next Started event
305            // will update it correctly.
306            job.status = JobStatus::Queued;
307            job.started_at = None;
308        }
309    }
310
311    async fn on_job_cancelled(&self, workflow_id: &str, job_id: &str) {
312        let mut state = self.state.write().await;
313        if let Some(wf) = state.workflows.get_mut(workflow_id)
314            && let Some(job) = wf.jobs.iter_mut().find(|j| j.id == job_id)
315        {
316            job.status = JobStatus::Cancelled;
317        }
318    }
319}
320
321/// Find all jobs transitively downstream of the given job.
322fn find_transitive_downstream(wf: &Workflow, job_id: &str) -> Vec<String> {
323    let mut result = Vec::new();
324    let mut stack = vec![job_id.to_string()];
325
326    while let Some(current) = stack.pop() {
327        for job in &wf.jobs {
328            if job.depends_on.contains(&current) && !result.contains(&job.id) {
329                result.push(job.id.clone());
330                stack.push(job.id.clone());
331            }
332        }
333    }
334
335    result
336}
337
338fn now_ms() -> u64 {
339    std::time::SystemTime::now()
340        .duration_since(std::time::UNIX_EPOCH)
341        .unwrap_or_default()
342        .as_millis() as u64
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::memory::{InMemoryArtifactStore, InMemoryJobQueue};
349
350    fn sample_workflow() -> Workflow {
351        Workflow {
352            id: "wf1".into(),
353            name: "test".into(),
354            trigger: "manual".into(),
355            jobs: vec![
356                workflow_graph_shared::Job {
357                    id: "a".into(),
358                    name: "Job A".into(),
359                    status: JobStatus::Queued,
360                    command: "echo a".into(),
361                    duration_secs: None,
362                    started_at: None,
363                    required_labels: vec![],
364                    max_retries: 0,
365                    attempt: 0,
366                    depends_on: vec![],
367                    output: None,
368                    metadata: std::collections::HashMap::new(),
369                    ports: vec![],
370                },
371                workflow_graph_shared::Job {
372                    id: "b".into(),
373                    name: "Job B".into(),
374                    status: JobStatus::Queued,
375                    command: "echo b".into(),
376                    duration_secs: None,
377                    started_at: None,
378                    required_labels: vec![],
379                    max_retries: 0,
380                    attempt: 0,
381                    depends_on: vec!["a".into()],
382                    output: None,
383                    metadata: std::collections::HashMap::new(),
384                    ports: vec![],
385                },
386                workflow_graph_shared::Job {
387                    id: "c".into(),
388                    name: "Job C".into(),
389                    status: JobStatus::Queued,
390                    command: "echo c".into(),
391                    duration_secs: None,
392                    started_at: None,
393                    required_labels: vec![],
394                    max_retries: 0,
395                    attempt: 0,
396                    depends_on: vec!["a".into()],
397                    output: None,
398                    metadata: std::collections::HashMap::new(),
399                    ports: vec![],
400                },
401            ],
402        }
403    }
404
405    async fn setup() -> (
406        Arc<DagScheduler<InMemoryJobQueue, InMemoryArtifactStore>>,
407        Arc<InMemoryJobQueue>,
408        SharedState,
409    ) {
410        let queue = Arc::new(InMemoryJobQueue::new());
411        let artifacts = Arc::new(InMemoryArtifactStore::new());
412        let state = Arc::new(RwLock::new(WorkflowState::new()));
413
414        state
415            .write()
416            .await
417            .workflows
418            .insert("wf1".into(), sample_workflow());
419
420        let scheduler = Arc::new(DagScheduler::new(
421            queue.clone(),
422            artifacts.clone(),
423            state.clone(),
424        ));
425
426        (scheduler, queue, state)
427    }
428
429    #[tokio::test]
430    async fn test_start_workflow_enqueues_roots() {
431        let (scheduler, queue, _state) = setup().await;
432
433        scheduler.start_workflow("wf1").await.unwrap();
434
435        // Only job "a" (root) should be enqueued
436        let (job, _lease) = queue
437            .claim("w1", &[], std::time::Duration::from_secs(30))
438            .await
439            .unwrap()
440            .unwrap();
441        assert_eq!(job.job_id, "a");
442
443        // No more jobs available
444        assert!(
445            queue
446                .claim("w1", &[], std::time::Duration::from_secs(30))
447                .await
448                .unwrap()
449                .is_none()
450        );
451    }
452
453    #[tokio::test]
454    async fn test_completed_enqueues_downstream() {
455        let (scheduler, queue, state) = setup().await;
456
457        scheduler.start_workflow("wf1").await.unwrap();
458
459        // Claim and complete job A
460        let (_, lease) = queue
461            .claim("w1", &[], std::time::Duration::from_secs(30))
462            .await
463            .unwrap()
464            .unwrap();
465
466        // Process the Started event
467        scheduler
468            .handle_event(JobEvent::Started {
469                workflow_id: "wf1".into(),
470                job_id: "a".into(),
471                worker_id: "w1".into(),
472            })
473            .await
474            .unwrap();
475
476        // Complete job A
477        queue
478            .complete(&lease.lease_id, HashMap::new())
479            .await
480            .unwrap();
481
482        // Process the Completed event — should enqueue B and C
483        scheduler
484            .handle_event(JobEvent::Completed {
485                workflow_id: "wf1".into(),
486                job_id: "a".into(),
487                outputs: HashMap::new(),
488            })
489            .await
490            .unwrap();
491
492        // Both B and C should now be claimable
493        let (job1, _) = queue
494            .claim("w1", &[], std::time::Duration::from_secs(30))
495            .await
496            .unwrap()
497            .unwrap();
498        let (job2, _) = queue
499            .claim("w1", &[], std::time::Duration::from_secs(30))
500            .await
501            .unwrap()
502            .unwrap();
503
504        let mut ids = vec![job1.job_id, job2.job_id];
505        ids.sort();
506        assert_eq!(ids, vec!["b", "c"]);
507
508        // Check state
509        let s = state.read().await;
510        let wf = &s.workflows["wf1"];
511        assert_eq!(
512            wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
513            JobStatus::Success
514        );
515    }
516
517    #[tokio::test]
518    async fn test_failure_skips_downstream() {
519        let (scheduler, _queue, state) = setup().await;
520
521        scheduler.start_workflow("wf1").await.unwrap();
522
523        // Simulate job A failing permanently
524        scheduler
525            .handle_event(JobEvent::Failed {
526                workflow_id: "wf1".into(),
527                job_id: "a".into(),
528                error: "boom".into(),
529                retryable: false,
530            })
531            .await
532            .unwrap();
533
534        let s = state.read().await;
535        let wf = &s.workflows["wf1"];
536        assert_eq!(
537            wf.jobs.iter().find(|j| j.id == "a").unwrap().status,
538            JobStatus::Failure
539        );
540        assert_eq!(
541            wf.jobs.iter().find(|j| j.id == "b").unwrap().status,
542            JobStatus::Skipped
543        );
544        assert_eq!(
545            wf.jobs.iter().find(|j| j.id == "c").unwrap().status,
546            JobStatus::Skipped
547        );
548    }
549
550    #[tokio::test]
551    async fn test_cancel_workflow() {
552        let (scheduler, _queue, state) = setup().await;
553
554        scheduler.start_workflow("wf1").await.unwrap();
555        scheduler.cancel_workflow("wf1").await.unwrap();
556
557        let s = state.read().await;
558        let wf = &s.workflows["wf1"];
559        for job in &wf.jobs {
560            assert!(
561                job.status == JobStatus::Cancelled,
562                "job {} should be cancelled, got {:?}",
563                job.id,
564                job.status
565            );
566        }
567    }
568}